mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 13:31:32 +08:00
feat: enhance AI Core with image generation capabilities
- Introduced `createImageClient` method in `ApiClientFactory` to support image generation for various providers. - Updated `UniversalAiSdkClient` to include `generateImage` method, allowing image generation through the unified client interface. - Refactored client creation functions to utilize the new `ProviderOptions` type for improved type safety. - Enhanced the provider registry to indicate which providers support image generation, streamlining client creation and usage. - Updated type definitions in `types.ts` to reflect changes in client options and middleware support.
This commit is contained in:
parent
8910281b09
commit
0f22fa18c3
@ -3,7 +3,8 @@
|
||||
* 整合现有实现的改进版API客户端工厂
|
||||
*/
|
||||
|
||||
import type { LanguageModelV1 } from 'ai'
|
||||
import type { ImageModelV1 } from '@ai-sdk/provider'
|
||||
import { type LanguageModelV1, wrapLanguageModel } from 'ai'
|
||||
|
||||
import { aiProviderRegistry } from '../providers/registry'
|
||||
import { ProviderOptions } from './types'
|
||||
@ -11,7 +12,7 @@ import { ProviderOptions } from './types'
|
||||
// 客户端配置接口
|
||||
export interface ClientConfig {
|
||||
providerId: string
|
||||
options?: any
|
||||
options?: ProviderOptions
|
||||
}
|
||||
|
||||
// 错误类型
|
||||
@ -69,7 +70,17 @@ export class ApiClientFactory {
|
||||
|
||||
// 返回模型实例
|
||||
if (typeof provider === 'function') {
|
||||
return provider(modelId)
|
||||
let model = provider(modelId)
|
||||
|
||||
// 应用 AI SDK 中间件
|
||||
if (providerConfig.aiSdkMiddlewares) {
|
||||
model = wrapLanguageModel({
|
||||
model: model,
|
||||
middleware: providerConfig.aiSdkMiddlewares
|
||||
})
|
||||
}
|
||||
|
||||
return model
|
||||
} else {
|
||||
throw new ClientFactoryError(`Unknown model access pattern for provider "${providerId}"`)
|
||||
}
|
||||
@ -85,6 +96,54 @@ export class ApiClientFactory {
|
||||
}
|
||||
}
|
||||
|
||||
static async createImageClient(
|
||||
providerId: string,
|
||||
modelId: string = 'default',
|
||||
options: ProviderOptions
|
||||
): Promise<ImageModelV1> {
|
||||
try {
|
||||
if (!aiProviderRegistry.isSupported(providerId)) {
|
||||
throw new ClientFactoryError(`Provider "${providerId}" is not supported`, providerId)
|
||||
}
|
||||
|
||||
const providerConfig = aiProviderRegistry.getProvider(providerId)
|
||||
if (!providerConfig) {
|
||||
throw new ClientFactoryError(`Provider "${providerId}" is not registered`, providerId)
|
||||
}
|
||||
|
||||
if (!providerConfig.supportsImageGeneration) {
|
||||
throw new ClientFactoryError(`Provider "${providerId}" does not support image generation`, providerId)
|
||||
}
|
||||
|
||||
const module = await providerConfig.import()
|
||||
|
||||
const creatorFunction = module[providerConfig.creatorFunctionName]
|
||||
|
||||
if (typeof creatorFunction !== 'function') {
|
||||
throw new ClientFactoryError(
|
||||
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${providerId}"`
|
||||
)
|
||||
}
|
||||
|
||||
const provider = creatorFunction(options)
|
||||
|
||||
if (provider && typeof provider.image === 'function') {
|
||||
return provider.image(modelId)
|
||||
} else {
|
||||
throw new ClientFactoryError(`Image model function not found for provider "${providerId}"`)
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof ClientFactoryError) {
|
||||
throw error
|
||||
}
|
||||
throw new ClientFactoryError(
|
||||
`Failed to create image client for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
providerId,
|
||||
error instanceof Error ? error : undefined
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取支持的 Providers 列表
|
||||
*/
|
||||
@ -122,3 +181,6 @@ export const createClient = (providerId: string, modelId?: string, options?: any
|
||||
export const getSupportedProviders = () => ApiClientFactory.getSupportedProviders()
|
||||
|
||||
export const getClientInfo = (providerId: string) => ApiClientFactory.getClientInfo(providerId)
|
||||
|
||||
export const createImageClient = (providerId: string, modelId?: string, options?: any) =>
|
||||
ApiClientFactory.createImageClient(providerId, modelId, options)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
* 统一的AI SDK客户端实现
|
||||
*/
|
||||
|
||||
import { generateObject, generateText, streamObject, streamText } from 'ai'
|
||||
import { experimental_generateImage as generateImage, generateObject, generateText, streamObject, streamText } from 'ai'
|
||||
|
||||
import { ApiClientFactory } from './ApiClientFactory'
|
||||
import { ProviderOptions } from './types'
|
||||
@ -78,6 +78,17 @@ export class UniversalAiSdkClient {
|
||||
})
|
||||
}
|
||||
|
||||
async generateImage(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateImage>> {
|
||||
const model = await ApiClientFactory.createImageClient(this.providerId, modelId, this.options)
|
||||
return generateImage({
|
||||
model,
|
||||
...params
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取客户端信息
|
||||
*/
|
||||
@ -89,6 +100,6 @@ export class UniversalAiSdkClient {
|
||||
/**
|
||||
* 创建客户端实例的工厂函数
|
||||
*/
|
||||
export function createUniversalClient(providerId: string, options: any = {}): UniversalAiSdkClient {
|
||||
export function createUniversalClient(providerId: string, options: ProviderOptions): UniversalAiSdkClient {
|
||||
return new UniversalAiSdkClient(providerId, options)
|
||||
}
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
export type ProviderOptions = {
|
||||
name: string
|
||||
apiKey?: string
|
||||
apiHost: string
|
||||
baseURL?: string
|
||||
apiVersion?: string
|
||||
headers?: Record<string, string | unknown>
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
|
||||
// 导入内部使用的类和函数
|
||||
import { ApiClientFactory } from './clients/ApiClientFactory'
|
||||
import { ProviderOptions } from './clients/types'
|
||||
import { createUniversalClient } from './clients/UniversalAiSdkClient'
|
||||
import { aiProviderRegistry, isProviderSupported } from './providers/registry'
|
||||
|
||||
@ -61,19 +62,23 @@ export const AiCore = {
|
||||
}
|
||||
|
||||
// 便捷的预配置clients创建函数
|
||||
export const createOpenAIClient = (options: { apiKey: string; baseURL?: string }) => {
|
||||
export const createOpenAIClient = (options: ProviderOptions) => {
|
||||
return createUniversalClient('openai', options)
|
||||
}
|
||||
|
||||
export const createAnthropicClient = (options: { apiKey: string; baseURL?: string }) => {
|
||||
export const createOpenAICompatibleClient = (options: ProviderOptions) => {
|
||||
return createUniversalClient('openai-compatible', options)
|
||||
}
|
||||
|
||||
export const createAnthropicClient = (options: ProviderOptions) => {
|
||||
return createUniversalClient('anthropic', options)
|
||||
}
|
||||
|
||||
export const createGoogleClient = (options: { apiKey: string; baseURL?: string }) => {
|
||||
export const createGoogleClient = (options: ProviderOptions) => {
|
||||
return createUniversalClient('google', options)
|
||||
}
|
||||
|
||||
export const createXAIClient = (options: { apiKey: string; baseURL?: string }) => {
|
||||
export const createXAIClient = (options: ProviderOptions) => {
|
||||
return createUniversalClient('xai', options)
|
||||
}
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
import type { LanguageModelV1Middleware, TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
/**
|
||||
* AI 请求上下文
|
||||
@ -37,6 +37,9 @@ export interface AiPlugin {
|
||||
tools?: TOOLS
|
||||
stopStream: () => void
|
||||
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
|
||||
|
||||
// AI SDK 原生中间件
|
||||
aiSdkMiddlewares?: LanguageModelV1Middleware[]
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import type { LanguageModelV1Middleware } from 'ai'
|
||||
|
||||
/**
|
||||
* AI Provider 注册表
|
||||
* 统一管理所有 AI SDK Providers 的动态导入和工厂函数
|
||||
@ -11,6 +13,10 @@ export interface ProviderConfig {
|
||||
import: () => Promise<any>
|
||||
// 创建函数名称
|
||||
creatorFunctionName: string
|
||||
// 是否支持图片生成
|
||||
supportsImageGeneration?: boolean
|
||||
// AI SDK 原生中间件
|
||||
aiSdkMiddlewares?: LanguageModelV1Middleware[]
|
||||
}
|
||||
|
||||
/**
|
||||
@ -43,7 +49,8 @@ export class AiProviderRegistry {
|
||||
id: 'openai',
|
||||
name: 'OpenAI',
|
||||
import: () => import('@ai-sdk/openai'),
|
||||
creatorFunctionName: 'createOpenAI'
|
||||
creatorFunctionName: 'createOpenAI',
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'openai-compatible',
|
||||
@ -55,109 +62,127 @@ export class AiProviderRegistry {
|
||||
id: 'anthropic',
|
||||
name: 'Anthropic',
|
||||
import: () => import('@ai-sdk/anthropic'),
|
||||
creatorFunctionName: 'createAnthropic'
|
||||
creatorFunctionName: 'createAnthropic',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'google',
|
||||
name: 'Google Generative AI',
|
||||
import: () => import('@ai-sdk/google'),
|
||||
creatorFunctionName: 'createGoogleGenerativeAI'
|
||||
creatorFunctionName: 'createGoogleGenerativeAI',
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'google-vertex',
|
||||
name: 'Google Vertex AI',
|
||||
import: () => import('@ai-sdk/google-vertex'),
|
||||
creatorFunctionName: 'createVertex'
|
||||
creatorFunctionName: 'createVertex',
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'mistral',
|
||||
name: 'Mistral AI',
|
||||
import: () => import('@ai-sdk/mistral'),
|
||||
creatorFunctionName: 'createMistral'
|
||||
creatorFunctionName: 'createMistral',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'xai',
|
||||
name: 'xAI (Grok)',
|
||||
import: () => import('@ai-sdk/xai'),
|
||||
creatorFunctionName: 'createXai'
|
||||
creatorFunctionName: 'createXai',
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'azure',
|
||||
name: 'Azure OpenAI',
|
||||
import: () => import('@ai-sdk/azure'),
|
||||
creatorFunctionName: 'createAzure'
|
||||
creatorFunctionName: 'createAzure',
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'bedrock',
|
||||
name: 'Amazon Bedrock',
|
||||
import: () => import('@ai-sdk/amazon-bedrock'),
|
||||
creatorFunctionName: 'createAmazonBedrock'
|
||||
creatorFunctionName: 'createAmazonBedrock',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'cohere',
|
||||
name: 'Cohere',
|
||||
import: () => import('@ai-sdk/cohere'),
|
||||
creatorFunctionName: 'createCohere'
|
||||
creatorFunctionName: 'createCohere',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'groq',
|
||||
name: 'Groq',
|
||||
import: () => import('@ai-sdk/groq'),
|
||||
creatorFunctionName: 'createGroq'
|
||||
creatorFunctionName: 'createGroq',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'together',
|
||||
name: 'Together.ai',
|
||||
import: () => import('@ai-sdk/togetherai'),
|
||||
creatorFunctionName: 'createTogetherAI'
|
||||
creatorFunctionName: 'createTogetherAI',
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'fireworks',
|
||||
name: 'Fireworks',
|
||||
import: () => import('@ai-sdk/fireworks'),
|
||||
creatorFunctionName: 'createFireworks'
|
||||
creatorFunctionName: 'createFireworks',
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'deepseek',
|
||||
name: 'DeepSeek',
|
||||
import: () => import('@ai-sdk/deepseek'),
|
||||
creatorFunctionName: 'createDeepSeek'
|
||||
creatorFunctionName: 'createDeepSeek',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'cerebras',
|
||||
name: 'Cerebras',
|
||||
import: () => import('@ai-sdk/cerebras'),
|
||||
creatorFunctionName: 'createCerebras'
|
||||
creatorFunctionName: 'createCerebras',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'deepinfra',
|
||||
name: 'DeepInfra',
|
||||
import: () => import('@ai-sdk/deepinfra'),
|
||||
creatorFunctionName: 'createDeepInfra'
|
||||
creatorFunctionName: 'createDeepInfra',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'replicate',
|
||||
name: 'Replicate',
|
||||
import: () => import('@ai-sdk/replicate'),
|
||||
creatorFunctionName: 'createReplicate'
|
||||
creatorFunctionName: 'createReplicate',
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'perplexity',
|
||||
name: 'Perplexity',
|
||||
import: () => import('@ai-sdk/perplexity'),
|
||||
creatorFunctionName: 'createPerplexity'
|
||||
creatorFunctionName: 'createPerplexity',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'fal',
|
||||
name: 'Fal AI',
|
||||
import: () => import('@ai-sdk/fal'),
|
||||
creatorFunctionName: 'createFal'
|
||||
creatorFunctionName: 'createFal',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'vercel',
|
||||
name: 'Vercel',
|
||||
import: () => import('@ai-sdk/vercel'),
|
||||
creatorFunctionName: 'createVercel'
|
||||
creatorFunctionName: 'createVercel',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
]
|
||||
|
||||
@ -167,31 +192,36 @@ export class AiProviderRegistry {
|
||||
id: 'ollama',
|
||||
name: 'Ollama',
|
||||
import: () => import('ollama-ai-provider'),
|
||||
creatorFunctionName: 'createOllama'
|
||||
creatorFunctionName: 'createOllama',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'qwen',
|
||||
name: 'Qwen',
|
||||
import: () => import('qwen-ai-provider'),
|
||||
creatorFunctionName: 'createQwen'
|
||||
creatorFunctionName: 'createQwen',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'zhipu',
|
||||
name: 'Zhipu AI',
|
||||
import: () => import('zhipu-ai-provider'),
|
||||
creatorFunctionName: 'createZhipu'
|
||||
creatorFunctionName: 'createZhipu',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'anthropic-vertex',
|
||||
name: 'Anthropic Vertex AI',
|
||||
import: () => import('anthropic-vertex-ai'),
|
||||
creatorFunctionName: 'createAnthropicVertex'
|
||||
creatorFunctionName: 'createAnthropicVertex',
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'openrouter',
|
||||
name: 'OpenRouter',
|
||||
import: () => import('@openrouter/ai-sdk-provider'),
|
||||
creatorFunctionName: 'createOpenRouter'
|
||||
creatorFunctionName: 'createOpenRouter',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user