From 84eef25ff9501b648a815fdb0b2a55a08ee5ddd1 Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat Date: Tue, 26 Aug 2025 16:17:01 +0800 Subject: [PATCH] feat(aiCore): enhance dynamic provider registration and refactor HubProvider - Introduced dynamic provider registration functionality, allowing for flexible management of providers through a new registry system. - Refactored HubProvider to streamline model resolution and improve error handling for unsupported models. - Added utility functions for managing dynamic providers, including registration, cleanup, and alias resolution. - Updated index exports to include new dynamic provider APIs, enhancing overall usability and integration. - Removed outdated provider files and simplified the provider management structure for better maintainability. --- .../aiCore/src/core/providers/HubProvider.ts | 92 +++--------- .../src/core/providers/RegistryManagement.ts | 78 +++++++++- packages/aiCore/src/core/providers/index.ts | 14 ++ .../aiCore/src/core/providers/registry.ts | 100 ++++++++++++- packages/aiCore/src/core/runtime/executor.ts | 6 +- packages/aiCore/src/index.ts | 14 ++ .../src/aiCore/chunk/AiSdkToChunkAdapter.ts | 6 + .../legacy/clients/aws/AwsBedrockAPIClient.ts | 2 +- .../provider/ProviderConfigProcessor.ts | 135 ++++++++++-------- src/renderer/src/aiCore/provider/aihubmix.ts | 56 -------- .../src/aiCore/provider/config/aihubmix.ts | 57 ++++++++ .../src/aiCore/provider/config/helper.ts | 22 +++ .../src/aiCore/provider/config/index.ts | 16 +++ .../src/aiCore/provider/config/newApi.ts | 52 +++++++ .../src/aiCore/provider/config/types.ts | 7 + src/renderer/src/aiCore/provider/factory.ts | 102 ++++++++----- .../src/aiCore/provider/providerConfigs.ts | 42 +++--- src/renderer/src/aiCore/utils/options.ts | 6 +- 18 files changed, 551 insertions(+), 256 deletions(-) delete mode 100644 src/renderer/src/aiCore/provider/aihubmix.ts create mode 100644 src/renderer/src/aiCore/provider/config/aihubmix.ts create mode 100644 src/renderer/src/aiCore/provider/config/helper.ts create mode 100644 src/renderer/src/aiCore/provider/config/index.ts create mode 100644 src/renderer/src/aiCore/provider/config/newApi.ts create mode 100644 src/renderer/src/aiCore/provider/config/types.ts diff --git a/packages/aiCore/src/core/providers/HubProvider.ts b/packages/aiCore/src/core/providers/HubProvider.ts index a5fd77888f..22ec1df082 100644 --- a/packages/aiCore/src/core/providers/HubProvider.ts +++ b/packages/aiCore/src/core/providers/HubProvider.ts @@ -5,7 +5,7 @@ * 例如: aihubmix:anthropic:claude-3.5-sonnet */ -import { ProviderV2 } from '@ai-sdk/provider' +import { EmbeddingModelV2, ImageModelV2, ProviderV2, SpeechModelV2, TranscriptionModelV2 } from '@ai-sdk/provider' import { customProvider } from 'ai' import { globalRegistryManagement } from './RegistryManagement' @@ -47,13 +47,7 @@ function parseHubModelId(modelId: string): { provider: string; actualModelId: st * 创建Hub Provider */ export function createHubProvider(config: HubProviderConfig): ProviderV2 { - const { hubId, debug = false } = config - - function logDebug(message: string, ...args: any[]) { - if (debug) { - console.log(`[HubProvider:${hubId}] ${message}`, ...args) - } - } + const { hubId } = config function getTargetProvider(providerId: string): ProviderV2 { // 从全局注册表获取provider实例 @@ -77,72 +71,26 @@ export function createHubProvider(config: HubProviderConfig): ProviderV2 { } } + function resolveModel(modelId: string, modelType: string, methodName: keyof ProviderV2): T { + const { provider, actualModelId } = parseHubModelId(modelId) + const targetProvider = getTargetProvider(provider) + + if (!targetProvider[methodName]) { + throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider) + } + + return (targetProvider[methodName] as any)(actualModelId) + } + return customProvider({ fallbackProvider: { - languageModel: (modelId: string) => { - logDebug('Resolving language model:', modelId) - - const { provider, actualModelId } = parseHubModelId(modelId) - const targetProvider = getTargetProvider(provider) - - if (!targetProvider.languageModel) { - throw new HubProviderError(`Provider "${provider}" does not support language models`, hubId, provider) - } - - return targetProvider.languageModel(actualModelId) - }, - - textEmbeddingModel: (modelId: string) => { - logDebug('Resolving text embedding model:', modelId) - - const { provider, actualModelId } = parseHubModelId(modelId) - const targetProvider = getTargetProvider(provider) - - if (!targetProvider.textEmbeddingModel) { - throw new HubProviderError(`Provider "${provider}" does not support text embedding models`, hubId, provider) - } - - return targetProvider.textEmbeddingModel(actualModelId) - }, - - imageModel: (modelId: string) => { - logDebug('Resolving image model:', modelId) - - const { provider, actualModelId } = parseHubModelId(modelId) - const targetProvider = getTargetProvider(provider) - - if (!targetProvider.imageModel) { - throw new HubProviderError(`Provider "${provider}" does not support image models`, hubId, provider) - } - - return targetProvider.imageModel(actualModelId) - }, - - transcriptionModel: (modelId: string) => { - logDebug('Resolving transcription model:', modelId) - - const { provider, actualModelId } = parseHubModelId(modelId) - const targetProvider = getTargetProvider(provider) - - if (!targetProvider.transcriptionModel) { - throw new HubProviderError(`Provider "${provider}" does not support transcription models`, hubId, provider) - } - - return targetProvider.transcriptionModel(actualModelId) - }, - - speechModel: (modelId: string) => { - logDebug('Resolving speech model:', modelId) - - const { provider, actualModelId } = parseHubModelId(modelId) - const targetProvider = getTargetProvider(provider) - - if (!targetProvider.speechModel) { - throw new HubProviderError(`Provider "${provider}" does not support speech models`, hubId, provider) - } - - return targetProvider.speechModel(actualModelId) - } + languageModel: (modelId: string) => resolveModel(modelId, 'language models', 'languageModel'), + textEmbeddingModel: (modelId: string) => + resolveModel>(modelId, 'text embedding models', 'textEmbeddingModel'), + imageModel: (modelId: string) => resolveModel(modelId, 'image models', 'imageModel'), + transcriptionModel: (modelId: string) => + resolveModel(modelId, 'transcription models', 'transcriptionModel'), + speechModel: (modelId: string) => resolveModel(modelId, 'speech models', 'speechModel') } }) } diff --git a/packages/aiCore/src/core/providers/RegistryManagement.ts b/packages/aiCore/src/core/providers/RegistryManagement.ts index bbde41e3e7..1a454ad625 100644 --- a/packages/aiCore/src/core/providers/RegistryManagement.ts +++ b/packages/aiCore/src/core/providers/RegistryManagement.ts @@ -15,6 +15,7 @@ export const DEFAULT_SEPARATOR = ':' export class RegistryManagement { private providers: PROVIDERS = {} + private aliases: Set = new Set() // 记录哪些key是别名 private separator: SEPARATOR private registry: ProviderRegistryProvider | null = null @@ -25,8 +26,18 @@ export class RegistryManagement { + this.providers[alias] = provider // 直接存储引用 + this.aliases.add(alias) // 标记为别名 + }) + } + this.rebuildRegistry() return this } @@ -48,9 +59,31 @@ export class RegistryManagement { + if (this.providers[alias] === provider) { + aliasesToRemove.push(alias) + } + }) + + aliasesToRemove.forEach((alias) => { + delete this.providers[alias] + this.aliases.delete(alias) + }) + } else { + // 如果移除的是别名,只删除别名记录 + this.aliases.delete(id) + } + delete this.providers[id] this.rebuildRegistry() return this @@ -121,10 +154,10 @@ export class RegistryManagement !this.aliases.has(id)) } /** @@ -139,9 +172,46 @@ export class RegistryManagement { + const result: Record = {} + this.aliases.forEach((alias) => { + result[alias] = this.resolveProviderId(alias) + }) + return result + } } /** diff --git a/packages/aiCore/src/core/providers/index.ts b/packages/aiCore/src/core/providers/index.ts index 779f73e46c..434145234a 100644 --- a/packages/aiCore/src/core/providers/index.ts +++ b/packages/aiCore/src/core/providers/index.ts @@ -28,6 +28,20 @@ export { reinitializeProvider } from './registry' +// 动态Provider注册功能 +export { + cleanup, + getAllAliases, + getAllDynamicMappings, + getDynamicProviders, + getProviderMapping, + isAlias, + isDynamicProvider, + registerDynamicProvider, + registerMultipleProviders, + resolveProviderId +} from './registry' + // ==================== 保留的导出(兼容性)==================== // 基础Provider数据源 diff --git a/packages/aiCore/src/core/providers/registry.ts b/packages/aiCore/src/core/providers/registry.ts index 1ddf360533..e8ac4b6998 100644 --- a/packages/aiCore/src/core/providers/registry.ts +++ b/packages/aiCore/src/core/providers/registry.ts @@ -8,7 +8,7 @@ import { customProvider } from 'ai' import { isOpenAIChatCompletionOnlyModel } from '../../utils/model' import { globalRegistryManagement } from './RegistryManagement' -import { baseProviders } from './schemas' +import { baseProviders, type DynamicProviderRegistration } from './schemas' /** * Provider 初始化错误类型 @@ -227,6 +227,104 @@ export function hasInitializedProviders(): boolean { return globalRegistryManagement.hasProviders() } +// ==================== 动态Provider注册功能 ==================== + +// 全局动态provider存储 +const dynamicProviders = new Map() + +/** + * 注册动态provider + */ +export function registerDynamicProvider(config: DynamicProviderRegistration): boolean { + try { + // 验证配置 + if (!config.id || !config.name) { + return false + } + + // 检查是否与基础provider冲突 + if (baseProviders.find((p) => p.id === config.id)) { + console.warn(`Dynamic provider "${config.id}" conflicts with base provider`) + return false + } + + // 存储动态provider配置 + dynamicProviders.set(config.id, config) + + // 如果有creator函数,立即初始化 + if (config.creator) { + try { + const provider = config.creator({}) as any // 使用空配置初始化,类型断言为any + const aliases = config.mappings ? Object.keys(config.mappings) : undefined + globalRegistryManagement.registerProvider(config.id, provider, aliases) + } catch (error) { + console.error(`Failed to initialize dynamic provider "${config.id}":`, error) + return false + } + } + + return true + } catch (error) { + console.error(`Failed to register dynamic provider:`, error) + return false + } +} + +/** + * 批量注册动态providers + */ +export function registerMultipleProviders(configs: DynamicProviderRegistration[]): number { + let successCount = 0 + configs.forEach((config) => { + if (registerDynamicProvider(config)) { + successCount++ + } + }) + return successCount +} + +/** + * 获取provider映射(解析别名) + */ +export function getProviderMapping(providerId: string): string { + return globalRegistryManagement.resolveProviderId(providerId) +} + +/** + * 检查是否为动态provider + */ +export function isDynamicProvider(providerId: string): boolean { + return dynamicProviders.has(providerId) +} + +/** + * 获取所有动态providers + */ +export function getDynamicProviders(): DynamicProviderRegistration[] { + return Array.from(dynamicProviders.values()) +} + +/** + * 获取所有别名映射关系 + */ +export function getAllDynamicMappings(): Record { + return globalRegistryManagement.getAllAliases() +} + +/** + * 清理所有动态providers + */ +export function cleanup(): void { + dynamicProviders.clear() + globalRegistryManagement.clear() +} + +// ==================== 导出别名相关API ==================== + +export const resolveProviderId = (id: string) => globalRegistryManagement.resolveProviderId(id) +export const isAlias = (id: string) => globalRegistryManagement.isAlias(id) +export const getAllAliases = () => globalRegistryManagement.getAllAliases() + // ==================== 导出错误类型和工具函数 ==================== export { isOpenAIChatCompletionOnlyModel, ProviderInitializationError } diff --git a/packages/aiCore/src/core/runtime/executor.ts b/packages/aiCore/src/core/runtime/executor.ts index dfe0a800f2..e7bd6afa60 100644 --- a/packages/aiCore/src/core/runtime/executor.ts +++ b/packages/aiCore/src/core/runtime/executor.ts @@ -38,7 +38,7 @@ export class RuntimeExecutor { this.pluginEngine = new PluginEngine(config.providerId, config.plugins || []) } - createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) { + private createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) { return definePlugin({ name: '_internal_resolveModel', enforce: 'post', @@ -50,7 +50,7 @@ export class RuntimeExecutor { }) } - createResolveImageModelPlugin() { + private createResolveImageModelPlugin() { return definePlugin({ name: '_internal_resolveImageModel', enforce: 'post', @@ -61,7 +61,7 @@ export class RuntimeExecutor { }) } - createConfigureContextPlugin() { + private createConfigureContextPlugin() { return definePlugin({ name: '_internal_configureContext', configureContext: async (context: AiRequestContext) => { diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index 297af35fe2..3617a4bb9f 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -133,6 +133,20 @@ export { reinitializeProvider } from './core/providers/registry' +// ==================== 动态Provider注册和别名映射 ==================== +export { + cleanup, + getAllAliases, + getAllDynamicMappings, + getDynamicProviders, + getProviderMapping, + isAlias, + isDynamicProvider, + registerDynamicProvider, + registerMultipleProviders, + resolveProviderId +} from './core/providers/registry' + // ==================== Zod Schema 和验证 ==================== export { baseProviderIds, validateProviderId } from './core/providers' diff --git a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts index 57be46f6d7..ef86ce2035 100644 --- a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts +++ b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts @@ -284,6 +284,12 @@ export class AiSdkToChunkAdapter { } }) break + case 'abort': + this.onChunk({ + type: ChunkType.ERROR, + error: new DOMException('Request was aborted', 'AbortError') + }) + break case 'error': this.onChunk({ type: ChunkType.ERROR, diff --git a/src/renderer/src/aiCore/legacy/clients/aws/AwsBedrockAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/aws/AwsBedrockAPIClient.ts index 9eec9e9605..958fd76686 100644 --- a/src/renderer/src/aiCore/legacy/clients/aws/AwsBedrockAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/aws/AwsBedrockAPIClient.ts @@ -6,7 +6,7 @@ import { InvokeModelWithResponseStreamCommand } from '@aws-sdk/client-bedrock-runtime' import { loggerService } from '@logger' -import { GenericChunk } from '@renderer/aiCore/middleware/schemas' +import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { findTokenLimit, isReasoningModel } from '@renderer/config/models' import { diff --git a/src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts b/src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts index b1bb0abc12..343fe92bf8 100644 --- a/src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts +++ b/src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts @@ -1,109 +1,124 @@ import { AiCore, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core' -import { isDedicatedImageGenerationModel } from '@renderer/config/models' import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import { getProviderByModel } from '@renderer/services/AssistantService' import type { Model, Provider } from '@renderer/types' import { formatApiHost } from '@renderer/utils/api' import { cloneDeep } from 'lodash' -import { createAihubmixProvider } from './aihubmix' +import { aihubmixProviderCreator, newApiResolverCreator } from './config' import { getAiSdkProviderId } from './factory' -export function getActualProvider(model: Model): Provider { - const provider = getProviderByModel(model) - // 如果是 vertexai 类型且没有 googleCredentials,转换为 VertexProvider - let actualProvider = cloneDeep(provider) +/** + * 处理特殊provider的转换逻辑 + */ +function handleSpecialProviders(model: Model, provider: Provider): Provider { if (provider.type === 'vertexai' && !isVertexProvider(provider)) { if (!isVertexAIConfigured()) { throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') } - actualProvider = createVertexProvider(provider) + return createVertexProvider(provider) } if (provider.id === 'aihubmix') { - actualProvider = createAihubmixProvider(model, actualProvider) + return aihubmixProviderCreator(model, provider) } - if (actualProvider.type === 'gemini') { - actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta') + if (provider.id === 'newapi') { + return newApiResolverCreator(model, provider) + } + return provider +} + +/** + * 格式化provider的API Host + */ +function formatProviderApiHost(provider: Provider): Provider { + const formatted = { ...provider } + if (formatted.type === 'gemini') { + formatted.apiHost = formatApiHost(formatted.apiHost, 'v1beta') } else { - actualProvider.apiHost = formatApiHost(actualProvider.apiHost) + formatted.apiHost = formatApiHost(formatted.apiHost) } + return formatted +} + +/** + * 获取实际的Provider配置 + * 简化版:将逻辑分解为小函数 + */ +export function getActualProvider(model: Model): Provider { + const baseProvider = getProviderByModel(model) + + // 按顺序处理各种转换 + let actualProvider = cloneDeep(baseProvider) + actualProvider = handleSpecialProviders(model, actualProvider) + actualProvider = formatProviderApiHost(actualProvider) + return actualProvider } /** * 将 Provider 配置转换为新 AI SDK 格式 + * 简化版:利用新的别名映射系统 */ export function providerToAiSdkConfig(actualProvider: Provider): { providerId: ProviderId | 'openai-compatible' options: ProviderSettingsMap[keyof ProviderSettingsMap] } { const aiSdkProviderId = getAiSdkProviderId(actualProvider) - const actualProviderType = actualProvider.type - const openaiResponseOptions = - actualProviderType === 'openai-response' - ? { - mode: 'responses' - } - : aiSdkProviderId === 'openai' - ? { - mode: 'chat' - } - : undefined - console.log('openaiResponseOptions', openaiResponseOptions) - console.log('actualProvider', actualProvider) - console.log('aiSdkProviderId', aiSdkProviderId) - if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { - const options = ProviderConfigFactory.fromProvider( - aiSdkProviderId, - { - baseURL: actualProvider.apiHost, - apiKey: actualProvider.apiKey - }, - { ...openaiResponseOptions, headers: actualProvider.extra_headers } - ) + // 构建基础配置 + const baseConfig = { + baseURL: actualProvider.apiHost, + apiKey: actualProvider.apiKey + } + + // 处理OpenAI模式(简化逻辑) + const extraOptions: any = {} + if (actualProvider.type === 'openai-response') { + extraOptions.mode = 'responses' + } else if (aiSdkProviderId === 'openai') { + extraOptions.mode = 'chat' + } + + // 添加额外headers + if (actualProvider.extra_headers) { + extraOptions.headers = actualProvider.extra_headers + } + + // 如果AI SDK支持该provider,使用原生配置 + if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { + const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions) return { providerId: aiSdkProviderId as ProviderId, options } - } else { - console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`) - const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey) + } - return { - providerId: 'openai-compatible', - options: { - ...options, - name: actualProvider.id - } + // 否则fallback到openai-compatible + const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey) + return { + providerId: 'openai-compatible', + options: { + ...options, + name: actualProvider.id, + ...extraOptions } } } /** * 检查是否支持使用新的AI SDK + * 简化版:利用新的别名映射和动态provider系统 */ -export function isModernSdkSupported(provider: Provider, model?: Model): boolean { - // 目前支持主要的providers - const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai'] - - // 检查provider类型 - if (!supportedProviders.includes(provider.type)) { - return false - } - - // 对于 vertexai,检查配置是否完整 +export function isModernSdkSupported(provider: Provider): boolean { + // 特殊检查:vertexai需要配置完整 if (provider.type === 'vertexai' && !isVertexAIConfigured()) { return false } - // 图像生成模型现在支持新的 AI SDK - // (但需要确保 provider 是支持的 + // 使用getAiSdkProviderId获取映射后的providerId,然后检查AI SDK是否支持 + const aiSdkProviderId = getAiSdkProviderId(provider) - if (model && isDedicatedImageGenerationModel(model)) { - return true - } - - return true + // 如果映射到了支持的provider,则支持现代SDK + return AiCore.isSupported(aiSdkProviderId) } diff --git a/src/renderer/src/aiCore/provider/aihubmix.ts b/src/renderer/src/aiCore/provider/aihubmix.ts deleted file mode 100644 index 14b9a468e0..0000000000 --- a/src/renderer/src/aiCore/provider/aihubmix.ts +++ /dev/null @@ -1,56 +0,0 @@ -import { ProviderId } from '@cherrystudio/ai-core/types' -import { isOpenAIModel } from '@renderer/config/models' -import { Model, Provider } from '@renderer/types' - -export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'openai-compatible' { - console.log('getAiSdkProviderIdForAihubmix', model) - const id = model.id.toLowerCase() - - if (id.startsWith('claude')) { - return 'anthropic' - } - // TODO:暂时注释,不清楚为什么排除,webSearch时会导致gemini模型走openai的逻辑 - if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) { - return 'google' - } - - if (isOpenAIModel(model)) { - return 'openai' - } - - return 'openai-compatible' -} - -export function createAihubmixProvider(model: Model, provider: Provider): Provider { - const providerId = getAiSdkProviderIdForAihubmix(model) - provider = { - ...provider, - extra_headers: { - ...provider.extra_headers, - 'APP-Code': 'MLTG2087' - } - } - if (providerId === 'google') { - return { - ...provider, - type: 'gemini', - apiHost: 'https://aihubmix.com/gemini' - } - } - - if (providerId === 'openai') { - return { - ...provider, - type: 'openai-response' - } - } - - if (providerId === 'anthropic') { - return { - ...provider, - type: 'anthropic' - } - } - - return provider -} diff --git a/src/renderer/src/aiCore/provider/config/aihubmix.ts b/src/renderer/src/aiCore/provider/config/aihubmix.ts new file mode 100644 index 0000000000..a881603047 --- /dev/null +++ b/src/renderer/src/aiCore/provider/config/aihubmix.ts @@ -0,0 +1,57 @@ +/** + * AiHubMix规则集 + */ +import { isOpenAIModel } from '@renderer/config/models' +import { Provider } from '@renderer/types' + +import { startsWith } from './helper' +import { provider2Provider } from './helper' +import type { ModelRule } from './types' + +const extraProviderConfig = (provider: Provider) => { + return { + ...provider, + extra_headers: { + ...provider.extra_headers, + 'APP-Code': 'MLTG2087' + } + } +} + +const AIHUBMIX_RULES: ModelRule[] = [ + { + name: 'claude', + match: startsWith('claude'), + provider: (provider: Provider) => { + return extraProviderConfig({ + ...provider, + type: 'anthropic' + }) + } + }, + { + name: 'gemini', + match: (model) => + (startsWith('gemini')(model) || startsWith('imagen')(model)) && + !model.id.endsWith('-nothink') && + !model.id.endsWith('-search'), + provider: (provider: Provider) => { + return extraProviderConfig({ + ...provider, + apiHost: 'https://aihubmix.com/gemini' + }) + } + }, + { + name: 'openai', + match: isOpenAIModel, + provider: (provider: Provider) => { + return extraProviderConfig({ + ...provider, + type: 'openai-response' + }) + } + } +] + +export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES) diff --git a/src/renderer/src/aiCore/provider/config/helper.ts b/src/renderer/src/aiCore/provider/config/helper.ts new file mode 100644 index 0000000000..31ae9b4eb0 --- /dev/null +++ b/src/renderer/src/aiCore/provider/config/helper.ts @@ -0,0 +1,22 @@ +import type { Model, Provider } from '@renderer/types' + +import type { ModelRule } from './types' + +export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase()) +export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type + +/** + * 解析模型对应的Provider ID + * @param model 模型对象 + * @param rules 匹配规则数组 + * @param fallback 默认fallback的providerId + * @returns 解析出的providerId + */ +export function provider2Provider(rules: ModelRule[], model: Model, provider: Provider): Provider { + for (const rule of rules) { + if (rule.match(model)) { + return rule.provider(provider) + } + } + return provider +} diff --git a/src/renderer/src/aiCore/provider/config/index.ts b/src/renderer/src/aiCore/provider/config/index.ts new file mode 100644 index 0000000000..7c19231d4e --- /dev/null +++ b/src/renderer/src/aiCore/provider/config/index.ts @@ -0,0 +1,16 @@ +// /** +// * Provider解析规则模块导出 +// */ + +// // 导出类型 +// export type { ModelRule } from './types' + +// // 导出匹配函数和解析器 +// export { endpointIs, resolveProvider, startsWith } from './helper' + +// // 导出规则集 +// export { AIHUBMIX_RULES } from './aihubmix' +// export { NEWAPI_RULES } from './newApi' + +export { aihubmixProviderCreator } from './aihubmix' +export { newApiResolverCreator } from './newApi' diff --git a/src/renderer/src/aiCore/provider/config/newApi.ts b/src/renderer/src/aiCore/provider/config/newApi.ts new file mode 100644 index 0000000000..e7de0bb328 --- /dev/null +++ b/src/renderer/src/aiCore/provider/config/newApi.ts @@ -0,0 +1,52 @@ +/** + * NewAPI规则集 + */ +import { Provider } from '@renderer/types' + +import { endpointIs, provider2Provider } from './helper' +import type { ModelRule } from './types' + +const NEWAPI_RULES: ModelRule[] = [ + { + name: 'anthropic', + match: endpointIs('anthropic'), + provider: (provider: Provider) => { + return { + ...provider, + type: 'anthropic' + } + } + }, + { + name: 'gemini', + match: endpointIs('gemini'), + provider: (provider: Provider) => { + return { + ...provider, + type: 'gemini' + } + } + }, + { + name: 'openai-response', + match: endpointIs('openai-response'), + provider: (provider: Provider) => { + return { + ...provider, + type: 'openai-response' + } + } + }, + { + name: 'openai', + match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model), + provider: (provider: Provider) => { + return { + ...provider, + type: 'openai' + } + } + } +] + +export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES) diff --git a/src/renderer/src/aiCore/provider/config/types.ts b/src/renderer/src/aiCore/provider/config/types.ts new file mode 100644 index 0000000000..5f3cc5a56b --- /dev/null +++ b/src/renderer/src/aiCore/provider/config/types.ts @@ -0,0 +1,7 @@ +import type { Model, Provider } from '@renderer/types' + +export interface ModelRule { + name: string + match: (model: Model) => boolean + provider: (provider: Provider) => Provider +} diff --git a/src/renderer/src/aiCore/provider/factory.ts b/src/renderer/src/aiCore/provider/factory.ts index 88255b112e..0c15fc0911 100644 --- a/src/renderer/src/aiCore/provider/factory.ts +++ b/src/renderer/src/aiCore/provider/factory.ts @@ -1,49 +1,75 @@ -import { AiCore, type ProviderId } from '@cherrystudio/ai-core' +import { AiCore, getProviderMapping, type ProviderId } from '@cherrystudio/ai-core' +import { loggerService } from '@logger' import { Provider } from '@renderer/types' -// TODO -// 初始化新的Provider注册系统 -// initializeNewProviders() +import { initializeNewProviders } from './providerConfigs' -// 静态Provider映射 - 核心providers +const logger = loggerService.withContext('ProviderFactory') + +/** + * 初始化动态Provider系统 + * 在模块加载时自动注册新的providers + */ +;(async () => { + try { + await initializeNewProviders() + } catch (error) { + logger.warn('Failed to initialize new providers:', error as Error) + } +})() + +/** + * 静态Provider映射表 + * 处理Cherry Studio特有的provider ID到AI SDK标准ID的映射 + */ const STATIC_PROVIDER_MAPPING: Record = { - // anthropic: 'anthropic', - gemini: 'google', - 'azure-openai': 'azure', - 'openai-response': 'openai', - grok: 'xai' + gemini: 'google', // Google Gemini -> google + 'azure-openai': 'azure', // Azure OpenAI -> azure + 'openai-response': 'openai', // OpenAI Responses -> openai + grok: 'xai' // Grok -> xai } +/** + * 尝试解析provider标识符(支持静态映射和动态映射) + */ +function tryResolveProviderId(identifier: string): ProviderId | null { + // 1. 检查静态映射 + const staticMapping = STATIC_PROVIDER_MAPPING[identifier] + if (staticMapping) { + return staticMapping + } + + // 2. 检查动态映射 + const dynamicMapping = getProviderMapping(identifier) + if (dynamicMapping && dynamicMapping !== identifier) { + return dynamicMapping as ProviderId + } + + // 3. 检查AiCore是否直接支持 + if (AiCore.isSupported(identifier)) { + return identifier as ProviderId + } + + return null +} + +/** + * 获取AI SDK Provider ID + * 简化版:减少重复逻辑,利用通用解析函数 + */ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' { - // 1. 首先检查静态映射 - const staticProviderId = STATIC_PROVIDER_MAPPING[provider.id] - if (staticProviderId) { - return staticProviderId - } - // TODO - // 2. 检查动态注册的provider映射(使用aiCore的函数) - // const dynamicProviderId = getProviderMapping(provider.id) - // if (dynamicProviderId) { - // return dynamicProviderId as ProviderId - // } - - // 3. 检查provider.type的静态映射 - const staticProviderType = STATIC_PROVIDER_MAPPING[provider.type] - if (staticProviderType) { - return staticProviderType - } - // TODO - // 4. 检查provider.type的动态映射 - // const dynamicProviderType = getProviderMapping(provider.type) - // if (dynamicProviderType) { - // return dynamicProviderType as ProviderId - // } - - // 5. 检查AiCore是否直接支持 - if (AiCore.isSupported(provider.id)) { - return provider.id as ProviderId + // 1. 尝试解析provider.id + const resolvedFromId = tryResolveProviderId(provider.id) + if (resolvedFromId) { + return resolvedFromId } - // 6. 最后的fallback + // 2. 尝试解析provider.type + const resolvedFromType = tryResolveProviderId(provider.type) + if (resolvedFromType) { + return resolvedFromType + } + + // 3. 最后的fallback(通常会成为openai-compatible) return provider.id as ProviderId } diff --git a/src/renderer/src/aiCore/provider/providerConfigs.ts b/src/renderer/src/aiCore/provider/providerConfigs.ts index efb37d5685..719d023e71 100644 --- a/src/renderer/src/aiCore/provider/providerConfigs.ts +++ b/src/renderer/src/aiCore/provider/providerConfigs.ts @@ -1,4 +1,4 @@ -import { type ProviderConfig } from '@cherrystudio/ai-core' +import { type ProviderConfig, registerMultipleProviders } from '@cherrystudio/ai-core' import { loggerService } from '@logger' const logger = loggerService.withContext('ProviderConfigs') @@ -43,19 +43,29 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & { } ] as const -// TODO -// /** -// * 初始化新的Providers -// * 使用aiCore的动态注册功能 -// */ -// export async function initializeNewProviders(): Promise { -// try { -// const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS) +/** + * 初始化新的Providers + * 使用aiCore的动态注册功能 + */ +export async function initializeNewProviders(): Promise { + try { + logger.info('Starting to register new providers', { + providerCount: NEW_PROVIDER_CONFIGS.length, + providerIds: NEW_PROVIDER_CONFIGS.map((p) => p.id) + }) -// if (successCount < NEW_PROVIDER_CONFIGS.length) { -// logger.warn('Some providers failed to register. Check previous error logs.') -// } -// } catch (error) { -// logger.error('Failed to initialize new providers:', error as Error) -// } -// } + const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS) + + logger.info('Provider registration completed', { + successCount, + totalCount: NEW_PROVIDER_CONFIGS.length, + failedCount: NEW_PROVIDER_CONFIGS.length - successCount + }) + + if (successCount < NEW_PROVIDER_CONFIGS.length) { + logger.warn('Some providers failed to register. Check previous error logs.') + } + } catch (error) { + logger.error('Failed to initialize new providers:', error as Error) + } +} diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index 72c46a6705..d8b1e853cc 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -28,7 +28,6 @@ export function buildProviderOptions( } ): Record { const providerId = getAiSdkProviderId(actualProvider) - // 构建 provider 特定的选项 let providerSpecificOptions: Record = {} @@ -37,11 +36,8 @@ export function buildProviderOptions( case 'openai': case 'azure': providerSpecificOptions = { - ...buildOpenAIProviderOptions(assistant, model, capabilities), - // 函数内有对于真实provider.id的判断,应该不会影响原生provider - ...buildGenericProviderOptions(assistant, model, capabilities) + ...buildOpenAIProviderOptions(assistant, model, capabilities) } - break case 'anthropic':