diff --git a/package.json b/package.json index 17ad9a71db..fca46dc9f9 100644 --- a/package.json +++ b/package.json @@ -272,7 +272,6 @@ "winston": "^3.17.0", "winston-daily-rotate-file": "^5.0.0", "word-extractor": "^1.0.4", - "zhipu-ai-provider": "0.2.0-beta.1", "zipread": "^1.3.3", "zod": "^3.25.74" }, diff --git a/packages/aiCore/src/core/providers/factory.ts b/packages/aiCore/src/core/providers/factory.ts index de8bf71310..831526c192 100644 --- a/packages/aiCore/src/core/providers/factory.ts +++ b/packages/aiCore/src/core/providers/factory.ts @@ -37,19 +37,6 @@ const configHandlers: { resourceName: azureProvider.resourceName }) } - // 'google-vertex': (builder, provider) => { - // const vertexBuilder = builder as ProviderConfigBuilder<'google-vertex'> - // const vertexProvider = provider as CompleteProviderConfig<'google-vertex'> - // vertexBuilder - // .withGoogleVertexConfig({ - // project: vertexProvider.project, - // location: vertexProvider.location - // }) - // .withGoogleCredentials({ - // clientEmail: vertexProvider.googleCredentials?.clientEmail || '', - // privateKey: vertexProvider.googleCredentials?.privateKey || '' - // }) - // } } export class ProviderConfigBuilder { @@ -117,42 +104,6 @@ export class ProviderConfigBuilder { return this } - /** - * Google 特定配置 - */ - // withGoogleVertexConfig(options: { project?: string; location?: string }): T extends 'google-vertex' ? this : never - // withGoogleVertexConfig(options: any): any { - // if (this.providerId === 'google-vertex') { - // const googleConfig = this.config as CompleteProviderConfig<'google-vertex'> - // if (options.project) { - // googleConfig.project = options.project - // } - // if (options.location) { - // googleConfig.location = options.location - // if (options.location === 'global') { - // googleConfig.baseURL = 'https://aiplatform.googleapis.com' - // } - // } - // } - // return this - // } - - withGoogleCredentials(credentials: { - clientEmail: string - privateKey: string - }): T extends 'google-vertex' ? this : never - withGoogleCredentials(): any { - // withGoogleCredentials(credentials: any): any { - // if (this.providerId === 'google-vertex') { - // const vertexConfig = this.config as CompleteProviderConfig<'google-vertex'> - // vertexConfig.googleCredentials = { - // clientEmail: credentials.clientEmail, - // privateKey: formatPrivateKey(credentials.privateKey) - // } - // } - return this - } - /** * 设置自定义参数 */ diff --git a/packages/aiCore/src/core/providers/registry.ts b/packages/aiCore/src/core/providers/registry.ts index 556863324d..fddc3c8aa1 100644 --- a/packages/aiCore/src/core/providers/registry.ts +++ b/packages/aiCore/src/core/providers/registry.ts @@ -16,6 +16,9 @@ import { type ProviderConfig } from './types' export class AiProviderRegistry { private static instance: AiProviderRegistry private registry = new Map() + // 动态注册扩展 + private dynamicMappings = new Map() + private dynamicProviders = new Set() private constructor() { this.initializeProviders() @@ -30,11 +33,10 @@ export class AiProviderRegistry { /** * 初始化所有支持的 Providers - * 基于 AI SDK 官方文档: https://ai-sdk.dev/providers/ai-sdk-providers + * 基于 AI SDK 官方文档: https://v5.ai-sdk.dev/providers/ai-sdk-providers */ private initializeProviders(): void { const providers: ProviderConfig[] = [ - // 官方 AI SDK Providers (19个) { id: 'openai', name: 'OpenAI', @@ -67,20 +69,6 @@ export class AiProviderRegistry { creator: createGoogleGenerativeAI, supportsImageGeneration: true }, - // { - // id: 'google-vertex', - // name: 'Google Vertex AI', - // import: () => import('@ai-sdk/google-vertex/edge'), - // creatorFunctionName: 'createVertex', - // supportsImageGeneration: true - // }, - // { - // id: 'mistral', - // name: 'Mistral AI', - // import: () => import('@ai-sdk/mistral'), - // creatorFunctionName: 'createMistral', - // supportsImageGeneration: false - // }, { id: 'xai', name: 'xAI (Grok)', @@ -93,111 +81,12 @@ export class AiProviderRegistry { creator: createAzure, supportsImageGeneration: true }, - // { - // id: 'bedrock', - // name: 'Amazon Bedrock', - // import: () => import('@ai-sdk/amazon-bedrock'), - // creatorFunctionName: 'createAmazonBedrock', - // supportsImageGeneration: false - // }, - // { - // id: 'cohere', - // name: 'Cohere', - // import: () => import('@ai-sdk/cohere'), - // creatorFunctionName: 'createCohere', - // supportsImageGeneration: false - // }, - // { - // id: 'groq', - // name: 'Groq', - // import: () => import('@ai-sdk/groq'), - // creatorFunctionName: 'createGroq', - // supportsImageGeneration: false - // }, - // { - // id: 'together', - // name: 'Together.ai', - // import: () => import('@ai-sdk/togetherai'), - // creatorFunctionName: 'createTogetherAI', - // supportsImageGeneration: true - // }, - // { - // id: 'fireworks', - // name: 'Fireworks', - // import: () => import('@ai-sdk/fireworks'), - // creatorFunctionName: 'createFireworks', - // supportsImageGeneration: true - // }, { id: 'deepseek', name: 'DeepSeek', creator: createDeepSeek, supportsImageGeneration: false } - // { - // id: 'cerebras', - // name: 'Cerebras', - // import: () => import('@ai-sdk/cerebras'), - // creatorFunctionName: 'createCerebras', - // supportsImageGeneration: false - // }, - // { - // id: 'deepinfra', - // name: 'DeepInfra', - // import: () => import('@ai-sdk/deepinfra'), - // creatorFunctionName: 'createDeepInfra', - // supportsImageGeneration: false - // }, - // { - // id: 'replicate', - // name: 'Replicate', - // import: () => import('@ai-sdk/replicate'), - // creatorFunctionName: 'createReplicate', - // supportsImageGeneration: true - // }, - // { - // id: 'perplexity', - // name: 'Perplexity', - // import: () => import('@ai-sdk/perplexity'), - // creatorFunctionName: 'createPerplexity', - // supportsImageGeneration: false - // }, - // { - // id: 'fal', - // name: 'Fal AI', - // import: () => import('@ai-sdk/fal'), - // creatorFunctionName: 'createFal', - // supportsImageGeneration: false - // }, - // { - // id: 'vercel', - // name: 'Vercel', - // import: () => import('@ai-sdk/vercel'), - // creatorFunctionName: 'createVercel' - // }, - - // 社区 Providers (5个) - // { - // id: 'ollama', - // name: 'Ollama', - // import: () => import('ollama-ai-provider'), - // creatorFunctionName: 'createOllama', - // supportsImageGeneration: false - // }, - // { - // id: 'anthropic-vertex', - // name: 'Anthropic Vertex AI', - // import: () => import('anthropic-vertex-ai'), - // creatorFunctionName: 'createAnthropicVertex', - // supportsImageGeneration: false - // }, - // { - // id: 'openrouter', - // name: 'OpenRouter', - // import: () => import('@openrouter/ai-sdk-provider'), - // creatorFunctionName: 'createOpenRouter', - // supportsImageGeneration: false - // } ] providers.forEach((config) => { @@ -243,28 +132,92 @@ export class AiProviderRegistry { this.registry.set(config.id, config) } + /** + * 动态注册Provider并支持映射关系 + */ + public registerDynamicProvider( + config: ProviderConfig & { + mappings?: Record + } + ): boolean { + try { + // 验证配置 + if (!config.id || config.id.trim() === '') { + console.error('Provider ID cannot be empty') + return false + } + + // 注册provider + this.registerProvider(config) + + // 记录为动态provider + this.dynamicProviders.add(config.id) + + // 添加映射关系(如果提供) + if (config.mappings) { + Object.entries(config.mappings).forEach(([key, value]) => { + this.dynamicMappings.set(key, value) + }) + } + + return true + } catch (error) { + console.error(`Failed to register provider ${config.id}:`, error) + return false + } + } + + /** + * 批量注册多个动态Providers + */ + public registerMultipleProviders( + configs: (ProviderConfig & { + mappings?: Record + })[] + ): number { + let successCount = 0 + configs.forEach((config) => { + if (this.registerDynamicProvider(config)) { + successCount++ + } + }) + return successCount + } + + /** + * 获取Provider映射(包括动态映射) + */ + public getProviderMapping(providerId: string): string | undefined { + return this.dynamicMappings.get(providerId) || (this.dynamicProviders.has(providerId) ? providerId : undefined) + } + + /** + * 检查是否为动态注册的Provider + */ + public isDynamicProvider(providerId: string): boolean { + return this.dynamicProviders.has(providerId) + } + + /** + * 获取所有动态Provider映射 + */ + public getAllDynamicMappings(): Record { + return Object.fromEntries(this.dynamicMappings) + } + + /** + * 获取所有动态注册的Providers + */ + public getDynamicProviders(): string[] { + return Array.from(this.dynamicProviders) + } + /** * 清理资源 */ public cleanup(): void { this.registry.clear() } - - // /** - // * 获取兼容现有实现的注册表格式 - // */ - // public getCompatibleRegistry(): Record Promise; creatorFunctionName: string }> { - // const compatibleRegistry: Record Promise; creatorFunctionName: string }> = {} - - // this.getAllProviders().forEach((provider) => { - // compatibleRegistry[provider.id] = { - // import: provider.import, - // creatorFunctionName: provider.creatorFunctionName - // } - // }) - - // return compatibleRegistry - // } } // 导出单例实例 @@ -275,5 +228,16 @@ export const getProvider = (id: string) => aiProviderRegistry.getProvider(id) export const getAllProviders = () => aiProviderRegistry.getAllProviders() export const isProviderSupported = (id: string) => aiProviderRegistry.isSupported(id) export const registerProvider = (config: ProviderConfig) => aiProviderRegistry.registerProvider(config) + +// 动态注册相关便捷函数 +export const registerDynamicProvider = (config: ProviderConfig & { mappings?: Record }) => + aiProviderRegistry.registerDynamicProvider(config) +export const registerMultipleProviders = (configs: (ProviderConfig & { mappings?: Record })[]) => + aiProviderRegistry.registerMultipleProviders(configs) +export const getProviderMapping = (providerId: string) => aiProviderRegistry.getProviderMapping(providerId) +export const isDynamicProvider = (providerId: string) => aiProviderRegistry.isDynamicProvider(providerId) +export const getAllDynamicMappings = () => aiProviderRegistry.getAllDynamicMappings() +export const getDynamicProviders = () => aiProviderRegistry.getDynamicProviders() + // 兼容现有实现的导出 // export const PROVIDER_REGISTRY = aiProviderRegistry.getCompatibleRegistry() diff --git a/packages/aiCore/src/core/providers/types.ts b/packages/aiCore/src/core/providers/types.ts index ec9bdf82a4..4f8eafb51e 100644 --- a/packages/aiCore/src/core/providers/types.ts +++ b/packages/aiCore/src/core/providers/types.ts @@ -6,11 +6,30 @@ import { type OpenAIProviderSettings } from '@ai-sdk/openai' import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible' import { type XaiProviderSettings } from '@ai-sdk/xai' +export interface ExtensibleProviderSettingsMap { + // 基础的静态providers + openai: OpenAIProviderSettings + 'openai-responses': OpenAIProviderSettings + 'openai-compatible': OpenAICompatibleProviderSettings + anthropic: AnthropicProviderSettings + google: GoogleGenerativeAIProviderSettings + xai: XaiProviderSettings + azure: AzureOpenAIProviderSettings + deepseek: DeepSeekProviderSettings +} + +// 动态扩展的provider类型注册表 +export interface DynamicProviderRegistry { + [key: string]: any +} + +// 合并基础和动态provider类型 +export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry + /** * Provider 相关核心类型定义 * 只定义必要的接口,其他类型直接使用 AI SDK */ -export type ProviderId = keyof ProviderSettingsMap & string // Provider 配置接口 - 支持灵活的创建方式 export interface ProviderConfig { @@ -45,57 +64,23 @@ export class ProviderError extends Error { } } -// 类型安全的 Provider Settings 映射 -export type ProviderSettingsMap = { - openai: OpenAIProviderSettings - 'openai-responses': OpenAIProviderSettings - 'openai-compatible': OpenAICompatibleProviderSettings - // openrouter: OpenRouterProviderSettings - anthropic: AnthropicProviderSettings - google: GoogleGenerativeAIProviderSettings - // 'google-vertex': GoogleVertexProviderSettings - // mistral: MistralProviderSettings - xai: XaiProviderSettings - azure: AzureOpenAIProviderSettings - // bedrock: AmazonBedrockProviderSettings - // cohere: CohereProviderSettings - // groq: GroqProviderSettings - // together: TogetherAIProviderSettings - // fireworks: FireworksProviderSettings - deepseek: DeepSeekProviderSettings - // cerebras: CerebrasProviderSettings - // deepinfra: DeepInfraProviderSettings - // replicate: ReplicateProviderSettings - // perplexity: PerplexityProviderSettings - // fal: FalProviderSettings - // vercel: VercelProviderSettings - // ollama: OllamaProviderSettings - // 'anthropic-vertex': AnthropicVertexProviderSettings +// 动态ProviderId类型 - 支持运行时扩展 +export type ProviderId = keyof ExtensibleProviderSettingsMap | string + +// Provider类型注册工具 +export interface ProviderTypeRegistrar { + registerProviderType(providerId: T, settingsType: S): void + getProviderSettings(providerId: T): any } // 重新导出所有类型供外部使用 export type { - // AmazonBedrockProviderSettings, AnthropicProviderSettings, - // AnthropicVertexProviderSettings, AzureOpenAIProviderSettings, - // CerebrasProviderSettings, - // CohereProviderSettings, - // DeepInfraProviderSettings, DeepSeekProviderSettings, - // FalProviderSettings, - // FireworksProviderSettings, GoogleGenerativeAIProviderSettings, - // GoogleVertexProviderSettings, - // GroqProviderSettings, - // MistralProviderSettings, - // OllamaProviderSettings, OpenAICompatibleProviderSettings, OpenAIProviderSettings, - // OpenRouterProviderSettings, - // PerplexityProviderSettings, - // ReplicateProviderSettings, - // TogetherAIProviderSettings, - // VercelProviderSettings, XaiProviderSettings } +// 新的provider类型已经在上面直接export,不需要重复导出 diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index 1b9b5b153a..9dab6c12f6 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -120,7 +120,19 @@ export { } from './core/options' // ==================== 工具函数 ==================== -export { getAllProviders, getProvider, isProviderSupported, registerProvider } from './core/providers/registry' +export { + getAllDynamicMappings, + getAllProviders, + getDynamicProviders, + getProvider, + getProviderMapping, + isDynamicProvider, + isProviderSupported, + // 动态注册功能 + registerDynamicProvider, + registerMultipleProviders, + registerProvider +} from './core/providers/registry' // ==================== Provider 配置工厂 ==================== export { diff --git a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts index af8b024947..f392d5d160 100644 --- a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts +++ b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts @@ -114,12 +114,12 @@ export class AiSdkToChunkAdapter { } break case 'reasoning-delta': + final.reasoningContent += chunk.text || '' this.onChunk({ type: ChunkType.THINKING_DELTA, text: final.reasoningContent || '', thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0 }) - final.reasoningContent += chunk.text || '' break case 'reasoning-end': this.onChunk({ diff --git a/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts b/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts index 2de808ab87..8b93b5482b 100644 --- a/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts +++ b/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts @@ -4,8 +4,8 @@ * 提供工具调用相关的处理API,每个交互使用一个新的实例 */ -import { ToolCallUnion, ToolResultUnion, ToolSet } from '@cherrystudio/ai-core' -import Logger from '@renderer/config/logger' +import { ToolSet, TypedToolCall, TypedToolResult } from '@cherrystudio/ai-core' +import { loggerService } from '@logger' import { BaseTool, MCPToolResponse, ToolCallResponse } from '@renderer/types' import { Chunk, ChunkType } from '@renderer/types/chunk' import { type ProviderMetadata } from 'ai' @@ -14,6 +14,8 @@ import { type ProviderMetadata } from 'ai' // WebSearchPluginConfig // } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin' +const logger = loggerService.withContext('ToolCallChunkHandler') + /** * 工具调用处理器类 */ @@ -84,7 +86,7 @@ export class ToolCallChunkHandler { case 'tool-input-delta': { const toolCall = this.activeToolCalls.get(chunk.id) if (!toolCall) { - Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`) + logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`) return } toolCall.args += chunk.delta @@ -94,7 +96,7 @@ export class ToolCallChunkHandler { const toolCall = this.activeToolCalls.get(chunk.id) this.activeToolCalls.delete(chunk.id) if (!toolCall) { - Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`) + logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`) return } const toolResponse: ToolCallResponse = { @@ -104,7 +106,7 @@ export class ToolCallChunkHandler { status: 'pending', toolCallId: toolCall.toolCallId } - console.log('toolResponse', toolResponse) + logger.debug('toolResponse', toolResponse) this.onChunk({ type: ChunkType.MCP_TOOL_PENDING, responses: [toolResponse] @@ -134,12 +136,12 @@ export class ToolCallChunkHandler { public handleToolCall( chunk: { type: 'tool-call' - } & ToolCallUnion + } & TypedToolCall ): void { const { toolCallId, toolName, input: args, providerExecuted } = chunk if (!toolCallId || !toolName) { - Logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`) + logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`) return } @@ -148,7 +150,7 @@ export class ToolCallChunkHandler { // 根据 providerExecuted 标志区分处理逻辑 if (providerExecuted) { // 如果是 Provider 执行的工具(如 web_search) - Logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`) + logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`) tool = { id: toolCallId, name: toolName, @@ -157,7 +159,7 @@ export class ToolCallChunkHandler { } } else if (toolName.startsWith('builtin_')) { // 如果是内置工具,沿用现有逻辑 - Logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`) + logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`) tool = { id: toolCallId, name: toolName, @@ -166,10 +168,10 @@ export class ToolCallChunkHandler { } } else { // 如果是客户端执行的 MCP 工具,沿用现有逻辑 - Logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`) + logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`) const mcpTool = this.mcpTools.find((t) => t.name === toolName) if (!mcpTool) { - Logger.warn(`[ToolCallChunkHandler] MCP tool not found: ${toolName}`) + logger.warn(`[ToolCallChunkHandler] MCP tool not found: ${toolName}`) return } tool = mcpTool @@ -207,19 +209,19 @@ export class ToolCallChunkHandler { public handleToolResult( chunk: { type: 'tool-result' - } & ToolResultUnion + } & TypedToolResult ): void { const { toolCallId, output, input } = chunk if (!toolCallId) { - Logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`) + logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`) return } // 查找对应的工具调用信息 const toolCallInfo = this.activeToolCalls.get(toolCallId) if (!toolCallInfo) { - Logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`) + logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`) return } diff --git a/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts b/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts index 6117dffa18..9eb253b287 100644 --- a/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts +++ b/src/renderer/src/aiCore/clients/aws/AwsBedrockAPIClient.ts @@ -37,7 +37,7 @@ import { import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils' import { awsBedrockToolUseToMcpTool, - isEnabledToolUse, + isSupportedToolUse, mcpToolCallResponseToAwsBedrockMessage, mcpToolsToAwsBedrockTools } from '@renderer/utils/mcp-tools' @@ -393,7 +393,7 @@ export class AwsBedrockAPIClient extends BaseApiClient< const { tools } = this.setupToolsConfig({ mcpTools: mcpTools, model, - enableToolUse: isEnabledToolUse(assistant) + enableToolUse: isSupportedToolUse(assistant) }) // 3. 处理消息 diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index e8458e7bb2..bf369dc840 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -69,10 +69,10 @@ function providerToAiSdkConfig(actualProvider: Provider): { const aiSdkProviderId = getAiSdkProviderId(actualProvider) // console.log('aiSdkProviderId', aiSdkProviderId) // 如果provider是openai,则使用strict模式并且默认responses api - const actualProviderId = actualProvider.type + const actualProviderType = actualProvider.type const openaiResponseOptions = // 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses - actualProviderId === 'openai-response' + actualProviderType === 'openai-response' ? { mode: 'responses' } diff --git a/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts b/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts index d4105a136a..8b521abc91 100644 --- a/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts +++ b/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts @@ -13,14 +13,13 @@ export default definePlugin({ return new TransformStream, TextStreamPart>({ transform(chunk: TextStreamPart, controller: TransformStreamDefaultController>) { // === 处理 reasoning 类型 === - if (chunk.type === 'reasoning') { - if (!hasStartedThinking) { - hasStartedThinking = true - thinkingStartTime = performance.now() - reasoningBlockId = chunk.id - } + if (chunk.type === 'reasoning-start') { + controller.enqueue(chunk) + hasStartedThinking = true + thinkingStartTime = performance.now() + reasoningBlockId = chunk.id + } else if (chunk.type === 'reasoning-delta') { accumulatedThinkingContent += chunk.text - controller.enqueue({ ...chunk, providerMetadata: { @@ -32,7 +31,7 @@ export default definePlugin({ } } }) - } else if (hasStartedThinking) { + } else if (chunk.type === 'reasoning-end' && hasStartedThinking) { controller.enqueue({ type: 'reasoning-end', id: reasoningBlockId, @@ -47,28 +46,8 @@ export default definePlugin({ hasStartedThinking = false thinkingStartTime = 0 reasoningBlockId = '' - if (chunk.type !== 'reasoning-end') { - controller.enqueue(chunk) - } } else { - if (chunk.type !== 'reasoning-end') { - controller.enqueue(chunk) - } - } - }, - - flush(controller) { - if (hasStartedThinking) { - controller.enqueue({ - type: 'reasoning-end', - id: reasoningBlockId, - providerMetadata: { - metadata: { - thinking_millsec: performance.now() - thinkingStartTime, - thinking_content: accumulatedThinkingContent - } - } - }) + controller.enqueue(chunk) } } }) diff --git a/src/renderer/src/aiCore/provider/__tests__/integratedRegistry.test.ts b/src/renderer/src/aiCore/provider/__tests__/integratedRegistry.test.ts new file mode 100644 index 0000000000..e26597e2d1 --- /dev/null +++ b/src/renderer/src/aiCore/provider/__tests__/integratedRegistry.test.ts @@ -0,0 +1,103 @@ +import type { Provider } from '@renderer/types' +import { describe, expect, it, vi } from 'vitest' + +import { getAiSdkProviderId } from '../factory' + +// Mock the external dependencies +vi.mock('@cherrystudio/ai-core', () => ({ + registerMultipleProviders: vi.fn(() => 4), // Mock successful registration of 4 providers + getProviderMapping: vi.fn((id: string) => { + // Mock dynamic mappings + const mappings: Record = { + openrouter: 'openrouter', + 'google-vertex': 'google-vertex', + vertexai: 'google-vertex', + bedrock: 'bedrock', + 'aws-bedrock': 'bedrock', + zhipu: 'zhipu' + } + return mappings[id] + }), + AiCore: { + isSupported: vi.fn(() => true) + } +})) + +// Mock the provider configs +vi.mock('../providerConfigs', () => ({ + initializeNewProviders: vi.fn() +})) + +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn() + }) + } +})) + +function createTestProvider(id: string, type: string): Provider { + return { + id, + type, + name: `Test ${id}`, + apiKey: 'test-key', + apiHost: 'test-host' + } as Provider +} + +describe('Integrated Provider Registry', () => { + describe('Provider ID Resolution', () => { + it('should resolve openrouter provider correctly', () => { + const provider = createTestProvider('openrouter', 'openrouter') + const result = getAiSdkProviderId(provider) + expect(result).toBe('openrouter') + }) + + it('should resolve google-vertex provider correctly', () => { + const provider = createTestProvider('google-vertex', 'vertexai') + const result = getAiSdkProviderId(provider) + expect(result).toBe('google-vertex') + }) + + it('should resolve bedrock provider correctly', () => { + const provider = createTestProvider('bedrock', 'aws-bedrock') + const result = getAiSdkProviderId(provider) + expect(result).toBe('bedrock') + }) + + it('should resolve zhipu provider correctly', () => { + const provider = createTestProvider('zhipu', 'zhipu') + const result = getAiSdkProviderId(provider) + expect(result).toBe('zhipu') + }) + + it('should resolve provider type mapping correctly', () => { + const provider = createTestProvider('vertex-test', 'vertexai') + const result = getAiSdkProviderId(provider) + expect(result).toBe('google-vertex') + }) + + it('should handle static provider mappings', () => { + const geminiProvider = createTestProvider('gemini', 'gemini') + const result = getAiSdkProviderId(geminiProvider) + expect(result).toBe('google') + }) + + it('should fallback to provider.id for unknown providers', () => { + const unknownProvider = createTestProvider('unknown-provider', 'unknown-type') + const result = getAiSdkProviderId(unknownProvider) + expect(result).toBe('unknown-provider') + }) + }) + + describe('Backward Compatibility', () => { + it('should maintain compatibility with existing providers', () => { + const grokProvider = createTestProvider('grok', 'grok') + const result = getAiSdkProviderId(grokProvider) + expect(result).toBe('xai') + }) + }) +}) diff --git a/src/renderer/src/aiCore/provider/aihubmix.ts b/src/renderer/src/aiCore/provider/aihubmix.ts index 0a38d3ea1e..14b9a468e0 100644 --- a/src/renderer/src/aiCore/provider/aihubmix.ts +++ b/src/renderer/src/aiCore/provider/aihubmix.ts @@ -10,9 +10,7 @@ export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'opena return 'anthropic' } // TODO:暂时注释,不清楚为什么排除,webSearch时会导致gemini模型走openai的逻辑 - // if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) { - - if (id.startsWith('gemini') || id.startsWith('imagen')) { + if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) { return 'google' } diff --git a/src/renderer/src/aiCore/provider/factory.ts b/src/renderer/src/aiCore/provider/factory.ts index 9fe9804a2a..bf44a8b02c 100644 --- a/src/renderer/src/aiCore/provider/factory.ts +++ b/src/renderer/src/aiCore/provider/factory.ts @@ -1,35 +1,50 @@ -import { AiCore, ProviderId } from '@cherrystudio/ai-core' +import { AiCore, getProviderMapping, type ProviderId } from '@cherrystudio/ai-core' import { Provider } from '@renderer/types' -const PROVIDER_MAPPING: Record = { +import { initializeNewProviders } from './providerConfigs' + +// 初始化新的Provider注册系统 +initializeNewProviders() + +// 静态Provider映射 - 核心providers +const STATIC_PROVIDER_MAPPING: Record = { // anthropic: 'anthropic', gemini: 'google', - vertexai: 'google-vertex', 'azure-openai': 'azure', 'openai-response': 'openai', grok: 'xai' } export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' { - const providerId = PROVIDER_MAPPING[provider.id] - - if (providerId) { - return providerId + // 1. 首先检查静态映射 + const staticProviderId = STATIC_PROVIDER_MAPPING[provider.id] + if (staticProviderId) { + return staticProviderId } - const providerType = PROVIDER_MAPPING[provider.type] // 有些第三方需要映射到aicore对应sdk - - if (providerType) { - return providerType + // 2. 检查动态注册的provider映射(使用aiCore的函数) + const dynamicProviderId = getProviderMapping(provider.id) + if (dynamicProviderId) { + return dynamicProviderId as ProviderId } + // 3. 检查provider.type的静态映射 + const staticProviderType = STATIC_PROVIDER_MAPPING[provider.type] + if (staticProviderType) { + return staticProviderType + } + + // 4. 检查provider.type的动态映射 + const dynamicProviderType = getProviderMapping(provider.type) + if (dynamicProviderType) { + return dynamicProviderType as ProviderId + } + + // 5. 检查AiCore是否直接支持 if (AiCore.isSupported(provider.id)) { return provider.id as ProviderId } - // 先注释掉,会影响获取providerOptions - // if (AiCore.isSupported(provider.type)) { - // return provider.type as ProviderId - // } + // 6. 最后的fallback return provider.id as ProviderId } diff --git a/src/renderer/src/aiCore/provider/providerConfigs.ts b/src/renderer/src/aiCore/provider/providerConfigs.ts new file mode 100644 index 0000000000..39bdce9551 --- /dev/null +++ b/src/renderer/src/aiCore/provider/providerConfigs.ts @@ -0,0 +1,63 @@ +import type { ProviderConfig } from '@cherrystudio/ai-core' +import { loggerService } from '@logger' + +const logger = loggerService.withContext('ProviderConfigs') + +/** + * 新Provider配置定义 + * 定义了需要动态注册的AI Providers + */ +export const NEW_PROVIDER_CONFIGS: (ProviderConfig & { + mappings?: Record +})[] = [ + { + id: 'openrouter', + name: 'OpenRouter', + import: () => import('@openrouter/ai-sdk-provider'), + creatorFunctionName: 'createOpenRouter', + supportsImageGeneration: true, + mappings: { + openrouter: 'openrouter' + } + }, + { + id: 'google-vertex', + name: 'Google Vertex AI', + import: () => import('@ai-sdk/google-vertex'), + creatorFunctionName: 'createGoogleVertex', + supportsImageGeneration: true, + mappings: { + 'google-vertex': 'google-vertex', + vertexai: 'google-vertex' + } + }, + { + id: 'bedrock', + name: 'Amazon Bedrock', + import: () => import('@ai-sdk/amazon-bedrock'), + creatorFunctionName: 'createAmazonBedrock', + supportsImageGeneration: true, + mappings: { + 'aws-bedrock': 'bedrock' + } + } +] as const + +/** + * 初始化新的Providers + * 使用aiCore的动态注册功能 + */ +export async function initializeNewProviders(): Promise { + try { + // 动态导入以避免循环依赖 + const { registerMultipleProviders } = await import('@cherrystudio/ai-core') + + const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS) + + if (successCount < NEW_PROVIDER_CONFIGS.length) { + logger.warn('Some providers failed to register. Check previous error logs.') + } + } catch (error) { + logger.error('Failed to initialize new providers:', error as Error) + } +} diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts index 1ef053b4ff..739f5624e4 100644 --- a/src/renderer/src/aiCore/transformParameters.ts +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -46,7 +46,6 @@ import { findThinkingBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' -import { buildSystemPrompt } from '@renderer/utils/prompt' import { defaultTimeout } from '@shared/config/constant' import { isEmpty } from 'lodash' @@ -79,17 +78,6 @@ export function getTimeout(model: Model): number { return defaultTimeout } -/** - * 构建系统提示词 - */ -export async function buildSystemPromptWithTools( - prompt: string, - mcpTools?: MCPTool[], - assistant?: Assistant -): Promise { - return await buildSystemPrompt(prompt, mcpTools, assistant) -} - /** * 提取文件内容 */ diff --git a/src/renderer/src/config/constant.ts b/src/renderer/src/config/constant.ts index 37619a7484..c1e6f426d9 100644 --- a/src/renderer/src/config/constant.ts +++ b/src/renderer/src/config/constant.ts @@ -5,7 +5,6 @@ export const SYSTEM_PROMPT_THRESHOLD = 128 export const DEFAULT_KNOWLEDGE_DOCUMENT_COUNT = 6 export const DEFAULT_KNOWLEDGE_THRESHOLD = 0.0 export const DEFAULT_WEBSEARCH_RAG_DOCUMENT_COUNT = 1 -export const SYSTEM_PROMPT_THRESHOLD = 128 export const platform = window.electron?.process?.platform export const isMac = platform === 'darwin' diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 19622f207f..46ddceb4c8 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -3,6 +3,7 @@ */ import { StreamTextParams } from '@cherrystudio/ai-core' import { loggerService } from '@logger' +import AiProvider from '@renderer/aiCore' import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder' import { CompletionsParams } from '@renderer/aiCore/middleware/schemas' import { buildStreamTextParams } from '@renderer/aiCore/transformParameters' @@ -26,7 +27,7 @@ import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt' import { isEmpty, takeRight } from 'lodash' -import AiProvider from '../aiCore' +import AiProviderNew from '../aiCore/index_new' import { // getAssistantProvider, // getAssistantSettings, @@ -395,11 +396,8 @@ export async function fetchChatCompletion({ }) { console.log('fetchChatCompletion', messages, assistant) - const provider = getAssistantProvider(assistant) - const AI = new AiProvider(provider) - - // Make sure that 'Clear Context' works for all scenarios including external tool and normal chat. - messages = filterContextMessages(messages) + const AI = new AiProviderNew(assistant.model || getDefaultModel()) + const provider = AI.getActualProvider() const mcpTools: MCPTool[] = [] diff --git a/src/renderer/src/utils/mcp-tools.ts b/src/renderer/src/utils/mcp-tools.ts index 4eb043a6ac..59db8a8c14 100644 --- a/src/renderer/src/utils/mcp-tools.ts +++ b/src/renderer/src/utils/mcp-tools.ts @@ -8,6 +8,7 @@ import store from '@renderer/store' import { addMCPServer } from '@renderer/store/mcp' import { Assistant, + BaseTool, MCPCallToolResponse, MCPServer, MCPTool, @@ -492,12 +493,13 @@ export function getMcpServerByTool(tool: MCPTool) { return servers.find((s) => s.id === tool.serverId) } -export function isToolAutoApproved(tool: MCPTool, server?: MCPServer): boolean { - if (tool.isBuiltIn) { +export function isToolAutoApproved(tool: BaseTool, server?: MCPServer): boolean { + if (tool.type === 'builtin') { return true } - const effectiveServer = server ?? getMcpServerByTool(tool) - return effectiveServer ? !effectiveServer.disabledAutoApproveTools?.includes(tool.name) : false + const mcpTool = tool as MCPTool + const effectiveServer = server ?? getMcpServerByTool(mcpTool) + return effectiveServer ? !effectiveServer.disabledAutoApproveTools?.includes(mcpTool.name) : false } export function parseToolUse(content: string, mcpTools: MCPTool[], startIdx: number = 0): ToolUseResponse[] {