diff --git a/packages/aiCore/src/clients/ApiClientFactory.ts b/packages/aiCore/src/clients/ApiClientFactory.ts index 820c182b08..045b2b3b56 100644 --- a/packages/aiCore/src/clients/ApiClientFactory.ts +++ b/packages/aiCore/src/clients/ApiClientFactory.ts @@ -3,7 +3,8 @@ * 整合现有实现的改进版API客户端工厂 */ -import type { LanguageModelV1 } from 'ai' +import type { ImageModelV1 } from '@ai-sdk/provider' +import { type LanguageModelV1, wrapLanguageModel } from 'ai' import { aiProviderRegistry } from '../providers/registry' import { ProviderOptions } from './types' @@ -11,7 +12,7 @@ import { ProviderOptions } from './types' // 客户端配置接口 export interface ClientConfig { providerId: string - options?: any + options?: ProviderOptions } // 错误类型 @@ -69,7 +70,17 @@ export class ApiClientFactory { // 返回模型实例 if (typeof provider === 'function') { - return provider(modelId) + let model = provider(modelId) + + // 应用 AI SDK 中间件 + if (providerConfig.aiSdkMiddlewares) { + model = wrapLanguageModel({ + model: model, + middleware: providerConfig.aiSdkMiddlewares + }) + } + + return model } else { throw new ClientFactoryError(`Unknown model access pattern for provider "${providerId}"`) } @@ -85,6 +96,54 @@ export class ApiClientFactory { } } + static async createImageClient( + providerId: string, + modelId: string = 'default', + options: ProviderOptions + ): Promise { + try { + if (!aiProviderRegistry.isSupported(providerId)) { + throw new ClientFactoryError(`Provider "${providerId}" is not supported`, providerId) + } + + const providerConfig = aiProviderRegistry.getProvider(providerId) + if (!providerConfig) { + throw new ClientFactoryError(`Provider "${providerId}" is not registered`, providerId) + } + + if (!providerConfig.supportsImageGeneration) { + throw new ClientFactoryError(`Provider "${providerId}" does not support image generation`, providerId) + } + + const module = await providerConfig.import() + + const creatorFunction = module[providerConfig.creatorFunctionName] + + if (typeof creatorFunction !== 'function') { + throw new ClientFactoryError( + `Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${providerId}"` + ) + } + + const provider = creatorFunction(options) + + if (provider && typeof provider.image === 'function') { + return provider.image(modelId) + } else { + throw new ClientFactoryError(`Image model function not found for provider "${providerId}"`) + } + } catch (error) { + if (error instanceof ClientFactoryError) { + throw error + } + throw new ClientFactoryError( + `Failed to create image client for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`, + providerId, + error instanceof Error ? error : undefined + ) + } + } + /** * 获取支持的 Providers 列表 */ @@ -122,3 +181,6 @@ export const createClient = (providerId: string, modelId?: string, options?: any export const getSupportedProviders = () => ApiClientFactory.getSupportedProviders() export const getClientInfo = (providerId: string) => ApiClientFactory.getClientInfo(providerId) + +export const createImageClient = (providerId: string, modelId?: string, options?: any) => + ApiClientFactory.createImageClient(providerId, modelId, options) diff --git a/packages/aiCore/src/clients/UniversalAiSdkClient.ts b/packages/aiCore/src/clients/UniversalAiSdkClient.ts index eeeb76a331..ec85451f11 100644 --- a/packages/aiCore/src/clients/UniversalAiSdkClient.ts +++ b/packages/aiCore/src/clients/UniversalAiSdkClient.ts @@ -3,7 +3,7 @@ * 统一的AI SDK客户端实现 */ -import { generateObject, generateText, streamObject, streamText } from 'ai' +import { experimental_generateImage as generateImage, generateObject, generateText, streamObject, streamText } from 'ai' import { ApiClientFactory } from './ApiClientFactory' import { ProviderOptions } from './types' @@ -78,6 +78,17 @@ export class UniversalAiSdkClient { }) } + async generateImage( + modelId: string, + params: Omit[0], 'model'> + ): Promise> { + const model = await ApiClientFactory.createImageClient(this.providerId, modelId, this.options) + return generateImage({ + model, + ...params + }) + } + /** * 获取客户端信息 */ @@ -89,6 +100,6 @@ export class UniversalAiSdkClient { /** * 创建客户端实例的工厂函数 */ -export function createUniversalClient(providerId: string, options: any = {}): UniversalAiSdkClient { +export function createUniversalClient(providerId: string, options: ProviderOptions): UniversalAiSdkClient { return new UniversalAiSdkClient(providerId, options) } diff --git a/packages/aiCore/src/clients/types.ts b/packages/aiCore/src/clients/types.ts index 3957f9bcce..f080edb225 100644 --- a/packages/aiCore/src/clients/types.ts +++ b/packages/aiCore/src/clients/types.ts @@ -1,7 +1,6 @@ export type ProviderOptions = { - name: string apiKey?: string - apiHost: string + baseURL?: string apiVersion?: string headers?: Record } diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index 137d4d371c..59201cc913 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -5,6 +5,7 @@ // 导入内部使用的类和函数 import { ApiClientFactory } from './clients/ApiClientFactory' +import { ProviderOptions } from './clients/types' import { createUniversalClient } from './clients/UniversalAiSdkClient' import { aiProviderRegistry, isProviderSupported } from './providers/registry' @@ -61,19 +62,23 @@ export const AiCore = { } // 便捷的预配置clients创建函数 -export const createOpenAIClient = (options: { apiKey: string; baseURL?: string }) => { +export const createOpenAIClient = (options: ProviderOptions) => { return createUniversalClient('openai', options) } -export const createAnthropicClient = (options: { apiKey: string; baseURL?: string }) => { +export const createOpenAICompatibleClient = (options: ProviderOptions) => { + return createUniversalClient('openai-compatible', options) +} + +export const createAnthropicClient = (options: ProviderOptions) => { return createUniversalClient('anthropic', options) } -export const createGoogleClient = (options: { apiKey: string; baseURL?: string }) => { +export const createGoogleClient = (options: ProviderOptions) => { return createUniversalClient('google', options) } -export const createXAIClient = (options: { apiKey: string; baseURL?: string }) => { +export const createXAIClient = (options: ProviderOptions) => { return createUniversalClient('xai', options) } diff --git a/packages/aiCore/src/middleware/types.ts b/packages/aiCore/src/middleware/types.ts index ef075420a0..94594744d8 100644 --- a/packages/aiCore/src/middleware/types.ts +++ b/packages/aiCore/src/middleware/types.ts @@ -1,4 +1,4 @@ -import type { TextStreamPart, ToolSet } from 'ai' +import type { LanguageModelV1Middleware, TextStreamPart, ToolSet } from 'ai' /** * AI 请求上下文 @@ -37,6 +37,9 @@ export interface AiPlugin { tools?: TOOLS stopStream: () => void }) => TransformStream, TextStreamPart> + + // AI SDK 原生中间件 + aiSdkMiddlewares?: LanguageModelV1Middleware[] } /** diff --git a/packages/aiCore/src/providers/registry.ts b/packages/aiCore/src/providers/registry.ts index 0012622ecd..1330efa2da 100644 --- a/packages/aiCore/src/providers/registry.ts +++ b/packages/aiCore/src/providers/registry.ts @@ -1,3 +1,5 @@ +import type { LanguageModelV1Middleware } from 'ai' + /** * AI Provider 注册表 * 统一管理所有 AI SDK Providers 的动态导入和工厂函数 @@ -11,6 +13,10 @@ export interface ProviderConfig { import: () => Promise // 创建函数名称 creatorFunctionName: string + // 是否支持图片生成 + supportsImageGeneration?: boolean + // AI SDK 原生中间件 + aiSdkMiddlewares?: LanguageModelV1Middleware[] } /** @@ -43,7 +49,8 @@ export class AiProviderRegistry { id: 'openai', name: 'OpenAI', import: () => import('@ai-sdk/openai'), - creatorFunctionName: 'createOpenAI' + creatorFunctionName: 'createOpenAI', + supportsImageGeneration: true }, { id: 'openai-compatible', @@ -55,109 +62,127 @@ export class AiProviderRegistry { id: 'anthropic', name: 'Anthropic', import: () => import('@ai-sdk/anthropic'), - creatorFunctionName: 'createAnthropic' + creatorFunctionName: 'createAnthropic', + supportsImageGeneration: false }, { id: 'google', name: 'Google Generative AI', import: () => import('@ai-sdk/google'), - creatorFunctionName: 'createGoogleGenerativeAI' + creatorFunctionName: 'createGoogleGenerativeAI', + supportsImageGeneration: true }, { id: 'google-vertex', name: 'Google Vertex AI', import: () => import('@ai-sdk/google-vertex'), - creatorFunctionName: 'createVertex' + creatorFunctionName: 'createVertex', + supportsImageGeneration: true }, { id: 'mistral', name: 'Mistral AI', import: () => import('@ai-sdk/mistral'), - creatorFunctionName: 'createMistral' + creatorFunctionName: 'createMistral', + supportsImageGeneration: false }, { id: 'xai', name: 'xAI (Grok)', import: () => import('@ai-sdk/xai'), - creatorFunctionName: 'createXai' + creatorFunctionName: 'createXai', + supportsImageGeneration: true }, { id: 'azure', name: 'Azure OpenAI', import: () => import('@ai-sdk/azure'), - creatorFunctionName: 'createAzure' + creatorFunctionName: 'createAzure', + supportsImageGeneration: true }, { id: 'bedrock', name: 'Amazon Bedrock', import: () => import('@ai-sdk/amazon-bedrock'), - creatorFunctionName: 'createAmazonBedrock' + creatorFunctionName: 'createAmazonBedrock', + supportsImageGeneration: false }, { id: 'cohere', name: 'Cohere', import: () => import('@ai-sdk/cohere'), - creatorFunctionName: 'createCohere' + creatorFunctionName: 'createCohere', + supportsImageGeneration: false }, { id: 'groq', name: 'Groq', import: () => import('@ai-sdk/groq'), - creatorFunctionName: 'createGroq' + creatorFunctionName: 'createGroq', + supportsImageGeneration: false }, { id: 'together', name: 'Together.ai', import: () => import('@ai-sdk/togetherai'), - creatorFunctionName: 'createTogetherAI' + creatorFunctionName: 'createTogetherAI', + supportsImageGeneration: true }, { id: 'fireworks', name: 'Fireworks', import: () => import('@ai-sdk/fireworks'), - creatorFunctionName: 'createFireworks' + creatorFunctionName: 'createFireworks', + supportsImageGeneration: true }, { id: 'deepseek', name: 'DeepSeek', import: () => import('@ai-sdk/deepseek'), - creatorFunctionName: 'createDeepSeek' + creatorFunctionName: 'createDeepSeek', + supportsImageGeneration: false }, { id: 'cerebras', name: 'Cerebras', import: () => import('@ai-sdk/cerebras'), - creatorFunctionName: 'createCerebras' + creatorFunctionName: 'createCerebras', + supportsImageGeneration: false }, { id: 'deepinfra', name: 'DeepInfra', import: () => import('@ai-sdk/deepinfra'), - creatorFunctionName: 'createDeepInfra' + creatorFunctionName: 'createDeepInfra', + supportsImageGeneration: false }, { id: 'replicate', name: 'Replicate', import: () => import('@ai-sdk/replicate'), - creatorFunctionName: 'createReplicate' + creatorFunctionName: 'createReplicate', + supportsImageGeneration: true }, { id: 'perplexity', name: 'Perplexity', import: () => import('@ai-sdk/perplexity'), - creatorFunctionName: 'createPerplexity' + creatorFunctionName: 'createPerplexity', + supportsImageGeneration: false }, { id: 'fal', name: 'Fal AI', import: () => import('@ai-sdk/fal'), - creatorFunctionName: 'createFal' + creatorFunctionName: 'createFal', + supportsImageGeneration: false }, { id: 'vercel', name: 'Vercel', import: () => import('@ai-sdk/vercel'), - creatorFunctionName: 'createVercel' + creatorFunctionName: 'createVercel', + supportsImageGeneration: false } ] @@ -167,31 +192,36 @@ export class AiProviderRegistry { id: 'ollama', name: 'Ollama', import: () => import('ollama-ai-provider'), - creatorFunctionName: 'createOllama' + creatorFunctionName: 'createOllama', + supportsImageGeneration: false }, { id: 'qwen', name: 'Qwen', import: () => import('qwen-ai-provider'), - creatorFunctionName: 'createQwen' + creatorFunctionName: 'createQwen', + supportsImageGeneration: false }, { id: 'zhipu', name: 'Zhipu AI', import: () => import('zhipu-ai-provider'), - creatorFunctionName: 'createZhipu' + creatorFunctionName: 'createZhipu', + supportsImageGeneration: false }, { id: 'anthropic-vertex', name: 'Anthropic Vertex AI', import: () => import('anthropic-vertex-ai'), - creatorFunctionName: 'createAnthropicVertex' + creatorFunctionName: 'createAnthropicVertex', + supportsImageGeneration: false }, { id: 'openrouter', name: 'OpenRouter', import: () => import('@openrouter/ai-sdk-provider'), - creatorFunctionName: 'createOpenRouter' + creatorFunctionName: 'createOpenRouter', + supportsImageGeneration: false } ]