diff --git a/src/renderer/src/aiCore/index.ts b/src/renderer/src/aiCore/index.ts index 5419b70639..eb68da74ea 100644 --- a/src/renderer/src/aiCore/index.ts +++ b/src/renderer/src/aiCore/index.ts @@ -1,187 +1,16 @@ -import { loggerService } from '@logger' -import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory' -import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' -import { isDedicatedImageGenerationModel } from '@renderer/config/models' -import { getProviderByModel } from '@renderer/services/AssistantService' -import { withSpanResult } from '@renderer/services/SpanManagerService' -import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' -import type { GenerateImageParams, Model, Provider } from '@renderer/types' -import type { RequestOptions, SdkModel } from '@renderer/types/sdk' -import { isPromptToolUse } from '@renderer/utils/mcp-tools' +/** + * Cherry Studio AI Core - 统一入口点 + * + * 这是新的统一入口,保持向后兼容性 + * 默认导出legacy AiProvider以保持现有代码的兼容性 + */ -import { AihubmixAPIClient } from './clients/AihubmixAPIClient' -import { VertexAPIClient } from './clients/gemini/VertexAPIClient' -import { NewAPIClient } from './clients/NewAPIClient' -import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient' -import { CompletionsMiddlewareBuilder } from './middleware/builder' -import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware' -import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware' -import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware' -import { applyCompletionsMiddlewares } from './middleware/composer' -import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware' -import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware' -import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware' -import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware' -import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware' -import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware' -import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware' -import { MiddlewareRegistry } from './middleware/register' -import type { CompletionsParams, CompletionsResult } from './middleware/schemas' +// 导出Legacy AiProvider作为默认导出(保持向后兼容) +export { default } from './legacy/index' -const logger = loggerService.withContext('AiProvider') +// 同时导出Modern AiProvider供新代码使用 +export { default as ModernAiProvider } from './index_new' -export default class AiProvider { - private apiClient: BaseApiClient - - constructor(provider: Provider) { - // Use the new ApiClientFactory to get a BaseApiClient instance - this.apiClient = ApiClientFactory.create(provider) - } - - public async completions(params: CompletionsParams, options?: RequestOptions): Promise { - // 1. 根据模型识别正确的客户端 - const model = params.assistant.model - if (!model) { - return Promise.reject(new Error('Model is required')) - } - - // 根据client类型选择合适的处理方式 - let client: BaseApiClient - - if (this.apiClient instanceof AihubmixAPIClient) { - // AihubmixAPIClient: 根据模型选择合适的子client - client = this.apiClient.getClientForModel(model) - if (client instanceof OpenAIResponseAPIClient) { - client = client.getClient(model) as BaseApiClient - } - } else if (this.apiClient instanceof NewAPIClient) { - client = this.apiClient.getClientForModel(model) - if (client instanceof OpenAIResponseAPIClient) { - client = client.getClient(model) as BaseApiClient - } - } else if (this.apiClient instanceof OpenAIResponseAPIClient) { - // OpenAIResponseAPIClient: 根据模型特征选择API类型 - client = this.apiClient.getClient(model) as BaseApiClient - } else if (this.apiClient instanceof VertexAPIClient) { - client = this.apiClient.getClient(model) as BaseApiClient - } else { - // 其他client直接使用 - client = this.apiClient - } - - // 2. 构建中间件链 - const builder = CompletionsMiddlewareBuilder.withDefaults() - // images api - if (isDedicatedImageGenerationModel(model)) { - builder.clear() - builder - .add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName]) - .add(MiddlewareRegistry[ErrorHandlerMiddlewareName]) - .add(MiddlewareRegistry[AbortHandlerMiddlewareName]) - .add(MiddlewareRegistry[ImageGenerationMiddlewareName]) - } else { - // Existing logic for other models - logger.silly('Builder Params', params) - // 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题 - const clientTypes = client.getClientCompatibilityType(model) - const isOpenAICompatible = - clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient') - if (!isOpenAICompatible) { - logger.silly('ThinkingTagExtractionMiddleware is removed') - builder.remove(ThinkingTagExtractionMiddlewareName) - } - - const isAnthropicOrOpenAIResponseCompatible = - clientTypes.includes('AnthropicAPIClient') || clientTypes.includes('OpenAIResponseAPIClient') - if (!isAnthropicOrOpenAIResponseCompatible) { - logger.silly('RawStreamListenerMiddleware is removed') - builder.remove(RawStreamListenerMiddlewareName) - } - if (!params.enableWebSearch) { - logger.silly('WebSearchMiddleware is removed') - builder.remove(WebSearchMiddlewareName) - } - if (!params.mcpTools?.length) { - builder.remove(ToolUseExtractionMiddlewareName) - logger.silly('ToolUseExtractionMiddleware is removed') - builder.remove(McpToolChunkMiddlewareName) - logger.silly('McpToolChunkMiddleware is removed') - } - if (!isPromptToolUse(params.assistant)) { - builder.remove(ToolUseExtractionMiddlewareName) - logger.silly('ToolUseExtractionMiddleware is removed') - } - if (params.callType !== 'chat') { - logger.silly('AbortHandlerMiddleware is removed') - builder.remove(AbortHandlerMiddlewareName) - } - if (params.callType === 'test') { - builder.remove(ErrorHandlerMiddlewareName) - logger.silly('ErrorHandlerMiddleware is removed') - builder.remove(FinalChunkConsumerMiddlewareName) - logger.silly('FinalChunkConsumerMiddleware is removed') - builder.insertBefore(ThinkChunkMiddlewareName, MiddlewareRegistry[ThinkingTagExtractionMiddlewareName]) - logger.silly('ThinkingTagExtractionMiddleware is inserted') - } - } - - const middlewares = builder.build() - logger.silly('middlewares', middlewares) - - // 3. Create the wrapped SDK method with middlewares - const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares) - - // 4. Execute the wrapped method with the original params - const result = wrappedCompletionMethod(params, options) - return result - } - - public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise { - const traceName = params.assistant.model?.name - ? `${params.assistant.model?.name}.${params.callType}` - : `LLM.${params.callType}` - - const traceParams: StartSpanParams = { - name: traceName, - tag: 'LLM', - topicId: params.topicId || '', - modelName: params.assistant.model?.name - } - - return await withSpanResult(this.completions.bind(this), traceParams, params, options) - } - - public async models(): Promise { - return this.apiClient.listModels() - } - - public async getEmbeddingDimensions(model: Model): Promise { - try { - // Use the SDK instance to test embedding capabilities - if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') { - this.apiClient = this.apiClient.getClient(model) as BaseApiClient - } - const dimensions = await this.apiClient.getEmbeddingDimensions(model) - return dimensions - } catch (error) { - logger.error('Error getting embedding dimensions:', error as Error) - throw error - } - } - - public async generateImage(params: GenerateImageParams): Promise { - if (this.apiClient instanceof AihubmixAPIClient) { - const client = this.apiClient.getClientForModel({ id: params.model } as Model) - return client.generateImage(params) - } - return this.apiClient.generateImage(params) - } - - public getBaseURL(): string { - return this.apiClient.getBaseURL() - } - - public getApiKey(): string { - return this.apiClient.getApiKey() - } -} +// 导出一些常用的类型和工具 +export * from './legacy/clients/types' +export * from './legacy/middleware/schemas' diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index bf369dc840..44b8c8e312 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -8,136 +8,17 @@ * 3. 暂时保持接口兼容性 */ -import { - AiCore, - AiPlugin, - createExecutor, - generateImage, - ProviderConfigFactory, - type ProviderId, - type ProviderSettingsMap, - StreamTextParams -} from '@cherrystudio/ai-core' -import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins' -import { isDedicatedImageGenerationModel, isNotSupportedImageSizeModel } from '@renderer/config/models' -import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' -import { getProviderByModel } from '@renderer/services/AssistantService' +import { createExecutor, generateImage, StreamTextParams } from '@cherrystudio/ai-core' +import { isNotSupportedImageSizeModel } from '@renderer/config/models' import type { GenerateImageParams, Model, Provider } from '@renderer/types' import { ChunkType } from '@renderer/types/chunk' -import { formatApiHost } from '@renderer/utils/api' -import { cloneDeep } from 'lodash' import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter' -import LegacyAiProvider from './index' -import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder' -import { CompletionsResult } from './middleware/schemas' -import reasoningTimePlugin from './plugins/reasoningTimePlugin' -import { searchOrchestrationPlugin } from './plugins/searchOrchestrationPlugin' -import { createAihubmixProvider } from './provider/aihubmix' -import { getAiSdkProviderId } from './provider/factory' - -function getActualProvider(model: Model): Provider { - const provider = getProviderByModel(model) - // 如果是 vertexai 类型且没有 googleCredentials,转换为 VertexProvider - let actualProvider = cloneDeep(provider) - if (provider.type === 'vertexai' && !isVertexProvider(provider)) { - if (!isVertexAIConfigured()) { - throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') - } - actualProvider = createVertexProvider(provider) - } - - if (provider.id === 'aihubmix') { - actualProvider = createAihubmixProvider(model, actualProvider) - } - if (actualProvider.type === 'gemini') { - actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta') - } else { - actualProvider.apiHost = formatApiHost(actualProvider.apiHost) - } - return actualProvider -} - -/** - * 将 Provider 配置转换为新 AI SDK 格式 - */ -function providerToAiSdkConfig(actualProvider: Provider): { - providerId: ProviderId | 'openai-compatible' - options: ProviderSettingsMap[keyof ProviderSettingsMap] -} { - // console.log('actualProvider', actualProvider) - const aiSdkProviderId = getAiSdkProviderId(actualProvider) - // console.log('aiSdkProviderId', aiSdkProviderId) - // 如果provider是openai,则使用strict模式并且默认responses api - const actualProviderType = actualProvider.type - const openaiResponseOptions = - // 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses - actualProviderType === 'openai-response' - ? { - mode: 'responses' - } - : aiSdkProviderId === 'openai' - ? { - mode: 'chat' - } - : undefined - console.log('openaiResponseOptions', openaiResponseOptions) - console.log('actualProvider', actualProvider) - console.log('aiSdkProviderId', aiSdkProviderId) - if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { - const options = ProviderConfigFactory.fromProvider( - aiSdkProviderId, - { - baseURL: actualProvider.apiHost, - apiKey: actualProvider.apiKey - }, - { ...openaiResponseOptions, headers: actualProvider.extra_headers } - ) - - return { - providerId: aiSdkProviderId as ProviderId, - options - } - } else { - console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`) - const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey) - - return { - providerId: 'openai-compatible', - options: { - ...options, - name: actualProvider.id - } - } - } -} - -/** - * 检查是否支持使用新的AI SDK - */ -function isModernSdkSupported(provider: Provider, model?: Model): boolean { - // 目前支持主要的providers - const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai'] - - // 检查provider类型 - if (!supportedProviders.includes(provider.type)) { - return false - } - - // 对于 vertexai,检查配置是否完整 - if (provider.type === 'vertexai' && !isVertexAIConfigured()) { - return false - } - - // 图像生成模型现在支持新的 AI SDK - // (但需要确保 provider 是支持的 - - if (model && isDedicatedImageGenerationModel(model)) { - return true - } - - return true -} +import LegacyAiProvider from './legacy/index' +import { CompletionsResult } from './legacy/middleware/schemas' +import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder' +import { buildPlugins } from './plugins/PluginBuilder' +import { getActualProvider, isModernSdkSupported, providerToAiSdkConfig } from './provider/ProviderConfigProcessor' export default class ModernAiProvider { private legacyProvider: LegacyAiProvider @@ -156,62 +37,6 @@ export default class ModernAiProvider { return this.actualProvider } - /** - * 根据条件构建插件数组 - */ - private buildPlugins(middlewareConfig: AiSdkMiddlewareConfig) { - const plugins: AiPlugin[] = [] - // 1. 总是添加通用插件 - // plugins.push(textPlugin) - if (middlewareConfig.enableWebSearch) { - // 内置了默认搜索参数,如果改的话可以传config进去 - plugins.push(webSearchPlugin()) - } - // 2. 支持工具调用时添加搜索插件 - if (middlewareConfig.isSupportedToolUse) { - plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant)) - } - - // 3. 推理模型时添加推理插件 - if (middlewareConfig.enableReasoning) { - plugins.push(reasoningTimePlugin) - } - - // 4. 启用Prompt工具调用时添加工具插件 - if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) { - plugins.push( - createPromptToolUsePlugin({ - enabled: true, - createSystemMessage: (systemPrompt, params, context) => { - if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) { - if (context.isRecursiveCall) { - return null - } - params.messages = [ - { - role: 'assistant', - content: systemPrompt - }, - ...params.messages - ] - return null - } - return systemPrompt - } - }) - ) - } - - // if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) { - // plugins.push(createNativeToolUsePlugin()) - // } - console.log( - '最终插件列表:', - plugins.map((p) => p.name) - ) - return plugins - } - public async completions( modelId: string, params: StreamTextParams, @@ -236,7 +61,7 @@ export default class ModernAiProvider { ): Promise { // try { // 根据条件构建插件数组 - const plugins = this.buildPlugins(middlewareConfig) + const plugins = buildPlugins(middlewareConfig) console.log('this.config.providerId', this.config.providerId) console.log('this.config.options', this.config.options) console.log('plugins', plugins) diff --git a/src/renderer/src/aiCore/clients/AihubmixAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/AihubmixAPIClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/AihubmixAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/AihubmixAPIClient.ts diff --git a/src/renderer/src/aiCore/clients/ApiClientFactory.ts b/src/renderer/src/aiCore/legacy/clients/ApiClientFactory.ts similarity index 100% rename from src/renderer/src/aiCore/clients/ApiClientFactory.ts rename to src/renderer/src/aiCore/legacy/clients/ApiClientFactory.ts diff --git a/src/renderer/src/aiCore/clients/BaseApiClient.ts b/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/BaseApiClient.ts rename to src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts diff --git a/src/renderer/src/aiCore/clients/MixedBaseApiClient.ts b/src/renderer/src/aiCore/legacy/clients/MixedBaseApiClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/MixedBaseApiClient.ts rename to src/renderer/src/aiCore/legacy/clients/MixedBaseApiClient.ts diff --git a/src/renderer/src/aiCore/clients/NewAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/NewAPIClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/NewAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/NewAPIClient.ts diff --git a/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts b/src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts similarity index 100% rename from src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts rename to src/renderer/src/aiCore/legacy/clients/__tests__/ApiClientFactory.test.ts diff --git a/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts similarity index 99% rename from src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts index 6a978c7bb0..13e55deb1e 100644 --- a/src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts @@ -26,7 +26,6 @@ import { import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages' import AnthropicVertex from '@anthropic-ai/vertex-sdk' import { loggerService } from '@logger' -import { GenericChunk } from '@renderer/aiCore/middleware/schemas' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models' import { getAssistantSettings } from '@renderer/services/AssistantService' @@ -71,6 +70,7 @@ import { } from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' +import { GenericChunk } from '../../middleware/schemas' import { BaseApiClient } from '../BaseApiClient' import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types' diff --git a/src/renderer/src/aiCore/clients/anthropic/AnthropicVertexClient.ts b/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicVertexClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/anthropic/AnthropicVertexClient.ts rename to src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicVertexClient.ts diff --git a/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/aws/AwsBedrockAPIClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/aws/AwsBedrockAPIClient.ts diff --git a/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts similarity index 99% rename from src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts index fb7f00904e..54b41d5e8f 100644 --- a/src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/gemini/GeminiAPIClient.ts @@ -18,7 +18,6 @@ import { } from '@google/genai' import { loggerService } from '@logger' import { nanoid } from '@reduxjs/toolkit' -import { GenericChunk } from '@renderer/aiCore/middleware/schemas' import { findTokenLimit, GEMINI_FLASH_MODEL_REGEX, @@ -61,6 +60,7 @@ import { import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { defaultTimeout, MB } from '@shared/config/constant' +import { GenericChunk } from '../../middleware/schemas' import { BaseApiClient } from '../BaseApiClient' import { RequestTransformer, ResponseChunkTransformer } from '../types' diff --git a/src/renderer/src/aiCore/clients/gemini/VertexAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/gemini/VertexAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts diff --git a/src/renderer/src/aiCore/clients/index.ts b/src/renderer/src/aiCore/legacy/clients/index.ts similarity index 100% rename from src/renderer/src/aiCore/clients/index.ts rename to src/renderer/src/aiCore/legacy/clients/index.ts diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts rename to src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts rename to src/renderer/src/aiCore/legacy/clients/openai/OpenAIBaseClient.ts diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts similarity index 99% rename from src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts index d2bb4f7b8b..f14294dddb 100644 --- a/src/renderer/src/aiCore/clients/openai/OpenAIResponseAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts @@ -1,5 +1,3 @@ -import { GenericChunk } from '@renderer/aiCore/middleware/schemas' -import { CompletionsContext } from '@renderer/aiCore/middleware/types' import { isOpenAIChatCompletionOnlyModel, isOpenAILLMModel, @@ -42,6 +40,8 @@ import { isEmpty } from 'lodash' import OpenAI, { AzureOpenAI } from 'openai' import { ResponseInput } from 'openai/resources/responses/responses' +import { GenericChunk } from '../../middleware/schemas' +import { CompletionsContext } from '../../middleware/types' import { RequestTransformer, ResponseChunkTransformer } from '../types' import { OpenAIAPIClient } from './OpenAIApiClient' import { OpenAIBaseClient } from './OpenAIBaseClient' diff --git a/src/renderer/src/aiCore/clients/ppio/PPIOAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/ppio/PPIOAPIClient.ts similarity index 100% rename from src/renderer/src/aiCore/clients/ppio/PPIOAPIClient.ts rename to src/renderer/src/aiCore/legacy/clients/ppio/PPIOAPIClient.ts diff --git a/src/renderer/src/aiCore/clients/types.ts b/src/renderer/src/aiCore/legacy/clients/types.ts similarity index 100% rename from src/renderer/src/aiCore/clients/types.ts rename to src/renderer/src/aiCore/legacy/clients/types.ts diff --git a/src/renderer/src/aiCore/legacy/index.ts b/src/renderer/src/aiCore/legacy/index.ts new file mode 100644 index 0000000000..a4ae0e4550 --- /dev/null +++ b/src/renderer/src/aiCore/legacy/index.ts @@ -0,0 +1,187 @@ +import { loggerService } from '@logger' +import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory' +import { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient' +import { isDedicatedImageGenerationModel } from '@renderer/config/models' +import { getProviderByModel } from '@renderer/services/AssistantService' +import { withSpanResult } from '@renderer/services/SpanManagerService' +import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' +import type { GenerateImageParams, Model, Provider } from '@renderer/types' +import type { RequestOptions, SdkModel } from '@renderer/types/sdk' +import { isPromptToolUse } from '@renderer/utils/mcp-tools' + +import { AihubmixAPIClient } from './clients/AihubmixAPIClient' +import { VertexAPIClient } from './clients/gemini/VertexAPIClient' +import { NewAPIClient } from './clients/NewAPIClient' +import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient' +import { CompletionsMiddlewareBuilder } from './middleware/builder' +import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware' +import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware' +import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware' +import { applyCompletionsMiddlewares } from './middleware/composer' +import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware' +import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware' +import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware' +import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware' +import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware' +import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware' +import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware' +import { MiddlewareRegistry } from './middleware/register' +import type { CompletionsParams, CompletionsResult } from './middleware/schemas' + +const logger = loggerService.withContext('AiProvider') + +export default class AiProvider { + private apiClient: BaseApiClient + + constructor(provider: Provider) { + // Use the new ApiClientFactory to get a BaseApiClient instance + this.apiClient = ApiClientFactory.create(provider) + } + + public async completions(params: CompletionsParams, options?: RequestOptions): Promise { + // 1. 根据模型识别正确的客户端 + const model = params.assistant.model + if (!model) { + return Promise.reject(new Error('Model is required')) + } + + // 根据client类型选择合适的处理方式 + let client: BaseApiClient + + if (this.apiClient instanceof AihubmixAPIClient) { + // AihubmixAPIClient: 根据模型选择合适的子client + client = this.apiClient.getClientForModel(model) + if (client instanceof OpenAIResponseAPIClient) { + client = client.getClient(model) as BaseApiClient + } + } else if (this.apiClient instanceof NewAPIClient) { + client = this.apiClient.getClientForModel(model) + if (client instanceof OpenAIResponseAPIClient) { + client = client.getClient(model) as BaseApiClient + } + } else if (this.apiClient instanceof OpenAIResponseAPIClient) { + // OpenAIResponseAPIClient: 根据模型特征选择API类型 + client = this.apiClient.getClient(model) as BaseApiClient + } else if (this.apiClient instanceof VertexAPIClient) { + client = this.apiClient.getClient(model) as BaseApiClient + } else { + // 其他client直接使用 + client = this.apiClient + } + + // 2. 构建中间件链 + const builder = CompletionsMiddlewareBuilder.withDefaults() + // images api + if (isDedicatedImageGenerationModel(model)) { + builder.clear() + builder + .add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName]) + .add(MiddlewareRegistry[ErrorHandlerMiddlewareName]) + .add(MiddlewareRegistry[AbortHandlerMiddlewareName]) + .add(MiddlewareRegistry[ImageGenerationMiddlewareName]) + } else { + // Existing logic for other models + logger.silly('Builder Params', params) + // 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题 + const clientTypes = client.getClientCompatibilityType(model) + const isOpenAICompatible = + clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient') + if (!isOpenAICompatible) { + logger.silly('ThinkingTagExtractionMiddleware is removed') + builder.remove(ThinkingTagExtractionMiddlewareName) + } + + const isAnthropicOrOpenAIResponseCompatible = + clientTypes.includes('AnthropicAPIClient') || clientTypes.includes('OpenAIResponseAPIClient') + if (!isAnthropicOrOpenAIResponseCompatible) { + logger.silly('RawStreamListenerMiddleware is removed') + builder.remove(RawStreamListenerMiddlewareName) + } + if (!params.enableWebSearch) { + logger.silly('WebSearchMiddleware is removed') + builder.remove(WebSearchMiddlewareName) + } + if (!params.mcpTools?.length) { + builder.remove(ToolUseExtractionMiddlewareName) + logger.silly('ToolUseExtractionMiddleware is removed') + builder.remove(McpToolChunkMiddlewareName) + logger.silly('McpToolChunkMiddleware is removed') + } + if (!isPromptToolUse(params.assistant)) { + builder.remove(ToolUseExtractionMiddlewareName) + logger.silly('ToolUseExtractionMiddleware is removed') + } + if (params.callType !== 'chat') { + logger.silly('AbortHandlerMiddleware is removed') + builder.remove(AbortHandlerMiddlewareName) + } + if (params.callType === 'test') { + builder.remove(ErrorHandlerMiddlewareName) + logger.silly('ErrorHandlerMiddleware is removed') + builder.remove(FinalChunkConsumerMiddlewareName) + logger.silly('FinalChunkConsumerMiddleware is removed') + builder.insertBefore(ThinkChunkMiddlewareName, MiddlewareRegistry[ThinkingTagExtractionMiddlewareName]) + logger.silly('ThinkingTagExtractionMiddleware is inserted') + } + } + + const middlewares = builder.build() + logger.silly('middlewares', middlewares) + + // 3. Create the wrapped SDK method with middlewares + const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares) + + // 4. Execute the wrapped method with the original params + const result = wrappedCompletionMethod(params, options) + return result + } + + public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise { + const traceName = params.assistant.model?.name + ? `${params.assistant.model?.name}.${params.callType}` + : `LLM.${params.callType}` + + const traceParams: StartSpanParams = { + name: traceName, + tag: 'LLM', + topicId: params.topicId || '', + modelName: params.assistant.model?.name + } + + return await withSpanResult(this.completions.bind(this), traceParams, params, options) + } + + public async models(): Promise { + return this.apiClient.listModels() + } + + public async getEmbeddingDimensions(model: Model): Promise { + try { + // Use the SDK instance to test embedding capabilities + if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') { + this.apiClient = this.apiClient.getClient(model) as BaseApiClient + } + const dimensions = await this.apiClient.getEmbeddingDimensions(model) + return dimensions + } catch (error) { + logger.error('Error getting embedding dimensions:', error as Error) + throw error + } + } + + public async generateImage(params: GenerateImageParams): Promise { + if (this.apiClient instanceof AihubmixAPIClient) { + const client = this.apiClient.getClientForModel({ id: params.model } as Model) + return client.generateImage(params) + } + return this.apiClient.generateImage(params) + } + + public getBaseURL(): string { + return this.apiClient.getBaseURL() + } + + public getApiKey(): string { + return this.apiClient.getApiKey() + } +} diff --git a/src/renderer/src/aiCore/middleware/BUILDER_USAGE.md b/src/renderer/src/aiCore/legacy/middleware/BUILDER_USAGE.md similarity index 100% rename from src/renderer/src/aiCore/middleware/BUILDER_USAGE.md rename to src/renderer/src/aiCore/legacy/middleware/BUILDER_USAGE.md diff --git a/src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md b/src/renderer/src/aiCore/legacy/middleware/MIDDLEWARE_SPECIFICATION.md similarity index 100% rename from src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md rename to src/renderer/src/aiCore/legacy/middleware/MIDDLEWARE_SPECIFICATION.md diff --git a/src/renderer/src/aiCore/middleware/builder.ts b/src/renderer/src/aiCore/legacy/middleware/builder.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/builder.ts rename to src/renderer/src/aiCore/legacy/middleware/builder.ts diff --git a/src/renderer/src/aiCore/middleware/common/AbortHandlerMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/common/AbortHandlerMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/common/AbortHandlerMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/common/AbortHandlerMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/common/ErrorHandlerMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/common/ErrorHandlerMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/common/ErrorHandlerMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/common/ErrorHandlerMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/common/FinalChunkConsumerMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/common/FinalChunkConsumerMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/common/FinalChunkConsumerMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/common/LoggingMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/common/LoggingMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/common/LoggingMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/common/LoggingMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/composer.ts b/src/renderer/src/aiCore/legacy/middleware/composer.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/composer.ts rename to src/renderer/src/aiCore/legacy/middleware/composer.ts diff --git a/src/renderer/src/aiCore/middleware/core/McpToolChunkMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/McpToolChunkMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/core/McpToolChunkMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/McpToolChunkMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/RawStreamListenerMiddleware.ts similarity index 94% rename from src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/RawStreamListenerMiddleware.ts index 25d0e358c6..e97a2ce96c 100644 --- a/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts +++ b/src/renderer/src/aiCore/legacy/middleware/core/RawStreamListenerMiddleware.ts @@ -1,4 +1,4 @@ -import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient' +import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient' import { isAnthropicModel } from '@renderer/config/models' import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk' diff --git a/src/renderer/src/aiCore/middleware/core/ResponseTransformMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/ResponseTransformMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/core/ResponseTransformMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/ResponseTransformMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/core/StreamAdapterMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/StreamAdapterMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/core/StreamAdapterMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/StreamAdapterMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/TextChunkMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/TextChunkMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/ThinkChunkMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/ThinkChunkMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/core/TransformCoreToSdkParamsMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/TransformCoreToSdkParamsMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/core/TransformCoreToSdkParamsMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/TransformCoreToSdkParamsMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/core/WebSearchMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/core/WebSearchMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/core/WebSearchMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/core/WebSearchMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/feat/ImageGenerationMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/feat/ImageGenerationMiddleware.ts similarity index 98% rename from src/renderer/src/aiCore/middleware/feat/ImageGenerationMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/feat/ImageGenerationMiddleware.ts index ceb8d791d7..0f89e8aca8 100644 --- a/src/renderer/src/aiCore/middleware/feat/ImageGenerationMiddleware.ts +++ b/src/renderer/src/aiCore/legacy/middleware/feat/ImageGenerationMiddleware.ts @@ -1,4 +1,3 @@ -import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' import { isDedicatedImageGenerationModel } from '@renderer/config/models' import FileManager from '@renderer/services/FileManager' import { ChunkType } from '@renderer/types/chunk' @@ -7,6 +6,7 @@ import { defaultTimeout } from '@shared/config/constant' import OpenAI from 'openai' import { toFile } from 'openai/uploads' +import { BaseApiClient } from '../../clients/BaseApiClient' import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' import { CompletionsContext, CompletionsMiddleware } from '../types' diff --git a/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/feat/ThinkingTagExtractionMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/feat/ThinkingTagExtractionMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/feat/ToolUseExtractionMiddleware.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/feat/ToolUseExtractionMiddleware.ts rename to src/renderer/src/aiCore/legacy/middleware/feat/ToolUseExtractionMiddleware.ts diff --git a/src/renderer/src/aiCore/middleware/index.ts b/src/renderer/src/aiCore/legacy/middleware/index.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/index.ts rename to src/renderer/src/aiCore/legacy/middleware/index.ts diff --git a/src/renderer/src/aiCore/middleware/register.ts b/src/renderer/src/aiCore/legacy/middleware/register.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/register.ts rename to src/renderer/src/aiCore/legacy/middleware/register.ts diff --git a/src/renderer/src/aiCore/middleware/schemas.ts b/src/renderer/src/aiCore/legacy/middleware/schemas.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/schemas.ts rename to src/renderer/src/aiCore/legacy/middleware/schemas.ts diff --git a/src/renderer/src/aiCore/middleware/types.ts b/src/renderer/src/aiCore/legacy/middleware/types.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/types.ts rename to src/renderer/src/aiCore/legacy/middleware/types.ts diff --git a/src/renderer/src/aiCore/middleware/utils.ts b/src/renderer/src/aiCore/legacy/middleware/utils.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/utils.ts rename to src/renderer/src/aiCore/legacy/middleware/utils.ts diff --git a/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts rename to src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts diff --git a/src/renderer/src/aiCore/middleware/aisdk/README.md b/src/renderer/src/aiCore/middleware/README.md similarity index 100% rename from src/renderer/src/aiCore/middleware/aisdk/README.md rename to src/renderer/src/aiCore/middleware/README.md diff --git a/src/renderer/src/aiCore/plugins/PluginBuilder.ts b/src/renderer/src/aiCore/plugins/PluginBuilder.ts new file mode 100644 index 0000000000..f54dec923d --- /dev/null +++ b/src/renderer/src/aiCore/plugins/PluginBuilder.ts @@ -0,0 +1,62 @@ +import { AiPlugin } from '@cherrystudio/ai-core' +import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins' + +import { AiSdkMiddlewareConfig } from '../middleware/AiSdkMiddlewareBuilder' +import reasoningTimePlugin from './reasoningTimePlugin' +import { searchOrchestrationPlugin } from './searchOrchestrationPlugin' + +/** + * 根据条件构建插件数组 + */ +export function buildPlugins(middlewareConfig: AiSdkMiddlewareConfig): AiPlugin[] { + const plugins: AiPlugin[] = [] + // 1. 总是添加通用插件 + // plugins.push(textPlugin) + if (middlewareConfig.enableWebSearch) { + // 内置了默认搜索参数,如果改的话可以传config进去 + plugins.push(webSearchPlugin()) + } + // 2. 支持工具调用时添加搜索插件 + if (middlewareConfig.isSupportedToolUse) { + plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant)) + } + + // 3. 推理模型时添加推理插件 + if (middlewareConfig.enableReasoning) { + plugins.push(reasoningTimePlugin) + } + + // 4. 启用Prompt工具调用时添加工具插件 + if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) { + plugins.push( + createPromptToolUsePlugin({ + enabled: true, + createSystemMessage: (systemPrompt, params, context) => { + if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) { + if (context.isRecursiveCall) { + return null + } + params.messages = [ + { + role: 'assistant', + content: systemPrompt + }, + ...params.messages + ] + return null + } + return systemPrompt + } + }) + ) + } + + // if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) { + // plugins.push(createNativeToolUsePlugin()) + // } + console.log( + '最终插件列表:', + plugins.map((p) => p.name) + ) + return plugins +} diff --git a/src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts b/src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts new file mode 100644 index 0000000000..35776188f9 --- /dev/null +++ b/src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts @@ -0,0 +1,113 @@ +import { AiCore, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core' +import { isDedicatedImageGenerationModel } from '@renderer/config/models' +import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' +import { getProviderByModel } from '@renderer/services/AssistantService' +import type { Model, Provider } from '@renderer/types' +import { formatApiHost } from '@renderer/utils/api' +import { cloneDeep } from 'lodash' + +import { createAihubmixProvider } from './aihubmix' +import { getAiSdkProviderId } from './factory' + +export function getActualProvider(model: Model): Provider { + const provider = getProviderByModel(model) + // 如果是 vertexai 类型且没有 googleCredentials,转换为 VertexProvider + let actualProvider = cloneDeep(provider) + if (provider.type === 'vertexai' && !isVertexProvider(provider)) { + if (!isVertexAIConfigured()) { + throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') + } + actualProvider = createVertexProvider(provider) + } + + if (provider.id === 'aihubmix') { + actualProvider = createAihubmixProvider(model, actualProvider) + } + if (actualProvider.type === 'gemini') { + actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta') + } else { + actualProvider.apiHost = formatApiHost(actualProvider.apiHost) + } + return actualProvider +} + +/** + * 将 Provider 配置转换为新 AI SDK 格式 + */ +export function providerToAiSdkConfig(actualProvider: Provider): { + providerId: ProviderId | 'openai-compatible' + options: ProviderSettingsMap[keyof ProviderSettingsMap] +} { + // console.log('actualProvider', actualProvider) + const aiSdkProviderId = getAiSdkProviderId(actualProvider) + // console.log('aiSdkProviderId', aiSdkProviderId) + // 如果provider是openai,则使用strict模式并且默认responses api + const actualProviderType = actualProvider.type + const openaiResponseOptions = + // 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses + actualProviderType === 'openai-response' + ? { + mode: 'responses' + } + : aiSdkProviderId === 'openai' + ? { + mode: 'chat' + } + : undefined + console.log('openaiResponseOptions', openaiResponseOptions) + console.log('actualProvider', actualProvider) + console.log('aiSdkProviderId', aiSdkProviderId) + if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { + const options = ProviderConfigFactory.fromProvider( + aiSdkProviderId, + { + baseURL: actualProvider.apiHost, + apiKey: actualProvider.apiKey + }, + { ...openaiResponseOptions, headers: actualProvider.extra_headers } + ) + + return { + providerId: aiSdkProviderId as ProviderId, + options + } + } else { + console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`) + const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey) + + return { + providerId: 'openai-compatible', + options: { + ...options, + name: actualProvider.id + } + } + } +} + +/** + * 检查是否支持使用新的AI SDK + */ +export function isModernSdkSupported(provider: Provider, model?: Model): boolean { + // 目前支持主要的providers + const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai'] + + // 检查provider类型 + if (!supportedProviders.includes(provider.type)) { + return false + } + + // 对于 vertexai,检查配置是否完整 + if (provider.type === 'vertexai' && !isVertexAIConfigured()) { + return false + } + + // 图像生成模型现在支持新的 AI SDK + // (但需要确保 provider 是支持的 + + if (model && isDedicatedImageGenerationModel(model)) { + return true + } + + return true +} diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts index 739f5624e4..c1651cdcf0 100644 --- a/src/renderer/src/aiCore/transformParameters.ts +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -13,8 +13,6 @@ import { TextPart, UserModelMessage } from '@cherrystudio/ai-core' -import AiProvider from '@renderer/aiCore' -import { CompletionsParams } from '@renderer/aiCore/middleware/schemas' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { isGenerateImageModel, @@ -49,6 +47,8 @@ import { import { defaultTimeout } from '@shared/config/constant' import { isEmpty } from 'lodash' +import AiProvider from './legacy/index' +import { CompletionsParams } from './legacy/middleware/schemas' // import { webSearchTool } from './tools/WebSearchTool' // import { jsonSchemaToZod } from 'json-schema-to-zod' import { setupToolsConfig } from './utils/mcp' diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 46ddceb4c8..3d644663c4 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -4,8 +4,8 @@ import { StreamTextParams } from '@cherrystudio/ai-core' import { loggerService } from '@logger' import AiProvider from '@renderer/aiCore' -import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder' -import { CompletionsParams } from '@renderer/aiCore/middleware/schemas' +import { CompletionsParams } from '@renderer/aiCore/legacy/middleware/schemas' +import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/AiSdkMiddlewareBuilder' import { buildStreamTextParams } from '@renderer/aiCore/transformParameters' import { isDedicatedImageGenerationModel, diff --git a/src/renderer/src/services/__tests__/ApiService.test.ts b/src/renderer/src/services/__tests__/ApiService.test.ts index 249dcd01f6..27d57f3299 100644 --- a/src/renderer/src/services/__tests__/ApiService.test.ts +++ b/src/renderer/src/services/__tests__/ApiService.test.ts @@ -9,13 +9,13 @@ import { import { FinishReason, MediaModality } from '@google/genai' import { FunctionCall } from '@google/genai' import AiProvider from '@renderer/aiCore' -import { OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/clients' -import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient' -import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory' -import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' -import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient' -import { OpenAIResponseAPIClient } from '@renderer/aiCore/clients/openai/OpenAIResponseAPIClient' -import { GenericChunk } from '@renderer/aiCore/middleware/schemas' +import { OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/legacy/clients' +import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient' +import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory' +import { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient' +import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient' +import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient' +import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas' import { isVisionModel } from '@renderer/config/models' import { Assistant, MCPCallToolResponse, MCPToolResponse, Model, Provider, WebSearchSource } from '@renderer/types' import {