mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-04 20:00:00 +08:00
feat: enhance AI Core with image generation capabilities
- Introduced `createImageClient` method in `ApiClientFactory` to support image generation for various providers. - Updated `UniversalAiSdkClient` to include `generateImage` method, allowing image generation through the unified client interface. - Refactored client creation functions to utilize the new `ProviderOptions` type for improved type safety. - Enhanced the provider registry to indicate which providers support image generation, streamlining client creation and usage. - Updated type definitions in `types.ts` to reflect changes in client options and middleware support.
This commit is contained in:
parent
ed2363e561
commit
2df1cddb43
@ -3,14 +3,15 @@
|
|||||||
* 整合现有实现的改进版API客户端工厂
|
* 整合现有实现的改进版API客户端工厂
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import type { LanguageModelV1 } from 'ai'
|
import type { ImageModelV1 } from '@ai-sdk/provider'
|
||||||
|
import { type LanguageModelV1, wrapLanguageModel } from 'ai'
|
||||||
|
|
||||||
import { aiProviderRegistry } from '../providers/registry'
|
import { aiProviderRegistry } from '../providers/registry'
|
||||||
|
|
||||||
// 客户端配置接口
|
// 客户端配置接口
|
||||||
export interface ClientConfig {
|
export interface ClientConfig {
|
||||||
providerId: string
|
providerId: string
|
||||||
options?: any
|
options?: ProviderOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
// 错误类型
|
// 错误类型
|
||||||
@ -68,7 +69,17 @@ export class ApiClientFactory {
|
|||||||
|
|
||||||
// 返回模型实例
|
// 返回模型实例
|
||||||
if (typeof provider === 'function') {
|
if (typeof provider === 'function') {
|
||||||
return provider(modelId)
|
let model = provider(modelId)
|
||||||
|
|
||||||
|
// 应用 AI SDK 中间件
|
||||||
|
if (providerConfig.aiSdkMiddlewares) {
|
||||||
|
model = wrapLanguageModel({
|
||||||
|
model: model,
|
||||||
|
middleware: providerConfig.aiSdkMiddlewares
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return model
|
||||||
} else {
|
} else {
|
||||||
throw new ClientFactoryError(`Unknown model access pattern for provider "${providerId}"`)
|
throw new ClientFactoryError(`Unknown model access pattern for provider "${providerId}"`)
|
||||||
}
|
}
|
||||||
@ -84,6 +95,54 @@ export class ApiClientFactory {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static async createImageClient(
|
||||||
|
providerId: string,
|
||||||
|
modelId: string = 'default',
|
||||||
|
options: ProviderOptions
|
||||||
|
): Promise<ImageModelV1> {
|
||||||
|
try {
|
||||||
|
if (!aiProviderRegistry.isSupported(providerId)) {
|
||||||
|
throw new ClientFactoryError(`Provider "${providerId}" is not supported`, providerId)
|
||||||
|
}
|
||||||
|
|
||||||
|
const providerConfig = aiProviderRegistry.getProvider(providerId)
|
||||||
|
if (!providerConfig) {
|
||||||
|
throw new ClientFactoryError(`Provider "${providerId}" is not registered`, providerId)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!providerConfig.supportsImageGeneration) {
|
||||||
|
throw new ClientFactoryError(`Provider "${providerId}" does not support image generation`, providerId)
|
||||||
|
}
|
||||||
|
|
||||||
|
const module = await providerConfig.import()
|
||||||
|
|
||||||
|
const creatorFunction = module[providerConfig.creatorFunctionName]
|
||||||
|
|
||||||
|
if (typeof creatorFunction !== 'function') {
|
||||||
|
throw new ClientFactoryError(
|
||||||
|
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${providerId}"`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const provider = creatorFunction(options)
|
||||||
|
|
||||||
|
if (provider && typeof provider.image === 'function') {
|
||||||
|
return provider.image(modelId)
|
||||||
|
} else {
|
||||||
|
throw new ClientFactoryError(`Image model function not found for provider "${providerId}"`)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof ClientFactoryError) {
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
throw new ClientFactoryError(
|
||||||
|
`Failed to create image client for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||||
|
providerId,
|
||||||
|
error instanceof Error ? error : undefined
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取支持的 Providers 列表
|
* 获取支持的 Providers 列表
|
||||||
*/
|
*/
|
||||||
@ -121,3 +180,6 @@ export const createClient = (providerId: string, modelId?: string, options?: any
|
|||||||
export const getSupportedProviders = () => ApiClientFactory.getSupportedProviders()
|
export const getSupportedProviders = () => ApiClientFactory.getSupportedProviders()
|
||||||
|
|
||||||
export const getClientInfo = (providerId: string) => ApiClientFactory.getClientInfo(providerId)
|
export const getClientInfo = (providerId: string) => ApiClientFactory.getClientInfo(providerId)
|
||||||
|
|
||||||
|
export const createImageClient = (providerId: string, modelId?: string, options?: any) =>
|
||||||
|
ApiClientFactory.createImageClient(providerId, modelId, options)
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
* 统一的AI SDK客户端实现
|
* 统一的AI SDK客户端实现
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { generateObject, generateText, streamObject, streamText } from 'ai'
|
import { experimental_generateImage as generateImage, generateObject, generateText, streamObject, streamText } from 'ai'
|
||||||
|
|
||||||
import { ApiClientFactory } from './ApiClientFactory'
|
import { ApiClientFactory } from './ApiClientFactory'
|
||||||
|
|
||||||
@ -65,6 +65,17 @@ export class UniversalAiSdkClient {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async generateImage(
|
||||||
|
modelId: string,
|
||||||
|
params: Omit<Parameters<typeof generateImage>[0], 'model'>
|
||||||
|
): Promise<ReturnType<typeof generateImage>> {
|
||||||
|
const model = await ApiClientFactory.createImageClient(this.providerId, modelId, this.options)
|
||||||
|
return generateImage({
|
||||||
|
model,
|
||||||
|
...params
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取客户端信息
|
* 获取客户端信息
|
||||||
*/
|
*/
|
||||||
@ -76,6 +87,6 @@ export class UniversalAiSdkClient {
|
|||||||
/**
|
/**
|
||||||
* 创建客户端实例的工厂函数
|
* 创建客户端实例的工厂函数
|
||||||
*/
|
*/
|
||||||
export function createUniversalClient(providerId: string, options: any = {}): UniversalAiSdkClient {
|
export function createUniversalClient(providerId: string, options: ProviderOptions): UniversalAiSdkClient {
|
||||||
return new UniversalAiSdkClient(providerId, options)
|
return new UniversalAiSdkClient(providerId, options)
|
||||||
}
|
}
|
||||||
|
|||||||
6
packages/aiCore/src/clients/types.ts
Normal file
6
packages/aiCore/src/clients/types.ts
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
export type ProviderOptions = {
|
||||||
|
apiKey?: string
|
||||||
|
baseURL?: string
|
||||||
|
apiVersion?: string
|
||||||
|
headers?: Record<string, string | unknown>
|
||||||
|
}
|
||||||
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
// 导入内部使用的类和函数
|
// 导入内部使用的类和函数
|
||||||
import { ApiClientFactory } from './clients/ApiClientFactory'
|
import { ApiClientFactory } from './clients/ApiClientFactory'
|
||||||
|
import { ProviderOptions } from './clients/types'
|
||||||
import { createUniversalClient } from './clients/UniversalAiSdkClient'
|
import { createUniversalClient } from './clients/UniversalAiSdkClient'
|
||||||
import { aiProviderRegistry, isProviderSupported } from './providers/registry'
|
import { aiProviderRegistry, isProviderSupported } from './providers/registry'
|
||||||
|
|
||||||
@ -61,19 +62,23 @@ export const AiCore = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 便捷的预配置clients创建函数
|
// 便捷的预配置clients创建函数
|
||||||
export const createOpenAIClient = (options: { apiKey: string; baseURL?: string }) => {
|
export const createOpenAIClient = (options: ProviderOptions) => {
|
||||||
return createUniversalClient('openai', options)
|
return createUniversalClient('openai', options)
|
||||||
}
|
}
|
||||||
|
|
||||||
export const createAnthropicClient = (options: { apiKey: string; baseURL?: string }) => {
|
export const createOpenAICompatibleClient = (options: ProviderOptions) => {
|
||||||
|
return createUniversalClient('openai-compatible', options)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const createAnthropicClient = (options: ProviderOptions) => {
|
||||||
return createUniversalClient('anthropic', options)
|
return createUniversalClient('anthropic', options)
|
||||||
}
|
}
|
||||||
|
|
||||||
export const createGoogleClient = (options: { apiKey: string; baseURL?: string }) => {
|
export const createGoogleClient = (options: ProviderOptions) => {
|
||||||
return createUniversalClient('google', options)
|
return createUniversalClient('google', options)
|
||||||
}
|
}
|
||||||
|
|
||||||
export const createXAIClient = (options: { apiKey: string; baseURL?: string }) => {
|
export const createXAIClient = (options: ProviderOptions) => {
|
||||||
return createUniversalClient('xai', options)
|
return createUniversalClient('xai', options)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import type { TextStreamPart, ToolSet } from 'ai'
|
import type { LanguageModelV1Middleware, TextStreamPart, ToolSet } from 'ai'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* AI 请求上下文
|
* AI 请求上下文
|
||||||
@ -37,6 +37,9 @@ export interface AiPlugin {
|
|||||||
tools?: TOOLS
|
tools?: TOOLS
|
||||||
stopStream: () => void
|
stopStream: () => void
|
||||||
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
|
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
|
||||||
|
|
||||||
|
// AI SDK 原生中间件
|
||||||
|
aiSdkMiddlewares?: LanguageModelV1Middleware[]
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import type { LanguageModelV1Middleware } from 'ai'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* AI Provider 注册表
|
* AI Provider 注册表
|
||||||
* 统一管理所有 AI SDK Providers 的动态导入和工厂函数
|
* 统一管理所有 AI SDK Providers 的动态导入和工厂函数
|
||||||
@ -11,6 +13,10 @@ export interface ProviderConfig {
|
|||||||
import: () => Promise<any>
|
import: () => Promise<any>
|
||||||
// 创建函数名称
|
// 创建函数名称
|
||||||
creatorFunctionName: string
|
creatorFunctionName: string
|
||||||
|
// 是否支持图片生成
|
||||||
|
supportsImageGeneration?: boolean
|
||||||
|
// AI SDK 原生中间件
|
||||||
|
aiSdkMiddlewares?: LanguageModelV1Middleware[]
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -43,115 +49,134 @@ export class AiProviderRegistry {
|
|||||||
id: 'openai',
|
id: 'openai',
|
||||||
name: 'OpenAI',
|
name: 'OpenAI',
|
||||||
import: () => import('@ai-sdk/openai'),
|
import: () => import('@ai-sdk/openai'),
|
||||||
creatorFunctionName: 'createOpenAI'
|
creatorFunctionName: 'createOpenAI',
|
||||||
|
supportsImageGeneration: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'anthropic',
|
id: 'anthropic',
|
||||||
name: 'Anthropic',
|
name: 'Anthropic',
|
||||||
import: () => import('@ai-sdk/anthropic'),
|
import: () => import('@ai-sdk/anthropic'),
|
||||||
creatorFunctionName: 'createAnthropic'
|
creatorFunctionName: 'createAnthropic',
|
||||||
|
supportsImageGeneration: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'google',
|
id: 'google',
|
||||||
name: 'Google Generative AI',
|
name: 'Google Generative AI',
|
||||||
import: () => import('@ai-sdk/google'),
|
import: () => import('@ai-sdk/google'),
|
||||||
creatorFunctionName: 'createGoogleGenerativeAI'
|
creatorFunctionName: 'createGoogleGenerativeAI',
|
||||||
|
supportsImageGeneration: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'google-vertex',
|
id: 'google-vertex',
|
||||||
name: 'Google Vertex AI',
|
name: 'Google Vertex AI',
|
||||||
import: () => import('@ai-sdk/google-vertex'),
|
import: () => import('@ai-sdk/google-vertex'),
|
||||||
creatorFunctionName: 'createVertex'
|
creatorFunctionName: 'createVertex',
|
||||||
|
supportsImageGeneration: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'mistral',
|
id: 'mistral',
|
||||||
name: 'Mistral AI',
|
name: 'Mistral AI',
|
||||||
import: () => import('@ai-sdk/mistral'),
|
import: () => import('@ai-sdk/mistral'),
|
||||||
creatorFunctionName: 'createMistral'
|
creatorFunctionName: 'createMistral',
|
||||||
|
supportsImageGeneration: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'xai',
|
id: 'xai',
|
||||||
name: 'xAI (Grok)',
|
name: 'xAI (Grok)',
|
||||||
import: () => import('@ai-sdk/xai'),
|
import: () => import('@ai-sdk/xai'),
|
||||||
creatorFunctionName: 'createXai'
|
creatorFunctionName: 'createXai',
|
||||||
|
supportsImageGeneration: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'azure',
|
id: 'azure',
|
||||||
name: 'Azure OpenAI',
|
name: 'Azure OpenAI',
|
||||||
import: () => import('@ai-sdk/azure'),
|
import: () => import('@ai-sdk/azure'),
|
||||||
creatorFunctionName: 'createAzure'
|
creatorFunctionName: 'createAzure',
|
||||||
|
supportsImageGeneration: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'bedrock',
|
id: 'bedrock',
|
||||||
name: 'Amazon Bedrock',
|
name: 'Amazon Bedrock',
|
||||||
import: () => import('@ai-sdk/amazon-bedrock'),
|
import: () => import('@ai-sdk/amazon-bedrock'),
|
||||||
creatorFunctionName: 'createAmazonBedrock'
|
creatorFunctionName: 'createAmazonBedrock',
|
||||||
|
supportsImageGeneration: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'cohere',
|
id: 'cohere',
|
||||||
name: 'Cohere',
|
name: 'Cohere',
|
||||||
import: () => import('@ai-sdk/cohere'),
|
import: () => import('@ai-sdk/cohere'),
|
||||||
creatorFunctionName: 'createCohere'
|
creatorFunctionName: 'createCohere',
|
||||||
|
supportsImageGeneration: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'groq',
|
id: 'groq',
|
||||||
name: 'Groq',
|
name: 'Groq',
|
||||||
import: () => import('@ai-sdk/groq'),
|
import: () => import('@ai-sdk/groq'),
|
||||||
creatorFunctionName: 'createGroq'
|
creatorFunctionName: 'createGroq',
|
||||||
|
supportsImageGeneration: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'together',
|
id: 'together',
|
||||||
name: 'Together.ai',
|
name: 'Together.ai',
|
||||||
import: () => import('@ai-sdk/togetherai'),
|
import: () => import('@ai-sdk/togetherai'),
|
||||||
creatorFunctionName: 'createTogetherAI'
|
creatorFunctionName: 'createTogetherAI',
|
||||||
|
supportsImageGeneration: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'fireworks',
|
id: 'fireworks',
|
||||||
name: 'Fireworks',
|
name: 'Fireworks',
|
||||||
import: () => import('@ai-sdk/fireworks'),
|
import: () => import('@ai-sdk/fireworks'),
|
||||||
creatorFunctionName: 'createFireworks'
|
creatorFunctionName: 'createFireworks',
|
||||||
|
supportsImageGeneration: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'deepseek',
|
id: 'deepseek',
|
||||||
name: 'DeepSeek',
|
name: 'DeepSeek',
|
||||||
import: () => import('@ai-sdk/deepseek'),
|
import: () => import('@ai-sdk/deepseek'),
|
||||||
creatorFunctionName: 'createDeepSeek'
|
creatorFunctionName: 'createDeepSeek',
|
||||||
|
supportsImageGeneration: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'cerebras',
|
id: 'cerebras',
|
||||||
name: 'Cerebras',
|
name: 'Cerebras',
|
||||||
import: () => import('@ai-sdk/cerebras'),
|
import: () => import('@ai-sdk/cerebras'),
|
||||||
creatorFunctionName: 'createCerebras'
|
creatorFunctionName: 'createCerebras',
|
||||||
|
supportsImageGeneration: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'deepinfra',
|
id: 'deepinfra',
|
||||||
name: 'DeepInfra',
|
name: 'DeepInfra',
|
||||||
import: () => import('@ai-sdk/deepinfra'),
|
import: () => import('@ai-sdk/deepinfra'),
|
||||||
creatorFunctionName: 'createDeepInfra'
|
creatorFunctionName: 'createDeepInfra',
|
||||||
|
supportsImageGeneration: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'replicate',
|
id: 'replicate',
|
||||||
name: 'Replicate',
|
name: 'Replicate',
|
||||||
import: () => import('@ai-sdk/replicate'),
|
import: () => import('@ai-sdk/replicate'),
|
||||||
creatorFunctionName: 'createReplicate'
|
creatorFunctionName: 'createReplicate',
|
||||||
|
supportsImageGeneration: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'perplexity',
|
id: 'perplexity',
|
||||||
name: 'Perplexity',
|
name: 'Perplexity',
|
||||||
import: () => import('@ai-sdk/perplexity'),
|
import: () => import('@ai-sdk/perplexity'),
|
||||||
creatorFunctionName: 'createPerplexity'
|
creatorFunctionName: 'createPerplexity',
|
||||||
|
supportsImageGeneration: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'fal',
|
id: 'fal',
|
||||||
name: 'Fal AI',
|
name: 'Fal AI',
|
||||||
import: () => import('@ai-sdk/fal'),
|
import: () => import('@ai-sdk/fal'),
|
||||||
creatorFunctionName: 'createFal'
|
creatorFunctionName: 'createFal',
|
||||||
|
supportsImageGeneration: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'vercel',
|
id: 'vercel',
|
||||||
name: 'Vercel',
|
name: 'Vercel',
|
||||||
import: () => import('@ai-sdk/vercel'),
|
import: () => import('@ai-sdk/vercel'),
|
||||||
creatorFunctionName: 'createVercel'
|
creatorFunctionName: 'createVercel',
|
||||||
|
supportsImageGeneration: false
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -161,25 +186,36 @@ export class AiProviderRegistry {
|
|||||||
id: 'ollama',
|
id: 'ollama',
|
||||||
name: 'Ollama',
|
name: 'Ollama',
|
||||||
import: () => import('ollama-ai-provider'),
|
import: () => import('ollama-ai-provider'),
|
||||||
creatorFunctionName: 'createOllama'
|
creatorFunctionName: 'createOllama',
|
||||||
|
supportsImageGeneration: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'qwen',
|
id: 'qwen',
|
||||||
name: 'Qwen',
|
name: 'Qwen',
|
||||||
import: () => import('qwen-ai-provider'),
|
import: () => import('qwen-ai-provider'),
|
||||||
creatorFunctionName: 'createQwen'
|
creatorFunctionName: 'createQwen',
|
||||||
|
supportsImageGeneration: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'zhipu',
|
id: 'zhipu',
|
||||||
name: 'Zhipu AI',
|
name: 'Zhipu AI',
|
||||||
import: () => import('zhipu-ai-provider'),
|
import: () => import('zhipu-ai-provider'),
|
||||||
creatorFunctionName: 'createZhipu'
|
creatorFunctionName: 'createZhipu',
|
||||||
|
supportsImageGeneration: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'anthropic-vertex',
|
id: 'anthropic-vertex',
|
||||||
name: 'Anthropic Vertex AI',
|
name: 'Anthropic Vertex AI',
|
||||||
import: () => import('anthropic-vertex-ai'),
|
import: () => import('anthropic-vertex-ai'),
|
||||||
creatorFunctionName: 'createAnthropicVertex'
|
creatorFunctionName: 'createAnthropicVertex',
|
||||||
|
supportsImageGeneration: false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'openrouter',
|
||||||
|
name: 'OpenRouter',
|
||||||
|
import: () => import('@openrouter/ai-sdk-provider'),
|
||||||
|
creatorFunctionName: 'createOpenRouter',
|
||||||
|
supportsImageGeneration: false
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user