feat: enhance AI Core with image generation capabilities

- Introduced `createImageClient` method in `ApiClientFactory` to support image generation for various providers.
- Updated `UniversalAiSdkClient` to include `generateImage` method, allowing image generation through the unified client interface.
- Refactored client creation functions to utilize the new `ProviderOptions` type for improved type safety.
- Enhanced the provider registry to indicate which providers support image generation, streamlining client creation and usage.
- Updated type definitions in `types.ts` to reflect changes in client options and middleware support.
This commit is contained in:
suyao 2025-06-17 22:25:33 +08:00 committed by MyPrototypeWhat
parent ed2363e561
commit 2df1cddb43
6 changed files with 156 additions and 33 deletions

View File

@ -3,14 +3,15 @@
* API客户端工厂
*/
import type { LanguageModelV1 } from 'ai'
import type { ImageModelV1 } from '@ai-sdk/provider'
import { type LanguageModelV1, wrapLanguageModel } from 'ai'
import { aiProviderRegistry } from '../providers/registry'
// 客户端配置接口
export interface ClientConfig {
providerId: string
options?: any
options?: ProviderOptions
}
// 错误类型
@ -68,7 +69,17 @@ export class ApiClientFactory {
// 返回模型实例
if (typeof provider === 'function') {
return provider(modelId)
let model = provider(modelId)
// 应用 AI SDK 中间件
if (providerConfig.aiSdkMiddlewares) {
model = wrapLanguageModel({
model: model,
middleware: providerConfig.aiSdkMiddlewares
})
}
return model
} else {
throw new ClientFactoryError(`Unknown model access pattern for provider "${providerId}"`)
}
@ -84,6 +95,54 @@ export class ApiClientFactory {
}
}
static async createImageClient(
providerId: string,
modelId: string = 'default',
options: ProviderOptions
): Promise<ImageModelV1> {
try {
if (!aiProviderRegistry.isSupported(providerId)) {
throw new ClientFactoryError(`Provider "${providerId}" is not supported`, providerId)
}
const providerConfig = aiProviderRegistry.getProvider(providerId)
if (!providerConfig) {
throw new ClientFactoryError(`Provider "${providerId}" is not registered`, providerId)
}
if (!providerConfig.supportsImageGeneration) {
throw new ClientFactoryError(`Provider "${providerId}" does not support image generation`, providerId)
}
const module = await providerConfig.import()
const creatorFunction = module[providerConfig.creatorFunctionName]
if (typeof creatorFunction !== 'function') {
throw new ClientFactoryError(
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${providerId}"`
)
}
const provider = creatorFunction(options)
if (provider && typeof provider.image === 'function') {
return provider.image(modelId)
} else {
throw new ClientFactoryError(`Image model function not found for provider "${providerId}"`)
}
} catch (error) {
if (error instanceof ClientFactoryError) {
throw error
}
throw new ClientFactoryError(
`Failed to create image client for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
providerId,
error instanceof Error ? error : undefined
)
}
}
/**
* Providers
*/
@ -121,3 +180,6 @@ export const createClient = (providerId: string, modelId?: string, options?: any
export const getSupportedProviders = () => ApiClientFactory.getSupportedProviders()
export const getClientInfo = (providerId: string) => ApiClientFactory.getClientInfo(providerId)
export const createImageClient = (providerId: string, modelId?: string, options?: any) =>
ApiClientFactory.createImageClient(providerId, modelId, options)

View File

@ -3,7 +3,7 @@
* AI SDK客户端实现
*/
import { generateObject, generateText, streamObject, streamText } from 'ai'
import { experimental_generateImage as generateImage, generateObject, generateText, streamObject, streamText } from 'ai'
import { ApiClientFactory } from './ApiClientFactory'
@ -65,6 +65,17 @@ export class UniversalAiSdkClient {
})
}
async generateImage(
modelId: string,
params: Omit<Parameters<typeof generateImage>[0], 'model'>
): Promise<ReturnType<typeof generateImage>> {
const model = await ApiClientFactory.createImageClient(this.providerId, modelId, this.options)
return generateImage({
model,
...params
})
}
/**
*
*/
@ -76,6 +87,6 @@ export class UniversalAiSdkClient {
/**
*
*/
export function createUniversalClient(providerId: string, options: any = {}): UniversalAiSdkClient {
export function createUniversalClient(providerId: string, options: ProviderOptions): UniversalAiSdkClient {
return new UniversalAiSdkClient(providerId, options)
}

View File

@ -0,0 +1,6 @@
export type ProviderOptions = {
apiKey?: string
baseURL?: string
apiVersion?: string
headers?: Record<string, string | unknown>
}

View File

@ -5,6 +5,7 @@
// 导入内部使用的类和函数
import { ApiClientFactory } from './clients/ApiClientFactory'
import { ProviderOptions } from './clients/types'
import { createUniversalClient } from './clients/UniversalAiSdkClient'
import { aiProviderRegistry, isProviderSupported } from './providers/registry'
@ -61,19 +62,23 @@ export const AiCore = {
}
// 便捷的预配置clients创建函数
export const createOpenAIClient = (options: { apiKey: string; baseURL?: string }) => {
export const createOpenAIClient = (options: ProviderOptions) => {
return createUniversalClient('openai', options)
}
export const createAnthropicClient = (options: { apiKey: string; baseURL?: string }) => {
export const createOpenAICompatibleClient = (options: ProviderOptions) => {
return createUniversalClient('openai-compatible', options)
}
export const createAnthropicClient = (options: ProviderOptions) => {
return createUniversalClient('anthropic', options)
}
export const createGoogleClient = (options: { apiKey: string; baseURL?: string }) => {
export const createGoogleClient = (options: ProviderOptions) => {
return createUniversalClient('google', options)
}
export const createXAIClient = (options: { apiKey: string; baseURL?: string }) => {
export const createXAIClient = (options: ProviderOptions) => {
return createUniversalClient('xai', options)
}

View File

@ -1,4 +1,4 @@
import type { TextStreamPart, ToolSet } from 'ai'
import type { LanguageModelV1Middleware, TextStreamPart, ToolSet } from 'ai'
/**
* AI
@ -37,6 +37,9 @@ export interface AiPlugin {
tools?: TOOLS
stopStream: () => void
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
// AI SDK 原生中间件
aiSdkMiddlewares?: LanguageModelV1Middleware[]
}
/**

View File

@ -1,3 +1,5 @@
import type { LanguageModelV1Middleware } from 'ai'
/**
* AI Provider
* AI SDK Providers
@ -11,6 +13,10 @@ export interface ProviderConfig {
import: () => Promise<any>
// 创建函数名称
creatorFunctionName: string
// 是否支持图片生成
supportsImageGeneration?: boolean
// AI SDK 原生中间件
aiSdkMiddlewares?: LanguageModelV1Middleware[]
}
/**
@ -43,115 +49,134 @@ export class AiProviderRegistry {
id: 'openai',
name: 'OpenAI',
import: () => import('@ai-sdk/openai'),
creatorFunctionName: 'createOpenAI'
creatorFunctionName: 'createOpenAI',
supportsImageGeneration: true
},
{
id: 'anthropic',
name: 'Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic'
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false
},
{
id: 'google',
name: 'Google Generative AI',
import: () => import('@ai-sdk/google'),
creatorFunctionName: 'createGoogleGenerativeAI'
creatorFunctionName: 'createGoogleGenerativeAI',
supportsImageGeneration: true
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex'),
creatorFunctionName: 'createVertex'
creatorFunctionName: 'createVertex',
supportsImageGeneration: true
},
{
id: 'mistral',
name: 'Mistral AI',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral'
creatorFunctionName: 'createMistral',
supportsImageGeneration: false
},
{
id: 'xai',
name: 'xAI (Grok)',
import: () => import('@ai-sdk/xai'),
creatorFunctionName: 'createXai'
creatorFunctionName: 'createXai',
supportsImageGeneration: true
},
{
id: 'azure',
name: 'Azure OpenAI',
import: () => import('@ai-sdk/azure'),
creatorFunctionName: 'createAzure'
creatorFunctionName: 'createAzure',
supportsImageGeneration: true
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock'
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: false
},
{
id: 'cohere',
name: 'Cohere',
import: () => import('@ai-sdk/cohere'),
creatorFunctionName: 'createCohere'
creatorFunctionName: 'createCohere',
supportsImageGeneration: false
},
{
id: 'groq',
name: 'Groq',
import: () => import('@ai-sdk/groq'),
creatorFunctionName: 'createGroq'
creatorFunctionName: 'createGroq',
supportsImageGeneration: false
},
{
id: 'together',
name: 'Together.ai',
import: () => import('@ai-sdk/togetherai'),
creatorFunctionName: 'createTogetherAI'
creatorFunctionName: 'createTogetherAI',
supportsImageGeneration: true
},
{
id: 'fireworks',
name: 'Fireworks',
import: () => import('@ai-sdk/fireworks'),
creatorFunctionName: 'createFireworks'
creatorFunctionName: 'createFireworks',
supportsImageGeneration: true
},
{
id: 'deepseek',
name: 'DeepSeek',
import: () => import('@ai-sdk/deepseek'),
creatorFunctionName: 'createDeepSeek'
creatorFunctionName: 'createDeepSeek',
supportsImageGeneration: false
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras'
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
},
{
id: 'deepinfra',
name: 'DeepInfra',
import: () => import('@ai-sdk/deepinfra'),
creatorFunctionName: 'createDeepInfra'
creatorFunctionName: 'createDeepInfra',
supportsImageGeneration: false
},
{
id: 'replicate',
name: 'Replicate',
import: () => import('@ai-sdk/replicate'),
creatorFunctionName: 'createReplicate'
creatorFunctionName: 'createReplicate',
supportsImageGeneration: true
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity'
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false
},
{
id: 'fal',
name: 'Fal AI',
import: () => import('@ai-sdk/fal'),
creatorFunctionName: 'createFal'
creatorFunctionName: 'createFal',
supportsImageGeneration: false
},
{
id: 'vercel',
name: 'Vercel',
import: () => import('@ai-sdk/vercel'),
creatorFunctionName: 'createVercel'
creatorFunctionName: 'createVercel',
supportsImageGeneration: false
}
]
@ -161,25 +186,36 @@ export class AiProviderRegistry {
id: 'ollama',
name: 'Ollama',
import: () => import('ollama-ai-provider'),
creatorFunctionName: 'createOllama'
creatorFunctionName: 'createOllama',
supportsImageGeneration: false
},
{
id: 'qwen',
name: 'Qwen',
import: () => import('qwen-ai-provider'),
creatorFunctionName: 'createQwen'
creatorFunctionName: 'createQwen',
supportsImageGeneration: false
},
{
id: 'zhipu',
name: 'Zhipu AI',
import: () => import('zhipu-ai-provider'),
creatorFunctionName: 'createZhipu'
creatorFunctionName: 'createZhipu',
supportsImageGeneration: false
},
{
id: 'anthropic-vertex',
name: 'Anthropic Vertex AI',
import: () => import('anthropic-vertex-ai'),
creatorFunctionName: 'createAnthropicVertex'
creatorFunctionName: 'createAnthropicVertex',
supportsImageGeneration: false
},
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: false
}
]