feat: enhance AI Core with image generation capabilities

- Introduced `createImageClient` method in `ApiClientFactory` to support image generation for various providers.
- Updated `UniversalAiSdkClient` to include `generateImage` method, allowing image generation through the unified client interface.
- Refactored client creation functions to utilize the new `ProviderOptions` type for improved type safety.
- Enhanced the provider registry to indicate which providers support image generation, streamlining client creation and usage.
- Updated type definitions in `types.ts` to reflect changes in client options and middleware support.
This commit is contained in:
suyao 2025-06-17 22:25:33 +08:00 committed by MyPrototypeWhat
parent ed2363e561
commit 2df1cddb43
6 changed files with 156 additions and 33 deletions

View File

@ -3,14 +3,15 @@
* API客户端工厂 * API客户端工厂
*/ */
import type { LanguageModelV1 } from 'ai' import type { ImageModelV1 } from '@ai-sdk/provider'
import { type LanguageModelV1, wrapLanguageModel } from 'ai'
import { aiProviderRegistry } from '../providers/registry' import { aiProviderRegistry } from '../providers/registry'
// 客户端配置接口 // 客户端配置接口
export interface ClientConfig { export interface ClientConfig {
providerId: string providerId: string
options?: any options?: ProviderOptions
} }
// 错误类型 // 错误类型
@ -68,7 +69,17 @@ export class ApiClientFactory {
// 返回模型实例 // 返回模型实例
if (typeof provider === 'function') { if (typeof provider === 'function') {
return provider(modelId) let model = provider(modelId)
// 应用 AI SDK 中间件
if (providerConfig.aiSdkMiddlewares) {
model = wrapLanguageModel({
model: model,
middleware: providerConfig.aiSdkMiddlewares
})
}
return model
} else { } else {
throw new ClientFactoryError(`Unknown model access pattern for provider "${providerId}"`) throw new ClientFactoryError(`Unknown model access pattern for provider "${providerId}"`)
} }
@ -84,6 +95,54 @@ export class ApiClientFactory {
} }
} }
static async createImageClient(
providerId: string,
modelId: string = 'default',
options: ProviderOptions
): Promise<ImageModelV1> {
try {
if (!aiProviderRegistry.isSupported(providerId)) {
throw new ClientFactoryError(`Provider "${providerId}" is not supported`, providerId)
}
const providerConfig = aiProviderRegistry.getProvider(providerId)
if (!providerConfig) {
throw new ClientFactoryError(`Provider "${providerId}" is not registered`, providerId)
}
if (!providerConfig.supportsImageGeneration) {
throw new ClientFactoryError(`Provider "${providerId}" does not support image generation`, providerId)
}
const module = await providerConfig.import()
const creatorFunction = module[providerConfig.creatorFunctionName]
if (typeof creatorFunction !== 'function') {
throw new ClientFactoryError(
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${providerId}"`
)
}
const provider = creatorFunction(options)
if (provider && typeof provider.image === 'function') {
return provider.image(modelId)
} else {
throw new ClientFactoryError(`Image model function not found for provider "${providerId}"`)
}
} catch (error) {
if (error instanceof ClientFactoryError) {
throw error
}
throw new ClientFactoryError(
`Failed to create image client for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
providerId,
error instanceof Error ? error : undefined
)
}
}
/** /**
* Providers * Providers
*/ */
@ -121,3 +180,6 @@ export const createClient = (providerId: string, modelId?: string, options?: any
export const getSupportedProviders = () => ApiClientFactory.getSupportedProviders() export const getSupportedProviders = () => ApiClientFactory.getSupportedProviders()
export const getClientInfo = (providerId: string) => ApiClientFactory.getClientInfo(providerId) export const getClientInfo = (providerId: string) => ApiClientFactory.getClientInfo(providerId)
export const createImageClient = (providerId: string, modelId?: string, options?: any) =>
ApiClientFactory.createImageClient(providerId, modelId, options)

View File

@ -3,7 +3,7 @@
* AI SDK客户端实现 * AI SDK客户端实现
*/ */
import { generateObject, generateText, streamObject, streamText } from 'ai' import { experimental_generateImage as generateImage, generateObject, generateText, streamObject, streamText } from 'ai'
import { ApiClientFactory } from './ApiClientFactory' import { ApiClientFactory } from './ApiClientFactory'
@ -65,6 +65,17 @@ export class UniversalAiSdkClient {
}) })
} }
async generateImage(
modelId: string,
params: Omit<Parameters<typeof generateImage>[0], 'model'>
): Promise<ReturnType<typeof generateImage>> {
const model = await ApiClientFactory.createImageClient(this.providerId, modelId, this.options)
return generateImage({
model,
...params
})
}
/** /**
* *
*/ */
@ -76,6 +87,6 @@ export class UniversalAiSdkClient {
/** /**
* *
*/ */
export function createUniversalClient(providerId: string, options: any = {}): UniversalAiSdkClient { export function createUniversalClient(providerId: string, options: ProviderOptions): UniversalAiSdkClient {
return new UniversalAiSdkClient(providerId, options) return new UniversalAiSdkClient(providerId, options)
} }

View File

@ -0,0 +1,6 @@
export type ProviderOptions = {
apiKey?: string
baseURL?: string
apiVersion?: string
headers?: Record<string, string | unknown>
}

View File

@ -5,6 +5,7 @@
// 导入内部使用的类和函数 // 导入内部使用的类和函数
import { ApiClientFactory } from './clients/ApiClientFactory' import { ApiClientFactory } from './clients/ApiClientFactory'
import { ProviderOptions } from './clients/types'
import { createUniversalClient } from './clients/UniversalAiSdkClient' import { createUniversalClient } from './clients/UniversalAiSdkClient'
import { aiProviderRegistry, isProviderSupported } from './providers/registry' import { aiProviderRegistry, isProviderSupported } from './providers/registry'
@ -61,19 +62,23 @@ export const AiCore = {
} }
// 便捷的预配置clients创建函数 // 便捷的预配置clients创建函数
export const createOpenAIClient = (options: { apiKey: string; baseURL?: string }) => { export const createOpenAIClient = (options: ProviderOptions) => {
return createUniversalClient('openai', options) return createUniversalClient('openai', options)
} }
export const createAnthropicClient = (options: { apiKey: string; baseURL?: string }) => { export const createOpenAICompatibleClient = (options: ProviderOptions) => {
return createUniversalClient('openai-compatible', options)
}
export const createAnthropicClient = (options: ProviderOptions) => {
return createUniversalClient('anthropic', options) return createUniversalClient('anthropic', options)
} }
export const createGoogleClient = (options: { apiKey: string; baseURL?: string }) => { export const createGoogleClient = (options: ProviderOptions) => {
return createUniversalClient('google', options) return createUniversalClient('google', options)
} }
export const createXAIClient = (options: { apiKey: string; baseURL?: string }) => { export const createXAIClient = (options: ProviderOptions) => {
return createUniversalClient('xai', options) return createUniversalClient('xai', options)
} }

View File

@ -1,4 +1,4 @@
import type { TextStreamPart, ToolSet } from 'ai' import type { LanguageModelV1Middleware, TextStreamPart, ToolSet } from 'ai'
/** /**
* AI * AI
@ -37,6 +37,9 @@ export interface AiPlugin {
tools?: TOOLS tools?: TOOLS
stopStream: () => void stopStream: () => void
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>> }) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
// AI SDK 原生中间件
aiSdkMiddlewares?: LanguageModelV1Middleware[]
} }
/** /**

View File

@ -1,3 +1,5 @@
import type { LanguageModelV1Middleware } from 'ai'
/** /**
* AI Provider * AI Provider
* AI SDK Providers * AI SDK Providers
@ -11,6 +13,10 @@ export interface ProviderConfig {
import: () => Promise<any> import: () => Promise<any>
// 创建函数名称 // 创建函数名称
creatorFunctionName: string creatorFunctionName: string
// 是否支持图片生成
supportsImageGeneration?: boolean
// AI SDK 原生中间件
aiSdkMiddlewares?: LanguageModelV1Middleware[]
} }
/** /**
@ -43,115 +49,134 @@ export class AiProviderRegistry {
id: 'openai', id: 'openai',
name: 'OpenAI', name: 'OpenAI',
import: () => import('@ai-sdk/openai'), import: () => import('@ai-sdk/openai'),
creatorFunctionName: 'createOpenAI' creatorFunctionName: 'createOpenAI',
supportsImageGeneration: true
}, },
{ {
id: 'anthropic', id: 'anthropic',
name: 'Anthropic', name: 'Anthropic',
import: () => import('@ai-sdk/anthropic'), import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic' creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false
}, },
{ {
id: 'google', id: 'google',
name: 'Google Generative AI', name: 'Google Generative AI',
import: () => import('@ai-sdk/google'), import: () => import('@ai-sdk/google'),
creatorFunctionName: 'createGoogleGenerativeAI' creatorFunctionName: 'createGoogleGenerativeAI',
supportsImageGeneration: true
}, },
{ {
id: 'google-vertex', id: 'google-vertex',
name: 'Google Vertex AI', name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex'), import: () => import('@ai-sdk/google-vertex'),
creatorFunctionName: 'createVertex' creatorFunctionName: 'createVertex',
supportsImageGeneration: true
}, },
{ {
id: 'mistral', id: 'mistral',
name: 'Mistral AI', name: 'Mistral AI',
import: () => import('@ai-sdk/mistral'), import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral' creatorFunctionName: 'createMistral',
supportsImageGeneration: false
}, },
{ {
id: 'xai', id: 'xai',
name: 'xAI (Grok)', name: 'xAI (Grok)',
import: () => import('@ai-sdk/xai'), import: () => import('@ai-sdk/xai'),
creatorFunctionName: 'createXai' creatorFunctionName: 'createXai',
supportsImageGeneration: true
}, },
{ {
id: 'azure', id: 'azure',
name: 'Azure OpenAI', name: 'Azure OpenAI',
import: () => import('@ai-sdk/azure'), import: () => import('@ai-sdk/azure'),
creatorFunctionName: 'createAzure' creatorFunctionName: 'createAzure',
supportsImageGeneration: true
}, },
{ {
id: 'bedrock', id: 'bedrock',
name: 'Amazon Bedrock', name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'), import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock' creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: false
}, },
{ {
id: 'cohere', id: 'cohere',
name: 'Cohere', name: 'Cohere',
import: () => import('@ai-sdk/cohere'), import: () => import('@ai-sdk/cohere'),
creatorFunctionName: 'createCohere' creatorFunctionName: 'createCohere',
supportsImageGeneration: false
}, },
{ {
id: 'groq', id: 'groq',
name: 'Groq', name: 'Groq',
import: () => import('@ai-sdk/groq'), import: () => import('@ai-sdk/groq'),
creatorFunctionName: 'createGroq' creatorFunctionName: 'createGroq',
supportsImageGeneration: false
}, },
{ {
id: 'together', id: 'together',
name: 'Together.ai', name: 'Together.ai',
import: () => import('@ai-sdk/togetherai'), import: () => import('@ai-sdk/togetherai'),
creatorFunctionName: 'createTogetherAI' creatorFunctionName: 'createTogetherAI',
supportsImageGeneration: true
}, },
{ {
id: 'fireworks', id: 'fireworks',
name: 'Fireworks', name: 'Fireworks',
import: () => import('@ai-sdk/fireworks'), import: () => import('@ai-sdk/fireworks'),
creatorFunctionName: 'createFireworks' creatorFunctionName: 'createFireworks',
supportsImageGeneration: true
}, },
{ {
id: 'deepseek', id: 'deepseek',
name: 'DeepSeek', name: 'DeepSeek',
import: () => import('@ai-sdk/deepseek'), import: () => import('@ai-sdk/deepseek'),
creatorFunctionName: 'createDeepSeek' creatorFunctionName: 'createDeepSeek',
supportsImageGeneration: false
}, },
{ {
id: 'cerebras', id: 'cerebras',
name: 'Cerebras', name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'), import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras' creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
}, },
{ {
id: 'deepinfra', id: 'deepinfra',
name: 'DeepInfra', name: 'DeepInfra',
import: () => import('@ai-sdk/deepinfra'), import: () => import('@ai-sdk/deepinfra'),
creatorFunctionName: 'createDeepInfra' creatorFunctionName: 'createDeepInfra',
supportsImageGeneration: false
}, },
{ {
id: 'replicate', id: 'replicate',
name: 'Replicate', name: 'Replicate',
import: () => import('@ai-sdk/replicate'), import: () => import('@ai-sdk/replicate'),
creatorFunctionName: 'createReplicate' creatorFunctionName: 'createReplicate',
supportsImageGeneration: true
}, },
{ {
id: 'perplexity', id: 'perplexity',
name: 'Perplexity', name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'), import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity' creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false
}, },
{ {
id: 'fal', id: 'fal',
name: 'Fal AI', name: 'Fal AI',
import: () => import('@ai-sdk/fal'), import: () => import('@ai-sdk/fal'),
creatorFunctionName: 'createFal' creatorFunctionName: 'createFal',
supportsImageGeneration: false
}, },
{ {
id: 'vercel', id: 'vercel',
name: 'Vercel', name: 'Vercel',
import: () => import('@ai-sdk/vercel'), import: () => import('@ai-sdk/vercel'),
creatorFunctionName: 'createVercel' creatorFunctionName: 'createVercel',
supportsImageGeneration: false
} }
] ]
@ -161,25 +186,36 @@ export class AiProviderRegistry {
id: 'ollama', id: 'ollama',
name: 'Ollama', name: 'Ollama',
import: () => import('ollama-ai-provider'), import: () => import('ollama-ai-provider'),
creatorFunctionName: 'createOllama' creatorFunctionName: 'createOllama',
supportsImageGeneration: false
}, },
{ {
id: 'qwen', id: 'qwen',
name: 'Qwen', name: 'Qwen',
import: () => import('qwen-ai-provider'), import: () => import('qwen-ai-provider'),
creatorFunctionName: 'createQwen' creatorFunctionName: 'createQwen',
supportsImageGeneration: false
}, },
{ {
id: 'zhipu', id: 'zhipu',
name: 'Zhipu AI', name: 'Zhipu AI',
import: () => import('zhipu-ai-provider'), import: () => import('zhipu-ai-provider'),
creatorFunctionName: 'createZhipu' creatorFunctionName: 'createZhipu',
supportsImageGeneration: false
}, },
{ {
id: 'anthropic-vertex', id: 'anthropic-vertex',
name: 'Anthropic Vertex AI', name: 'Anthropic Vertex AI',
import: () => import('anthropic-vertex-ai'), import: () => import('anthropic-vertex-ai'),
creatorFunctionName: 'createAnthropicVertex' creatorFunctionName: 'createAnthropicVertex',
supportsImageGeneration: false
},
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: false
} }
] ]