From a5e7aa1342f9b0f51dce1a42f698038f7383e872 Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 27 Nov 2025 15:30:19 +0800 Subject: [PATCH 01/53] feat: Implement shared provider utilities and API host formatting - Added provider API host formatting utilities to handle differences between Cherry Studio and AI SDK. - Introduced functions for formatting provider API hosts, including support for Azure OpenAI and Vertex AI. - Created a simple API key rotator for managing API key rotation. - Developed shared provider initialization and mapping utilities for resolving provider IDs. - Implemented AI SDK configuration utilities for converting Cherry Studio providers to AI SDK configurations. - Added support for various providers including OpenRouter, Google Vertex AI, and Amazon Bedrock. - Enhanced error handling and logging in the unified messages service for better debugging. - Introduced functions for streaming and generating unified messages using AI SDK. --- .../shared/adapters/AiSdkToAnthropicSSE.ts | 593 ++++++++++++++++++ packages/shared/adapters/index.ts | 13 + packages/shared/api/index.ts | 173 +++++ .../shared}/provider/config/aihubmix.ts | 29 +- .../shared/provider/config/azure-anthropic.ts | 22 + packages/shared/provider/config/helper.ts | 32 + packages/shared/provider/config/index.ts | 6 + .../shared}/provider/config/newApi.ts | 24 +- packages/shared/provider/config/types.ts | 9 + .../provider/config/vertex-anthropic.ts | 19 + packages/shared/provider/detection.ts | 100 +++ packages/shared/provider/format.ts | 136 ++++ packages/shared/provider/index.ts | 48 ++ packages/shared/provider/initialization.ts | 107 ++++ packages/shared/provider/mapping.ts | 95 +++ packages/shared/provider/resolve.ts | 44 ++ packages/shared/provider/sdk-config.ts | 240 +++++++ packages/shared/provider/types.ts | 174 +++++ packages/shared/utils/index.ts | 1 + packages/shared/utils/naming.ts | 31 + src/main/apiServer/routes/messages.ts | 241 ++++++- src/main/apiServer/services/models.ts | 27 +- .../apiServer/services/unified-messages.ts | 455 ++++++++++++++ .../agents/services/claudecode/index.ts | 29 +- .../legacy/clients/gemini/VertexAPIClient.ts | 2 +- .../aiCore/provider/config/azure-anthropic.ts | 22 - .../src/aiCore/provider/config/helper.ts | 22 - .../src/aiCore/provider/config/index.ts | 10 +- .../src/aiCore/provider/config/types.ts | 9 - .../provider/config/vertext-anthropic.ts | 19 - src/renderer/src/aiCore/provider/factory.ts | 63 +- .../src/aiCore/provider/providerConfig.ts | 227 ++----- .../aiCore/provider/providerInitialization.ts | 112 +--- src/renderer/src/types/index.ts | 4 +- src/renderer/src/types/provider.ts | 158 +---- src/renderer/src/utils/api.ts | 178 +----- src/renderer/src/utils/naming.ts | 34 +- src/renderer/src/utils/provider.ts | 69 +- 38 files changed, 2681 insertions(+), 896 deletions(-) create mode 100644 packages/shared/adapters/AiSdkToAnthropicSSE.ts create mode 100644 packages/shared/adapters/index.ts create mode 100644 packages/shared/api/index.ts rename {src/renderer/src/aiCore => packages/shared}/provider/config/aihubmix.ts (53%) create mode 100644 packages/shared/provider/config/azure-anthropic.ts create mode 100644 packages/shared/provider/config/helper.ts create mode 100644 packages/shared/provider/config/index.ts rename {src/renderer/src/aiCore => packages/shared}/provider/config/newApi.ts (52%) create mode 100644 packages/shared/provider/config/types.ts create mode 100644 packages/shared/provider/config/vertex-anthropic.ts create mode 100644 packages/shared/provider/detection.ts create mode 100644 packages/shared/provider/format.ts create mode 100644 packages/shared/provider/index.ts create mode 100644 packages/shared/provider/initialization.ts create mode 100644 packages/shared/provider/mapping.ts create mode 100644 packages/shared/provider/resolve.ts create mode 100644 packages/shared/provider/sdk-config.ts create mode 100644 packages/shared/provider/types.ts create mode 100644 packages/shared/utils/index.ts create mode 100644 packages/shared/utils/naming.ts create mode 100644 src/main/apiServer/services/unified-messages.ts delete mode 100644 src/renderer/src/aiCore/provider/config/azure-anthropic.ts delete mode 100644 src/renderer/src/aiCore/provider/config/helper.ts delete mode 100644 src/renderer/src/aiCore/provider/config/types.ts delete mode 100644 src/renderer/src/aiCore/provider/config/vertext-anthropic.ts diff --git a/packages/shared/adapters/AiSdkToAnthropicSSE.ts b/packages/shared/adapters/AiSdkToAnthropicSSE.ts new file mode 100644 index 0000000000..38fab703ac --- /dev/null +++ b/packages/shared/adapters/AiSdkToAnthropicSSE.ts @@ -0,0 +1,593 @@ +/** + * AI SDK to Anthropic SSE Adapter + * + * Converts AI SDK's fullStream (TextStreamPart) events to Anthropic Messages API SSE format. + * This enables any AI provider supported by AI SDK to be exposed via Anthropic-compatible API. + * + * Anthropic SSE Event Flow: + * 1. message_start - Initial message with metadata + * 2. content_block_start - Begin a content block (text, tool_use, thinking) + * 3. content_block_delta - Incremental content updates + * 4. content_block_stop - End a content block + * 5. message_delta - Updates to overall message (stop_reason, usage) + * 6. message_stop - Stream complete + * + * @see https://docs.anthropic.com/en/api/messages-streaming + */ + +import type { + ContentBlock, + InputJSONDelta, + Message, + MessageDeltaUsage, + RawContentBlockDeltaEvent, + RawContentBlockStartEvent, + RawContentBlockStopEvent, + RawMessageDeltaEvent, + RawMessageStartEvent, + RawMessageStopEvent, + RawMessageStreamEvent, + StopReason, + TextBlock, + TextDelta, + ThinkingBlock, + ThinkingDelta, + ToolUseBlock, + Usage +} from '@anthropic-ai/sdk/resources/messages' +import { loggerService } from '@logger' +import type { TextStreamPart, ToolSet } from 'ai' + +const logger = loggerService.withContext('AiSdkToAnthropicSSE') + +interface ContentBlockState { + type: 'text' | 'tool_use' | 'thinking' + index: number + started: boolean + content: string + // For tool_use blocks + toolId?: string + toolName?: string + toolInput?: string +} + +interface AdapterState { + messageId: string + model: string + inputTokens: number + outputTokens: number + currentBlockIndex: number + blocks: Map + textBlockIndex: number | null + thinkingBlockIndex: number | null + toolBlocks: Map // toolCallId -> blockIndex + stopReason: StopReason | null + hasEmittedMessageStart: boolean +} + +// ============================================================================ +// Adapter Class +// ============================================================================ + +export type SSEEventCallback = (event: RawMessageStreamEvent) => void + +export interface AiSdkToAnthropicSSEOptions { + model: string + messageId?: string + inputTokens?: number + onEvent: SSEEventCallback +} + +/** + * Adapter that converts AI SDK fullStream events to Anthropic SSE events + */ +export class AiSdkToAnthropicSSE { + private state: AdapterState + private onEvent: SSEEventCallback + + constructor(options: AiSdkToAnthropicSSEOptions) { + this.onEvent = options.onEvent + this.state = { + messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`, + model: options.model, + inputTokens: options.inputTokens || 0, + outputTokens: 0, + currentBlockIndex: 0, + blocks: new Map(), + textBlockIndex: null, + thinkingBlockIndex: null, + toolBlocks: new Map(), + stopReason: null, + hasEmittedMessageStart: false + } + } + + /** + * Process the AI SDK stream and emit Anthropic SSE events + */ + async processStream(fullStream: ReadableStream>): Promise { + const reader = fullStream.getReader() + + try { + // Emit message_start at the beginning + this.emitMessageStart() + + while (true) { + const { done, value } = await reader.read() + + if (done) { + break + } + + this.processChunk(value) + } + + // Ensure all blocks are closed and emit final events + this.finalize() + } finally { + reader.releaseLock() + } + } + + /** + * Process a single AI SDK chunk and emit corresponding Anthropic events + */ + private processChunk(chunk: TextStreamPart): void { + logger.silly('AiSdkToAnthropicSSE - Processing chunk:', chunk) + switch (chunk.type) { + // === Text Events === + case 'text-start': + this.startTextBlock() + break + + case 'text-delta': + this.emitTextDelta(chunk.text || '') + break + + case 'text-end': + this.stopTextBlock() + break + + // === Reasoning/Thinking Events === + case 'reasoning-start': + this.startThinkingBlock() + break + + case 'reasoning-delta': + this.emitThinkingDelta(chunk.text || '') + break + + case 'reasoning-end': + this.stopThinkingBlock() + break + + // === Tool Events === + case 'tool-call': + this.handleToolCall({ + type: 'tool-call', + toolCallId: chunk.toolCallId, + toolName: chunk.toolName, + // AI SDK uses 'args' in some versions and 'input' in others + args: 'args' in chunk ? chunk.args : (chunk as any).input + }) + break + + case 'tool-result': + // Tool results are handled separately in Anthropic API + // They come from user messages, not assistant stream + break + + // === Completion Events === + case 'finish-step': + if (chunk.finishReason === 'tool-calls') { + this.state.stopReason = 'tool_use' + } + break + + case 'finish': + this.handleFinish(chunk) + break + + // === Error Events === + case 'error': + // Anthropic doesn't have a standard error event in the stream + // Errors are typically sent as separate HTTP responses + // For now, we'll just log and continue + break + + // Ignore other event types + default: + break + } + } + + private emitMessageStart(): void { + if (this.state.hasEmittedMessageStart) return + + this.state.hasEmittedMessageStart = true + + const usage: Usage = { + input_tokens: this.state.inputTokens, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + server_tool_use: null + } + + const message: Message = { + id: this.state.messageId, + type: 'message', + role: 'assistant', + content: [], + model: this.state.model, + stop_reason: null, + stop_sequence: null, + usage + } + + const event: RawMessageStartEvent = { + type: 'message_start', + message + } + + this.onEvent(event) + } + + private startTextBlock(): void { + // If we already have a text block, don't create another + if (this.state.textBlockIndex !== null) return + + const index = this.state.currentBlockIndex++ + this.state.textBlockIndex = index + this.state.blocks.set(index, { + type: 'text', + index, + started: true, + content: '' + }) + + const contentBlock: TextBlock = { + type: 'text', + text: '', + citations: null + } + + const event: RawContentBlockStartEvent = { + type: 'content_block_start', + index, + content_block: contentBlock + } + + this.onEvent(event) + } + + private emitTextDelta(text: string): void { + if (!text) return + + // Auto-start text block if not started + if (this.state.textBlockIndex === null) { + this.startTextBlock() + } + + const index = this.state.textBlockIndex! + const block = this.state.blocks.get(index) + if (block) { + block.content += text + } + + const delta: TextDelta = { + type: 'text_delta', + text + } + + const event: RawContentBlockDeltaEvent = { + type: 'content_block_delta', + index, + delta + } + + this.onEvent(event) + } + + private stopTextBlock(): void { + if (this.state.textBlockIndex === null) return + + const index = this.state.textBlockIndex + + const event: RawContentBlockStopEvent = { + type: 'content_block_stop', + index + } + + this.onEvent(event) + this.state.textBlockIndex = null + } + + private startThinkingBlock(): void { + if (this.state.thinkingBlockIndex !== null) return + + const index = this.state.currentBlockIndex++ + this.state.thinkingBlockIndex = index + this.state.blocks.set(index, { + type: 'thinking', + index, + started: true, + content: '' + }) + + const contentBlock: ThinkingBlock = { + type: 'thinking', + thinking: '', + signature: '' + } + + const event: RawContentBlockStartEvent = { + type: 'content_block_start', + index, + content_block: contentBlock + } + + this.onEvent(event) + } + + private emitThinkingDelta(text: string): void { + if (!text) return + + // Auto-start thinking block if not started + if (this.state.thinkingBlockIndex === null) { + this.startThinkingBlock() + } + + const index = this.state.thinkingBlockIndex! + const block = this.state.blocks.get(index) + if (block) { + block.content += text + } + + const delta: ThinkingDelta = { + type: 'thinking_delta', + thinking: text + } + + const event: RawContentBlockDeltaEvent = { + type: 'content_block_delta', + index, + delta + } + + this.onEvent(event) + } + + private stopThinkingBlock(): void { + if (this.state.thinkingBlockIndex === null) return + + const index = this.state.thinkingBlockIndex + + const event: RawContentBlockStopEvent = { + type: 'content_block_stop', + index + } + + this.onEvent(event) + this.state.thinkingBlockIndex = null + } + + private handleToolCall(chunk: { type: 'tool-call'; toolCallId: string; toolName: string; args: unknown }): void { + const { toolCallId, toolName, args } = chunk + + // Check if we already have this tool call + if (this.state.toolBlocks.has(toolCallId)) { + return + } + + const index = this.state.currentBlockIndex++ + this.state.toolBlocks.set(toolCallId, index) + + const inputJson = JSON.stringify(args) + + this.state.blocks.set(index, { + type: 'tool_use', + index, + started: true, + content: inputJson, + toolId: toolCallId, + toolName, + toolInput: inputJson + }) + + // Emit content_block_start for tool_use + const contentBlock: ToolUseBlock = { + type: 'tool_use', + id: toolCallId, + name: toolName, + input: {} + } + + const startEvent: RawContentBlockStartEvent = { + type: 'content_block_start', + index, + content_block: contentBlock + } + + this.onEvent(startEvent) + + // Emit the full input as a delta (Anthropic streams JSON incrementally) + const delta: InputJSONDelta = { + type: 'input_json_delta', + partial_json: inputJson + } + + const deltaEvent: RawContentBlockDeltaEvent = { + type: 'content_block_delta', + index, + delta + } + + this.onEvent(deltaEvent) + + // Emit content_block_stop + const stopEvent: RawContentBlockStopEvent = { + type: 'content_block_stop', + index + } + + this.onEvent(stopEvent) + + // Mark that we have tool use + this.state.stopReason = 'tool_use' + } + + private handleFinish(chunk: { + type: 'finish' + finishReason?: string + totalUsage?: { + inputTokens?: number + outputTokens?: number + } + }): void { + // Update usage + if (chunk.totalUsage) { + this.state.inputTokens = chunk.totalUsage.inputTokens || 0 + this.state.outputTokens = chunk.totalUsage.outputTokens || 0 + } + + // Determine finish reason + if (!this.state.stopReason) { + switch (chunk.finishReason) { + case 'stop': + case 'end_turn': + this.state.stopReason = 'end_turn' + break + case 'length': + case 'max_tokens': + this.state.stopReason = 'max_tokens' + break + case 'tool-calls': + this.state.stopReason = 'tool_use' + break + default: + this.state.stopReason = 'end_turn' + } + } + } + + private finalize(): void { + // Close any open blocks + if (this.state.textBlockIndex !== null) { + this.stopTextBlock() + } + if (this.state.thinkingBlockIndex !== null) { + this.stopThinkingBlock() + } + + // Emit message_delta with final stop reason and usage + const usage: MessageDeltaUsage = { + output_tokens: this.state.outputTokens, + input_tokens: null, + cache_creation_input_tokens: null, + cache_read_input_tokens: null, + server_tool_use: null + } + + const messageDeltaEvent: RawMessageDeltaEvent = { + type: 'message_delta', + delta: { + stop_reason: this.state.stopReason || 'end_turn', + stop_sequence: null + }, + usage + } + + this.onEvent(messageDeltaEvent) + + // Emit message_stop + const messageStopEvent: RawMessageStopEvent = { + type: 'message_stop' + } + + this.onEvent(messageStopEvent) + } + + /** + * Set input token count (typically from prompt) + */ + setInputTokens(count: number): void { + this.state.inputTokens = count + } + + /** + * Get the current message ID + */ + getMessageId(): string { + return this.state.messageId + } + + /** + * Build a complete Message object for non-streaming responses + */ + buildNonStreamingResponse(): Message { + const content: ContentBlock[] = [] + + // Collect all content blocks in order + const sortedBlocks = Array.from(this.state.blocks.values()).sort((a, b) => a.index - b.index) + + for (const block of sortedBlocks) { + switch (block.type) { + case 'text': + content.push({ + type: 'text', + text: block.content, + citations: null + } as TextBlock) + break + case 'thinking': + content.push({ + type: 'thinking', + thinking: block.content + } as ThinkingBlock) + break + case 'tool_use': + content.push({ + type: 'tool_use', + id: block.toolId!, + name: block.toolName!, + input: JSON.parse(block.toolInput || '{}') + } as ToolUseBlock) + break + } + } + + return { + id: this.state.messageId, + type: 'message', + role: 'assistant', + content, + model: this.state.model, + stop_reason: this.state.stopReason || 'end_turn', + stop_sequence: null, + usage: { + input_tokens: this.state.inputTokens, + output_tokens: this.state.outputTokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + server_tool_use: null + } + } + } +} + +/** + * Format an Anthropic SSE event for HTTP streaming + */ +export function formatSSEEvent(event: RawMessageStreamEvent): string { + return `event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n` +} + +/** + * Create a done marker for SSE stream + */ +export function formatSSEDone(): string { + return 'data: [DONE]\n\n' +} + +export default AiSdkToAnthropicSSE diff --git a/packages/shared/adapters/index.ts b/packages/shared/adapters/index.ts new file mode 100644 index 0000000000..a19db9594e --- /dev/null +++ b/packages/shared/adapters/index.ts @@ -0,0 +1,13 @@ +/** + * Shared Adapters + * + * This module exports adapters for converting between different AI API formats. + */ + +export { + AiSdkToAnthropicSSE, + type AiSdkToAnthropicSSEOptions, + formatSSEDone, + formatSSEEvent, + type SSEEventCallback +} from './AiSdkToAnthropicSSE' diff --git a/packages/shared/api/index.ts b/packages/shared/api/index.ts new file mode 100644 index 0000000000..0cf652b427 --- /dev/null +++ b/packages/shared/api/index.ts @@ -0,0 +1,173 @@ +/** + * Shared API Utilities + * + * Common utilities for API URL formatting and validation. + * Used by both main process (API Server) and renderer. + */ + +import type { MinimalProvider } from '@shared/provider' +import { trim } from 'lodash' + +// Supported endpoints for routing +export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const +export const SUPPORTED_ENDPOINT_LIST = [ + 'chat/completions', + 'responses', + 'messages', + 'generateContent', + 'streamGenerateContent', + ...SUPPORTED_IMAGE_ENDPOINT_LIST +] as const + +/** + * Removes the trailing slash from a URL string if it exists. + */ +export function withoutTrailingSlash(url: T): T { + return url.replace(/\/$/, '') as T +} + +/** + * Checks if the host path contains a version string (e.g., /v1, /v2beta). + */ +export function hasAPIVersion(host?: string): boolean { + if (!host) return false + + const versionRegex = /\/v\d+(?:alpha|beta)?(?=\/|$)/i + + try { + const url = new URL(host) + return versionRegex.test(url.pathname) + } catch { + return versionRegex.test(host) + } +} + +/** + * 格式化 Azure OpenAI 的 API 主机地址。 + */ +export function formatAzureOpenAIApiHost(host: string): string { + const normalizedHost = withoutTrailingSlash(host) + ?.replace(/\/v1$/, '') + .replace(/\/openai$/, '') + // NOTE: AISDK会添加上`v1` + return formatApiHost(normalizedHost + '/openai', false) +} + +export function formatVertexApiHost(provider: MinimalProvider, project: string, location: string): string { + const { apiHost } = provider + const trimmedHost = withoutTrailingSlash(trim(apiHost)) + if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) { + const host = + location == 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com` + return `${formatApiHost(host)}/projects/${project}/locations/${location}` + } + return formatApiHost(trimmedHost) +} + +/** + * Formats an API host URL by normalizing it and optionally appending an API version. + * + * @param host - The API host URL to format + * @param isSupportedAPIVersion - Whether the API version is supported. Defaults to `true`. + * @param apiVersion - The API version to append if needed. Defaults to `'v1'`. + * + * @example + * formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1' + * formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#' + * formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2' + */ +export function formatApiHost(host?: string, isSupportedAPIVersion: boolean = true, apiVersion: string = 'v1'): string { + const normalizedHost = withoutTrailingSlash((host || '').trim()) + if (!normalizedHost) { + return '' + } + + if (normalizedHost.endsWith('#') || !isSupportedAPIVersion || hasAPIVersion(normalizedHost)) { + return normalizedHost + } + return `${normalizedHost}/${apiVersion}` +} + +/** + * Converts an API host URL into separate base URL and endpoint components. + * + * This function extracts endpoint information from a composite API host string. + * If the host ends with '#', it attempts to match the preceding part against the supported endpoint list. + * + * @param apiHost - The API host string to parse + * @returns An object containing: + * - `baseURL`: The base URL without the endpoint suffix + * - `endpoint`: The matched endpoint identifier, or empty string if no match found + * + * @example + * routeToEndpoint('https://api.example.com/openai/chat/completions#') + * // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' } + * + * @example + * routeToEndpoint('https://api.example.com/v1') + * // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' } + */ +export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } { + const trimmedHost = (apiHost || '').trim() + if (!trimmedHost.endsWith('#')) { + return { baseURL: trimmedHost, endpoint: '' } + } + // Remove trailing # + const host = trimmedHost.slice(0, -1) + const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint)) + if (!endpointMatch) { + const baseURL = withoutTrailingSlash(host) + return { baseURL, endpoint: '' } + } + const baseSegment = host.slice(0, host.length - endpointMatch.length) + const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // Remove trailing colon (gemini special case) + return { baseURL, endpoint: endpointMatch } +} + +/** + * Gets the AI SDK compatible base URL from a provider's apiHost. + * + * AI SDK expects baseURL WITH version suffix (e.g., /v1). + * This function: + * 1. Handles '#' endpoint routing format + * 2. Ensures the URL has a version suffix (adds /v1 if missing) + * + * @param apiHost - The provider's apiHost value (may or may not have /v1) + * @param apiVersion - The API version to use if missing. Defaults to 'v1'. + * @returns The baseURL suitable for AI SDK (with version suffix) + * + * @example + * getAiSdkBaseUrl('https://api.openai.com') // 'https://api.openai.com/v1' + * getAiSdkBaseUrl('https://api.openai.com/v1') // 'https://api.openai.com/v1' + * getAiSdkBaseUrl('https://api.example.com/chat/completions#') // 'https://api.example.com' + */ +export function getAiSdkBaseUrl(apiHost: string, apiVersion: string = 'v1'): string { + // First handle '#' endpoint routing format + const { baseURL } = routeToEndpoint(apiHost) + + // If already has version, return as-is + if (hasAPIVersion(baseURL)) { + return withoutTrailingSlash(baseURL) + } + + // Add version suffix + return `${withoutTrailingSlash(baseURL)}/${apiVersion}` +} + +/** + * Validates an API host address. + * + * @param apiHost - The API host address to validate + * @returns true if valid URL with http/https protocol, false otherwise + */ +export function validateApiHost(apiHost: string): boolean { + if (!apiHost || !apiHost.trim()) { + return true // Allow empty + } + try { + const url = new URL(apiHost.trim()) + return url.protocol === 'http:' || url.protocol === 'https:' + } catch { + return false + } +} diff --git a/src/renderer/src/aiCore/provider/config/aihubmix.ts b/packages/shared/provider/config/aihubmix.ts similarity index 53% rename from src/renderer/src/aiCore/provider/config/aihubmix.ts rename to packages/shared/provider/config/aihubmix.ts index 8feed89909..5214e8d06a 100644 --- a/src/renderer/src/aiCore/provider/config/aihubmix.ts +++ b/packages/shared/provider/config/aihubmix.ts @@ -1,13 +1,13 @@ /** * AiHubMix规则集 */ -import { isOpenAILLMModel } from '@renderer/config/models' -import type { Provider } from '@renderer/types' +import { getLowerBaseModelName } from '@shared/utils/naming' +import type { MinimalModel, MinimalProvider } from '../types' import { provider2Provider, startsWith } from './helper' import type { RuleSet } from './types' -const extraProviderConfig = (provider: Provider) => { +const extraProviderConfig =

(provider: P) => { return { ...provider, extra_headers: { @@ -17,11 +17,23 @@ const extraProviderConfig = (provider: Provider) => { } } +function isOpenAILLMModel(model: M): boolean { + const modelId = getLowerBaseModelName(model.id) + const reasonings = ['o1', 'o3', 'o4', 'gpt-oss'] + if (reasonings.some((r) => modelId.includes(r))) { + return true + } + if (modelId.includes('gpt')) { + return true + } + return false +} + const AIHUBMIX_RULES: RuleSet = { rules: [ { match: startsWith('claude'), - provider: (provider: Provider) => { + provider: (provider) => { return extraProviderConfig({ ...provider, type: 'anthropic' @@ -34,7 +46,7 @@ const AIHUBMIX_RULES: RuleSet = { !model.id.endsWith('-nothink') && !model.id.endsWith('-search') && !model.id.includes('embedding'), - provider: (provider: Provider) => { + provider: (provider) => { return extraProviderConfig({ ...provider, type: 'gemini', @@ -44,7 +56,7 @@ const AIHUBMIX_RULES: RuleSet = { }, { match: isOpenAILLMModel, - provider: (provider: Provider) => { + provider: (provider) => { return extraProviderConfig({ ...provider, type: 'openai-response' @@ -52,7 +64,8 @@ const AIHUBMIX_RULES: RuleSet = { } } ], - fallbackRule: (provider: Provider) => extraProviderConfig(provider) + fallbackRule: (provider) => extraProviderConfig(provider) } -export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES) +export const aihubmixProviderCreator =

(model: MinimalModel, provider: P): P => + provider2Provider(AIHUBMIX_RULES, model, provider) diff --git a/packages/shared/provider/config/azure-anthropic.ts b/packages/shared/provider/config/azure-anthropic.ts new file mode 100644 index 0000000000..e176614df3 --- /dev/null +++ b/packages/shared/provider/config/azure-anthropic.ts @@ -0,0 +1,22 @@ +import type { MinimalModel, MinimalProvider, ProviderType } from '../types' +import { provider2Provider, startsWith } from './helper' +import type { RuleSet } from './types' + +// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry +const AZURE_ANTHROPIC_RULES: RuleSet = { + rules: [ + { + match: startsWith('claude'), + provider: (provider: MinimalProvider) => ({ + ...provider, + type: 'anthropic' as ProviderType, + apiHost: provider.apiHost + 'anthropic/v1', + id: 'azure-anthropic' + }) + } + ], + fallbackRule: (provider: MinimalProvider) => provider +} + +export const azureAnthropicProviderCreator =

(model: MinimalModel, provider: P): P => + provider2Provider(AZURE_ANTHROPIC_RULES, model, provider) diff --git a/packages/shared/provider/config/helper.ts b/packages/shared/provider/config/helper.ts new file mode 100644 index 0000000000..4e821a6c8f --- /dev/null +++ b/packages/shared/provider/config/helper.ts @@ -0,0 +1,32 @@ +import type { MinimalModel, MinimalProvider } from '../types' +import type { RuleSet } from './types' + +export const startsWith = + (prefix: string) => + (model: M) => + model.id.toLowerCase().startsWith(prefix.toLowerCase()) + +export const endpointIs = + (type: string) => + (model: M) => + model.endpoint_type === type + +/** + * 解析模型对应的Provider + * @param ruleSet 规则集对象 + * @param model 模型对象 + * @param provider 原始provider对象 + * @returns 解析出的provider对象 + */ +export function provider2Provider< + M extends MinimalModel, + R extends MinimalProvider, + P extends R = R +>(ruleSet: RuleSet, model: M, provider: P): P { + for (const rule of ruleSet.rules) { + if (rule.match(model)) { + return rule.provider(provider) as P + } + } + return ruleSet.fallbackRule(provider) as P +} diff --git a/packages/shared/provider/config/index.ts b/packages/shared/provider/config/index.ts new file mode 100644 index 0000000000..1273319ecd --- /dev/null +++ b/packages/shared/provider/config/index.ts @@ -0,0 +1,6 @@ +export { aihubmixProviderCreator } from './aihubmix' +export { azureAnthropicProviderCreator } from './azure-anthropic' +export { endpointIs, provider2Provider, startsWith } from './helper' +export { newApiResolverCreator } from './newApi' +export type { RuleSet } from './types' +export { vertexAnthropicProviderCreator } from './vertex-anthropic' diff --git a/src/renderer/src/aiCore/provider/config/newApi.ts b/packages/shared/provider/config/newApi.ts similarity index 52% rename from src/renderer/src/aiCore/provider/config/newApi.ts rename to packages/shared/provider/config/newApi.ts index 97de62597d..fd1b74085f 100644 --- a/src/renderer/src/aiCore/provider/config/newApi.ts +++ b/packages/shared/provider/config/newApi.ts @@ -1,8 +1,7 @@ /** * NewAPI规则集 */ -import type { Provider } from '@renderer/types' - +import type { MinimalModel, MinimalProvider, ProviderType } from '../types' import { endpointIs, provider2Provider } from './helper' import type { RuleSet } from './types' @@ -10,42 +9,43 @@ const NEWAPI_RULES: RuleSet = { rules: [ { match: endpointIs('anthropic'), - provider: (provider: Provider) => { + provider: (provider) => { return { ...provider, - type: 'anthropic' + type: 'anthropic' as ProviderType } } }, { match: endpointIs('gemini'), - provider: (provider: Provider) => { + provider: (provider) => { return { ...provider, - type: 'gemini' + type: 'gemini' as ProviderType } } }, { match: endpointIs('openai-response'), - provider: (provider: Provider) => { + provider: (provider) => { return { ...provider, - type: 'openai-response' + type: 'openai-response' as ProviderType } } }, { match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model), - provider: (provider: Provider) => { + provider: (provider) => { return { ...provider, - type: 'openai' + type: 'openai' as ProviderType } } } ], - fallbackRule: (provider: Provider) => provider + fallbackRule: (provider) => provider } -export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES) +export const newApiResolverCreator =

(model: MinimalModel, provider: P): P => + provider2Provider(NEWAPI_RULES, model, provider) diff --git a/packages/shared/provider/config/types.ts b/packages/shared/provider/config/types.ts new file mode 100644 index 0000000000..fdb1309869 --- /dev/null +++ b/packages/shared/provider/config/types.ts @@ -0,0 +1,9 @@ +import type { MinimalModel, MinimalProvider } from '../types' + +export interface RuleSet { + rules: Array<{ + match: (model: M) => boolean + provider: (provider: P) => P + }> + fallbackRule: (provider: P) => P +} diff --git a/packages/shared/provider/config/vertex-anthropic.ts b/packages/shared/provider/config/vertex-anthropic.ts new file mode 100644 index 0000000000..242ba2a9f5 --- /dev/null +++ b/packages/shared/provider/config/vertex-anthropic.ts @@ -0,0 +1,19 @@ +import type { MinimalModel, MinimalProvider } from '../types' +import { provider2Provider, startsWith } from './helper' +import type { RuleSet } from './types' + +const VERTEX_ANTHROPIC_RULES: RuleSet = { + rules: [ + { + match: startsWith('claude'), + provider: (provider: MinimalProvider) => ({ + ...provider, + id: 'google-vertex-anthropic' + }) + } + ], + fallbackRule: (provider: MinimalProvider) => provider +} + +export const vertexAnthropicProviderCreator =

(model: MinimalModel, provider: P): P => + provider2Provider(VERTEX_ANTHROPIC_RULES, model, provider) diff --git a/packages/shared/provider/detection.ts b/packages/shared/provider/detection.ts new file mode 100644 index 0000000000..19fff2dff9 --- /dev/null +++ b/packages/shared/provider/detection.ts @@ -0,0 +1,100 @@ +/** + * Provider Type Detection Utilities + * + * Functions to detect provider types based on provider configuration. + * These are pure functions that only depend on provider.type and provider.id. + * + * NOTE: These functions should match the logic in @renderer/utils/provider.ts + */ + +import type { MinimalProvider } from './types' + +/** + * Check if provider is Anthropic type + */ +export function isAnthropicProvider

(provider: P): boolean { + return provider.type === 'anthropic' +} + +/** + * Check if provider is OpenAI Response type (openai-response) + * NOTE: This matches isOpenAIProvider in renderer/utils/provider.ts + */ +export function isOpenAIProvider

(provider: P): boolean { + return provider.type === 'openai-response' +} + +/** + * Check if provider is Gemini type + */ +export function isGeminiProvider

(provider: P): boolean { + return provider.type === 'gemini' +} + +/** + * Check if provider is Azure OpenAI type + */ +export function isAzureOpenAIProvider

(provider: P): boolean { + return provider.type === 'azure-openai' +} + +/** + * Check if provider is Vertex AI type + */ +export function isVertexProvider

(provider: P): boolean { + return provider.type === 'vertexai' +} + +/** + * Check if provider is AWS Bedrock type + */ +export function isAwsBedrockProvider

(provider: P): boolean { + return provider.type === 'aws-bedrock' +} + +/** + * Check if provider is AI Gateway type + */ +export function isAIGatewayProvider

(provider: P): boolean { + return provider.type === 'ai-gateway' +} + +/** + * Check if Azure OpenAI provider uses responses endpoint + * Matches isAzureResponsesEndpoint in renderer/utils/provider.ts + */ +export function isAzureResponsesEndpoint

(provider: P): boolean { + return provider.apiVersion === 'preview' || provider.apiVersion === 'v1' +} + +/** + * Check if provider is Cherry AI type + * Matches isCherryAIProvider in renderer/utils/provider.ts + */ +export function isCherryAIProvider

(provider: P): boolean { + return provider.id === 'cherryai' +} + +/** + * Check if provider is Perplexity type + * Matches isPerplexityProvider in renderer/utils/provider.ts + */ +export function isPerplexityProvider

(provider: P): boolean { + return provider.id === 'perplexity' +} + +/** + * Check if provider is new-api type (supports multiple backends) + * Matches isNewApiProvider in renderer/utils/provider.ts + */ +export function isNewApiProvider

(provider: P): boolean { + return ['new-api', 'cherryin'].includes(provider.id) || provider.type === ('new-api' as string) +} + +/** + * Check if provider is OpenAI compatible + * Matches isOpenAICompatibleProvider in renderer/utils/provider.ts + */ +export function isOpenAICompatibleProvider

(provider: P): boolean { + return ['openai', 'new-api', 'mistral'].includes(provider.type) +} diff --git a/packages/shared/provider/format.ts b/packages/shared/provider/format.ts new file mode 100644 index 0000000000..72e768d9b3 --- /dev/null +++ b/packages/shared/provider/format.ts @@ -0,0 +1,136 @@ +/** + * Provider API Host Formatting + * + * Utilities for formatting provider API hosts to work with AI SDK. + * These handle the differences between how Cherry Studio stores API hosts + * and how AI SDK expects them. + */ + +import { + formatApiHost, + formatAzureOpenAIApiHost, + formatVertexApiHost, + routeToEndpoint, + withoutTrailingSlash +} from '../api' +import { + isAnthropicProvider, + isAzureOpenAIProvider, + isCherryAIProvider, + isGeminiProvider, + isPerplexityProvider, + isVertexProvider +} from './detection' +import type { MinimalProvider } from './types' +import { SystemProviderIds } from './types' + +/** + * Interface for environment-specific implementations + * Renderer and Main process can provide their own implementations + */ +export interface ProviderFormatContext { + vertex: { + project: string + location: string + } +} + +/** + * Default Azure OpenAI API host formatter + */ +export function defaultFormatAzureOpenAIApiHost(host: string): string { + const normalizedHost = withoutTrailingSlash(host) + ?.replace(/\/v1$/, '') + .replace(/\/openai$/, '') + // AI SDK will add /v1 + return formatApiHost(normalizedHost + '/openai', false) +} + +/** + * Format provider API host for AI SDK + * + * This function normalizes the apiHost to work with AI SDK. + * Different providers have different requirements: + * - Most providers: add /v1 suffix + * - Gemini: add /v1beta suffix + * - Some providers: no suffix needed + * + * @param provider - The provider to format + * @param context - Optional context with environment-specific implementations + * @returns Provider with formatted apiHost (and anthropicApiHost if applicable) + */ +export function formatProviderApiHost(provider: T, context: ProviderFormatContext): T { + const formatted = { ...provider } + + // Format anthropicApiHost if present + if (formatted.anthropicApiHost) { + formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost) + } + + // Format based on provider type + if (isAnthropicProvider(provider)) { + const baseHost = formatted.anthropicApiHost || formatted.apiHost + // AI SDK needs /v1 in baseURL + formatted.apiHost = formatApiHost(baseHost) + if (!formatted.anthropicApiHost) { + formatted.anthropicApiHost = formatted.apiHost + } + } else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) { + formatted.apiHost = formatApiHost(formatted.apiHost, false) + } else if (isGeminiProvider(formatted)) { + formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta') + } else if (isAzureOpenAIProvider(formatted)) { + formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost) + } else if (isVertexProvider(formatted)) { + formatted.apiHost = formatVertexApiHost(formatted, context.vertex.project, context.vertex.location) + } else if (isCherryAIProvider(formatted)) { + formatted.apiHost = formatApiHost(formatted.apiHost, false) + } else if (isPerplexityProvider(formatted)) { + formatted.apiHost = formatApiHost(formatted.apiHost, false) + } else { + formatted.apiHost = formatApiHost(formatted.apiHost) + } + + return formatted +} + +/** + * Get the base URL for AI SDK from a formatted provider + * + * This extracts the baseURL that AI SDK expects, handling + * the '#' endpoint routing format if present. + * + * @param formattedApiHost - The formatted apiHost (after formatProviderApiHost) + * @returns The baseURL for AI SDK + */ +export function getBaseUrlForAiSdk(formattedApiHost: string): string { + const { baseURL } = routeToEndpoint(formattedApiHost) + return baseURL +} + +/** + * Get rotated API key from comma-separated keys + * + * This is the interface for API key rotation. The actual implementation + * depends on the environment (renderer uses window.keyv, main uses its own storage). + */ +export interface ApiKeyRotator { + /** + * Get the next API key in rotation + * @param providerId - The provider ID for tracking rotation + * @param keys - Comma-separated API keys + * @returns The next API key to use + */ + getRotatedKey(providerId: string, keys: string): string +} + +/** + * Simple API key rotator that always returns the first key + * Use this when rotation is not needed + */ +export const simpleKeyRotator: ApiKeyRotator = { + getRotatedKey(_providerId: string, keys: string): string { + const keyList = keys.split(',').map((k) => k.trim()) + return keyList[0] || keys + } +} diff --git a/packages/shared/provider/index.ts b/packages/shared/provider/index.ts new file mode 100644 index 0000000000..f0b9b11d10 --- /dev/null +++ b/packages/shared/provider/index.ts @@ -0,0 +1,48 @@ +/** + * Shared Provider Utilities + * + * This module exports utilities for working with AI providers + * that can be shared between main process and renderer process. + */ + +// Type definitions +export type { MinimalProvider, ProviderType, SystemProviderId } from './types' +export { SystemProviderIds } from './types' + +// Provider type detection +export { + isAIGatewayProvider, + isAnthropicProvider, + isAwsBedrockProvider, + isAzureOpenAIProvider, + isAzureResponsesEndpoint, + isCherryAIProvider, + isGeminiProvider, + isNewApiProvider, + isOpenAICompatibleProvider, + isOpenAIProvider, + isPerplexityProvider, + isVertexProvider +} from './detection' + +// API host formatting +export type { ApiKeyRotator, ProviderFormatContext } from './format' +export { + defaultFormatAzureOpenAIApiHost, + formatProviderApiHost, + getBaseUrlForAiSdk, + simpleKeyRotator +} from './format' + +// Provider ID mapping +export { getAiSdkProviderId, STATIC_PROVIDER_MAPPING, tryResolveProviderId } from './mapping' + +// AI SDK configuration +export type { AiSdkConfig, AiSdkConfigContext } from './sdk-config' +export { providerToAiSdkConfig } from './sdk-config' + +// Provider resolution +export { resolveActualProvider } from './resolve' + +// Provider initialization +export { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from './initialization' diff --git a/packages/shared/provider/initialization.ts b/packages/shared/provider/initialization.ts new file mode 100644 index 0000000000..fbb5fba54f --- /dev/null +++ b/packages/shared/provider/initialization.ts @@ -0,0 +1,107 @@ +import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider' + +type ProviderInitializationLogger = { + warn?: (message: string) => void + error?: (message: string, error: Error) => void +} + +export const SHARED_PROVIDER_CONFIGS: ProviderConfig[] = [ + { + id: 'openrouter', + name: 'OpenRouter', + import: () => import('@openrouter/ai-sdk-provider'), + creatorFunctionName: 'createOpenRouter', + supportsImageGeneration: true, + aliases: ['openrouter'] + }, + { + id: 'google-vertex', + name: 'Google Vertex AI', + import: () => import('@ai-sdk/google-vertex/edge'), + creatorFunctionName: 'createVertex', + supportsImageGeneration: true, + aliases: ['vertexai'] + }, + { + id: 'google-vertex-anthropic', + name: 'Google Vertex AI Anthropic', + import: () => import('@ai-sdk/google-vertex/anthropic/edge'), + creatorFunctionName: 'createVertexAnthropic', + supportsImageGeneration: true, + aliases: ['vertexai-anthropic'] + }, + { + id: 'azure-anthropic', + name: 'Azure AI Anthropic', + import: () => import('@ai-sdk/anthropic'), + creatorFunctionName: 'createAnthropic', + supportsImageGeneration: false, + aliases: ['azure-anthropic'] + }, + { + id: 'github-copilot-openai-compatible', + name: 'GitHub Copilot OpenAI Compatible', + import: () => import('@opeoginni/github-copilot-openai-compatible'), + creatorFunctionName: 'createGitHubCopilotOpenAICompatible', + supportsImageGeneration: false, + aliases: ['copilot', 'github-copilot'] + }, + { + id: 'bedrock', + name: 'Amazon Bedrock', + import: () => import('@ai-sdk/amazon-bedrock'), + creatorFunctionName: 'createAmazonBedrock', + supportsImageGeneration: true, + aliases: ['aws-bedrock'] + }, + { + id: 'perplexity', + name: 'Perplexity', + import: () => import('@ai-sdk/perplexity'), + creatorFunctionName: 'createPerplexity', + supportsImageGeneration: false, + aliases: ['perplexity'] + }, + { + id: 'mistral', + name: 'Mistral', + import: () => import('@ai-sdk/mistral'), + creatorFunctionName: 'createMistral', + supportsImageGeneration: false, + aliases: ['mistral'] + }, + { + id: 'huggingface', + name: 'HuggingFace', + import: () => import('@ai-sdk/huggingface'), + creatorFunctionName: 'createHuggingFace', + supportsImageGeneration: true, + aliases: ['hf', 'hugging-face'] + }, + { + id: 'ai-gateway', + name: 'AI Gateway', + import: () => import('@ai-sdk/gateway'), + creatorFunctionName: 'createGateway', + supportsImageGeneration: true, + aliases: ['gateway'] + }, + { + id: 'cerebras', + name: 'Cerebras', + import: () => import('@ai-sdk/cerebras'), + creatorFunctionName: 'createCerebras', + supportsImageGeneration: false + } +] as const + +export function initializeSharedProviders(logger?: ProviderInitializationLogger): void { + try { + const successCount = registerMultipleProviderConfigs(SHARED_PROVIDER_CONFIGS) + if (successCount < SHARED_PROVIDER_CONFIGS.length) { + logger?.warn?.('Some providers failed to register. Check previous error logs.') + } + } catch (error) { + logger?.error?.('Failed to initialize shared providers', error as Error) + } +} diff --git a/packages/shared/provider/mapping.ts b/packages/shared/provider/mapping.ts new file mode 100644 index 0000000000..20e2e10c3f --- /dev/null +++ b/packages/shared/provider/mapping.ts @@ -0,0 +1,95 @@ +/** + * Provider ID Mapping + * + * Maps Cherry Studio provider IDs/types to AI SDK provider IDs. + * This logic should match @renderer/aiCore/provider/factory.ts + */ + +import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider' + +import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from './detection' +import type { MinimalProvider } from './types' + +/** + * Static mapping from Cherry Studio provider ID/type to AI SDK provider ID + * Matches STATIC_PROVIDER_MAPPING in @renderer/aiCore/provider/factory.ts + */ +export const STATIC_PROVIDER_MAPPING: Record = { + gemini: 'google', // Google Gemini -> google + 'azure-openai': 'azure', // Azure OpenAI -> azure + 'openai-response': 'openai', // OpenAI Responses -> openai + grok: 'xai', // Grok -> xai + copilot: 'github-copilot-openai-compatible' +} + +/** + * Try to resolve a provider identifier to an AI SDK provider ID + * Matches tryResolveProviderId in @renderer/aiCore/provider/factory.ts + * + * @param identifier - The provider ID or type to resolve + * @param checker - Provider config checker (defaults to static mapping only) + * @returns The resolved AI SDK provider ID, or null if not found + */ +export function tryResolveProviderId(identifier: string): ProviderId | null { + // 1. 检查静态映射 + const staticMapping = STATIC_PROVIDER_MAPPING[identifier] + if (staticMapping) { + return staticMapping + } + + // 2. 检查AiCore是否支持(包括别名支持) + if (hasProviderConfigByAlias(identifier)) { + // 解析为真实的Provider ID + return resolveProviderConfigId(identifier) as ProviderId + } + + return null +} + +/** + * Get the AI SDK Provider ID for a Cherry Studio provider + * Matches getAiSdkProviderId in @renderer/aiCore/provider/factory.ts + * + * Logic: + * 1. Handle Azure OpenAI specially (check responses endpoint) + * 2. Try to resolve from provider.id + * 3. Try to resolve from provider.type (but not for generic 'openai' type) + * 4. Check for OpenAI API host pattern + * 5. Fallback to provider's own ID + * + * @param provider - The Cherry Studio provider + * @param checker - Provider config checker (defaults to static mapping only) + * @returns The AI SDK provider ID to use + */ +export function getAiSdkProviderId(provider: MinimalProvider): ProviderId { + // 1. Handle Azure OpenAI specially - check this FIRST before other resolution + if (isAzureOpenAIProvider(provider)) { + if (isAzureResponsesEndpoint(provider)) { + return 'azure-responses' + } + return 'azure' + } + + // 2. 尝试解析provider.id + const resolvedFromId = tryResolveProviderId(provider.id) + if (resolvedFromId) { + return resolvedFromId + } + + // 3. 尝试解析provider.type + // 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上 + if (provider.type !== 'openai') { + const resolvedFromType = tryResolveProviderId(provider.type) + if (resolvedFromType) { + return resolvedFromType + } + } + + // 4. Check for OpenAI API host pattern + if (provider.apiHost.includes('api.openai.com')) { + return 'openai-chat' + } + + // 5. 最后的fallback(使用provider本身的id) + return provider.id +} diff --git a/packages/shared/provider/resolve.ts b/packages/shared/provider/resolve.ts new file mode 100644 index 0000000000..9055a36c6e --- /dev/null +++ b/packages/shared/provider/resolve.ts @@ -0,0 +1,44 @@ +import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config' +import { azureAnthropicProviderCreator } from './config/azure-anthropic' +import { isAzureOpenAIProvider, isNewApiProvider } from './detection' +import type { MinimalModel, MinimalProvider } from './types' + +export interface ResolveActualProviderOptions

{ + isSystemProvider?: (provider: P) => boolean +} + +const defaultIsSystemProvider =

(provider: P): boolean => { + if ('isSystem' in provider) { + return Boolean((provider as unknown as { isSystem?: boolean }).isSystem) + } + return false +} + +export function resolveActualProvider( + provider: P, + model: M, + options: ResolveActualProviderOptions

= {} +): P { + let resolvedProvider = provider + + if (isNewApiProvider(resolvedProvider)) { + resolvedProvider = newApiResolverCreator(model, resolvedProvider) + } + + const isSystemProvider = + options.isSystemProvider?.(resolvedProvider) ?? defaultIsSystemProvider(resolvedProvider) + + if (isSystemProvider && resolvedProvider.id === 'aihubmix') { + resolvedProvider = aihubmixProviderCreator(model, resolvedProvider) + } + + if (isSystemProvider && resolvedProvider.id === 'vertexai') { + resolvedProvider = vertexAnthropicProviderCreator(model, resolvedProvider) + } + + if (isAzureOpenAIProvider(resolvedProvider)) { + resolvedProvider = azureAnthropicProviderCreator(model, resolvedProvider) + } + + return resolvedProvider +} diff --git a/packages/shared/provider/sdk-config.ts b/packages/shared/provider/sdk-config.ts new file mode 100644 index 0000000000..a03b3b1417 --- /dev/null +++ b/packages/shared/provider/sdk-config.ts @@ -0,0 +1,240 @@ +/** + * AI SDK Configuration + * + * Shared utilities for converting Cherry Studio Provider to AI SDK configuration. + * Environment-specific logic (renderer/main) is injected via context interfaces. + */ + +import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider' + +import { routeToEndpoint } from '../api' +import { getAiSdkProviderId } from './mapping' +import type { MinimalProvider } from './types' +import { SystemProviderIds } from './types' + +/** + * AI SDK configuration result + */ +export interface AiSdkConfig { + providerId: string + options: Record +} + +/** + * Context for environment-specific implementations + */ +export interface AiSdkConfigContext { + /** + * Get the rotated API key (for multi-key support) + * Default: returns first key + */ + getRotatedApiKey?: (provider: MinimalProvider) => string + + /** + * Check if a model uses chat completion only (for OpenAI response mode) + * Default: returns false + */ + isOpenAIChatCompletionOnlyModel?: (modelId: string) => boolean + + /** + * Get Copilot default headers (constants) + * Default: returns empty object + */ + getCopilotDefaultHeaders?: () => Record + + /** + * Get Copilot stored headers from state + * Default: returns empty object + */ + getCopilotStoredHeaders?: () => Record + + /** + * Get AWS Bedrock configuration + * Default: returns undefined (not configured) + */ + getAwsBedrockConfig?: () => + | { + authType: 'apiKey' | 'iam' + region: string + apiKey?: string + accessKeyId?: string + secretAccessKey?: string + } + | undefined + + /** + * Get Vertex AI configuration + * Default: returns undefined (not configured) + */ + getVertexConfig?: (provider: MinimalProvider) => + | { + project: string + location: string + googleCredentials: { + privateKey: string + clientEmail: string + } + } + | undefined + + /** + * Get endpoint type for cherryin provider + */ + getEndpointType?: (modelId: string) => string | undefined + + /** + * Custom fetch implementation + * Main process: use Electron net.fetch + * Renderer process: use browser fetch (default) + */ + fetch?: typeof globalThis.fetch +} + +/** + * Default simple key rotator - returns first key + */ +function defaultGetRotatedApiKey(provider: MinimalProvider): string { + const keys = provider.apiKey.split(',').map((k) => k.trim()) + return keys[0] || provider.apiKey +} + +/** + * Convert Cherry Studio Provider to AI SDK configuration + * + * @param provider - The formatted provider (after formatProviderApiHost) + * @param modelId - The model ID to use + * @param context - Environment-specific implementations + * @returns AI SDK configuration + */ +export function providerToAiSdkConfig( + provider: MinimalProvider, + modelId: string, + context: AiSdkConfigContext = {} +): AiSdkConfig { + const getRotatedApiKey = context.getRotatedApiKey || defaultGetRotatedApiKey + const isOpenAIChatCompletionOnlyModel = context.isOpenAIChatCompletionOnlyModel || (() => false) + + const aiSdkProviderId = getAiSdkProviderId(provider) + + // Build base config + const { baseURL, endpoint } = routeToEndpoint(provider.apiHost) + const baseConfig = { + baseURL, + apiKey: getRotatedApiKey(provider) + } + + // Handle Copilot specially + if (provider.id === SystemProviderIds.copilot) { + const defaultHeaders = context.getCopilotDefaultHeaders?.() ?? {} + const storedHeaders = context.getCopilotStoredHeaders?.() ?? {} + const options = ProviderConfigFactory.fromProvider('github-copilot-openai-compatible', baseConfig, { + headers: { + ...defaultHeaders, + ...storedHeaders, + ...provider.extra_headers + }, + name: provider.id, + includeUsage: true + }) + + return { + providerId: 'github-copilot-openai-compatible', + options + } + } + + // Build extra options + const extraOptions: Record = {} + if (endpoint) { + extraOptions.endpoint = endpoint + } + + // Handle OpenAI mode + if (provider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(modelId)) { + extraOptions.mode = 'responses' + } else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && provider.type === 'openai')) { + extraOptions.mode = 'chat' + } + + // Add extra headers + if (provider.extra_headers) { + extraOptions.headers = provider.extra_headers + if (aiSdkProviderId === 'openai') { + extraOptions.headers = { + ...(extraOptions.headers as Record), + 'HTTP-Referer': 'https://cherry-ai.com', + 'X-Title': 'Cherry Studio', + 'X-Api-Key': baseConfig.apiKey + } + } + } + + // Handle Azure modes + if (aiSdkProviderId === 'azure-responses') { + extraOptions.mode = 'responses' + } else if (aiSdkProviderId === 'azure') { + extraOptions.mode = 'chat' + } + + // Handle AWS Bedrock + if (aiSdkProviderId === 'bedrock') { + const bedrockConfig = context.getAwsBedrockConfig?.() + if (bedrockConfig) { + extraOptions.region = bedrockConfig.region + if (bedrockConfig.authType === 'apiKey') { + extraOptions.apiKey = bedrockConfig.apiKey + } else { + extraOptions.accessKeyId = bedrockConfig.accessKeyId + extraOptions.secretAccessKey = bedrockConfig.secretAccessKey + } + } + } + + // Handle Vertex AI + if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') { + const vertexConfig = context.getVertexConfig?.(provider) + if (vertexConfig) { + extraOptions.project = vertexConfig.project + extraOptions.location = vertexConfig.location + extraOptions.googleCredentials = { + ...vertexConfig.googleCredentials, + privateKey: formatPrivateKey(vertexConfig.googleCredentials.privateKey) + } + baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models' + } + } + + // Handle cherryin endpoint type + if (aiSdkProviderId === 'cherryin') { + const endpointType = context.getEndpointType?.(modelId) + if (endpointType) { + extraOptions.endpointType = endpointType + } + } + + // Inject custom fetch if provided + if (context.fetch) { + extraOptions.fetch = context.fetch + } + + // Check if AI SDK supports this provider natively + if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { + const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions) + return { + providerId: aiSdkProviderId, + options + } + } + + // Fallback to openai-compatible + const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey) + return { + providerId: 'openai-compatible', + options: { + ...options, + name: provider.id, + ...extraOptions, + includeUsage: true + } + } +} diff --git a/packages/shared/provider/types.ts b/packages/shared/provider/types.ts new file mode 100644 index 0000000000..b9745f9d3a --- /dev/null +++ b/packages/shared/provider/types.ts @@ -0,0 +1,174 @@ +import * as z from 'zod' + +export const ProviderTypeSchema = z.enum([ + 'openai', + 'openai-response', + 'anthropic', + 'gemini', + 'azure-openai', + 'vertexai', + 'mistral', + 'aws-bedrock', + 'vertex-anthropic', + 'new-api', + 'ai-gateway' +]) + +export type ProviderType = z.infer + +/** + * Minimal provider interface for shared utilities + * This is the subset of Provider that shared code needs + */ +export type MinimalProvider = { + id: string + type: ProviderType + apiKey: string + apiHost: string + anthropicApiHost?: string + apiVersion?: string + extra_headers?: Record +} + +/** + * Minimal model interface for shared utilities + * This is the subset of Model that shared code needs + */ +export type MinimalModel = { + id: string + endpoint_type?: string +} + +export const SystemProviderIdSchema = z.enum([ + 'cherryin', + 'silicon', + 'aihubmix', + 'ocoolai', + 'deepseek', + 'ppio', + 'alayanew', + 'qiniu', + 'dmxapi', + 'burncloud', + 'tokenflux', + '302ai', + 'cephalon', + 'lanyun', + 'ph8', + 'openrouter', + 'ollama', + 'ovms', + 'new-api', + 'lmstudio', + 'anthropic', + 'openai', + 'azure-openai', + 'gemini', + 'vertexai', + 'github', + 'copilot', + 'zhipu', + 'yi', + 'moonshot', + 'baichuan', + 'dashscope', + 'stepfun', + 'doubao', + 'infini', + 'minimax', + 'groq', + 'together', + 'fireworks', + 'nvidia', + 'grok', + 'hyperbolic', + 'mistral', + 'jina', + 'perplexity', + 'modelscope', + 'xirang', + 'hunyuan', + 'tencent-cloud-ti', + 'baidu-cloud', + 'gpustack', + 'voyageai', + 'aws-bedrock', + 'poe', + 'aionly', + 'longcat', + 'huggingface', + 'sophnet', + 'ai-gateway', + 'cerebras' +]) + +export type SystemProviderId = z.infer + +export const isSystemProviderId = (id: string): id is SystemProviderId => { + return SystemProviderIdSchema.safeParse(id).success +} + +export const SystemProviderIds = { + cherryin: 'cherryin', + silicon: 'silicon', + aihubmix: 'aihubmix', + ocoolai: 'ocoolai', + deepseek: 'deepseek', + ppio: 'ppio', + alayanew: 'alayanew', + qiniu: 'qiniu', + dmxapi: 'dmxapi', + burncloud: 'burncloud', + tokenflux: 'tokenflux', + '302ai': '302ai', + cephalon: 'cephalon', + lanyun: 'lanyun', + ph8: 'ph8', + sophnet: 'sophnet', + openrouter: 'openrouter', + ollama: 'ollama', + ovms: 'ovms', + 'new-api': 'new-api', + lmstudio: 'lmstudio', + anthropic: 'anthropic', + openai: 'openai', + 'azure-openai': 'azure-openai', + gemini: 'gemini', + vertexai: 'vertexai', + github: 'github', + copilot: 'copilot', + zhipu: 'zhipu', + yi: 'yi', + moonshot: 'moonshot', + baichuan: 'baichuan', + dashscope: 'dashscope', + stepfun: 'stepfun', + doubao: 'doubao', + infini: 'infini', + minimax: 'minimax', + groq: 'groq', + together: 'together', + fireworks: 'fireworks', + nvidia: 'nvidia', + grok: 'grok', + hyperbolic: 'hyperbolic', + mistral: 'mistral', + jina: 'jina', + perplexity: 'perplexity', + modelscope: 'modelscope', + xirang: 'xirang', + hunyuan: 'hunyuan', + 'tencent-cloud-ti': 'tencent-cloud-ti', + 'baidu-cloud': 'baidu-cloud', + gpustack: 'gpustack', + voyageai: 'voyageai', + 'aws-bedrock': 'aws-bedrock', + poe: 'poe', + aionly: 'aionly', + longcat: 'longcat', + huggingface: 'huggingface', + 'ai-gateway': 'ai-gateway', + cerebras: 'cerebras' +} as const satisfies Record + +export type SystemProviderIdTypeMap = typeof SystemProviderIds diff --git a/packages/shared/utils/index.ts b/packages/shared/utils/index.ts new file mode 100644 index 0000000000..838c28e6c0 --- /dev/null +++ b/packages/shared/utils/index.ts @@ -0,0 +1 @@ +export { getBaseModelName, getLowerBaseModelName } from './naming' diff --git a/packages/shared/utils/naming.ts b/packages/shared/utils/naming.ts new file mode 100644 index 0000000000..a8b4f5501d --- /dev/null +++ b/packages/shared/utils/naming.ts @@ -0,0 +1,31 @@ +/** + * 从模型 ID 中提取基础名称。 + * 例如: + * - 'deepseek/deepseek-r1' => 'deepseek-r1' + * - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1' + * @param {string} id 模型 ID + * @param {string} [delimiter='/'] 分隔符,默认为 '/' + * @returns {string} 基础名称 + */ +export const getBaseModelName = (id: string, delimiter: string = '/'): string => { + const parts = id.split(delimiter) + return parts[parts.length - 1] +} + +/** + * 从模型 ID 中提取基础名称并转换为小写。 + * 例如: + * - 'deepseek/DeepSeek-R1' => 'deepseek-r1' + * - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1' + * @param {string} id 模型 ID + * @param {string} [delimiter='/'] 分隔符,默认为 '/' + * @returns {string} 小写的基础名称 + */ +export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => { + const baseModelName = getBaseModelName(id, delimiter).toLowerCase() + // for openrouter + if (baseModelName.endsWith(':free')) { + return baseModelName.replace(':free', '') + } + return baseModelName +} diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index 02ce0544e8..1ce42c46ea 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -5,6 +5,7 @@ import type { Request, Response } from 'express' import express from 'express' import { messagesService } from '../services/messages' +import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages' import { getProviderById, validateModelId } from '../utils' const logger = loggerService.withContext('ApiServerMessagesRoutes') @@ -33,21 +34,35 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro } interface HandleMessageProcessingOptions { - req: Request res: Response provider: Provider request: MessageCreateParams modelId?: string } +/** + * Handle message processing using unified AI SDK + * All providers (including Anthropic) are handled through AI SDK: + * - Anthropic providers use @ai-sdk/anthropic which outputs native Anthropic SSE + * - Other providers use their respective AI SDK adapters, with output converted to Anthropic SSE + */ async function handleMessageProcessing({ - req, res, provider, request, modelId }: HandleMessageProcessingOptions): Promise { + const actualModelId = modelId || request.model + + logger.info('Processing message via unified AI SDK', { + providerId: provider.id, + providerType: provider.type, + modelId: actualModelId, + stream: !!request.stream + }) + try { + // Validate request const validation = messagesService.validateRequest(request) if (!validation.isValid) { res.status(400).json({ @@ -60,21 +75,23 @@ async function handleMessageProcessing({ return } - const extraHeaders = messagesService.prepareHeaders(req.headers) - const { client, anthropicRequest } = await messagesService.processMessage({ - provider, - request, - extraHeaders, - modelId - }) - if (request.stream) { - await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider) - return + await streamUnifiedMessages({ + response: res, + provider, + modelId: actualModelId, + params: request, + onError: (error) => { + logger.error('Stream error', error as Error) + }, + onComplete: () => { + logger.debug('Stream completed') + } + }) + } else { + const response = await generateUnifiedMessage(provider, actualModelId, request) + res.json(response) } - - const response = await client.messages.create(anthropicRequest) - res.json(response) } catch (error: any) { logger.error('Message processing error', { error }) const { statusCode, errorResponse } = messagesService.transformError(error) @@ -235,7 +252,7 @@ router.post('/', async (req: Request, res: Response) => { const provider = modelValidation.provider! const modelId = modelValidation.modelId! - return handleMessageProcessing({ req, res, provider, request, modelId }) + return handleMessageProcessing({ res, provider, request, modelId }) } catch (error: any) { logger.error('Message processing error', { error }) const { statusCode, errorResponse } = messagesService.transformError(error) @@ -393,7 +410,7 @@ providerRouter.post('/', async (req: Request, res: Response) => { const request: MessageCreateParams = req.body - return handleMessageProcessing({ req, res, provider, request }) + return handleMessageProcessing({ res, provider, request }) } catch (error: any) { logger.error('Message processing error', { error }) const { statusCode, errorResponse } = messagesService.transformError(error) @@ -401,4 +418,194 @@ providerRouter.post('/', async (req: Request, res: Response) => { } }) +/** + * @swagger + * /v1/messages/count_tokens: + * post: + * summary: Count tokens for messages + * description: Count tokens for Anthropic Messages API format (required by Claude Code SDK) + * tags: [Messages] + * requestBody: + * required: true + * content: + * application/json: + * schema: + * type: object + * required: + * - model + * - messages + * properties: + * model: + * type: string + * description: Model ID + * messages: + * type: array + * items: + * type: object + * system: + * type: string + * description: System message + * responses: + * 200: + * description: Token count response + * content: + * application/json: + * schema: + * type: object + * properties: + * input_tokens: + * type: integer + * 400: + * description: Bad request + */ +router.post('/count_tokens', async (req: Request, res: Response) => { + try { + const { model, messages, system } = req.body + + if (!model) { + return res.status(400).json({ + type: 'error', + error: { + type: 'invalid_request_error', + message: 'model parameter is required' + } + }) + } + + if (!messages || !Array.isArray(messages)) { + return res.status(400).json({ + type: 'error', + error: { + type: 'invalid_request_error', + message: 'messages parameter is required' + } + }) + } + + // Simple token estimation based on character count + // This is a rough approximation: ~4 characters per token for English text + let totalChars = 0 + + // Count system message tokens + if (system) { + if (typeof system === 'string') { + totalChars += system.length + } else if (Array.isArray(system)) { + for (const block of system) { + if (block.type === 'text' && block.text) { + totalChars += block.text.length + } + } + } + } + + // Count message tokens + for (const msg of messages) { + if (typeof msg.content === 'string') { + totalChars += msg.content.length + } else if (Array.isArray(msg.content)) { + for (const block of msg.content) { + if (block.type === 'text' && block.text) { + totalChars += block.text.length + } + } + } + // Add overhead for role + totalChars += 10 + } + + // Estimate tokens (~4 chars per token, with some overhead) + const estimatedTokens = Math.ceil(totalChars / 4) + messages.length * 3 + + logger.debug('Token count estimated', { + model, + messageCount: messages.length, + totalChars, + estimatedTokens + }) + + return res.json({ + input_tokens: estimatedTokens + }) + } catch (error: any) { + logger.error('Token counting error', { error }) + return res.status(500).json({ + type: 'error', + error: { + type: 'api_error', + message: error.message || 'Internal server error' + } + }) + } +}) + +/** + * Provider-specific count_tokens endpoint + */ +providerRouter.post('/count_tokens', async (req: Request, res: Response) => { + try { + const { model, messages, system } = req.body + + if (!messages || !Array.isArray(messages)) { + return res.status(400).json({ + type: 'error', + error: { + type: 'invalid_request_error', + message: 'messages parameter is required' + } + }) + } + + // Simple token estimation + let totalChars = 0 + + if (system) { + if (typeof system === 'string') { + totalChars += system.length + } else if (Array.isArray(system)) { + for (const block of system) { + if (block.type === 'text' && block.text) { + totalChars += block.text.length + } + } + } + } + + for (const msg of messages) { + if (typeof msg.content === 'string') { + totalChars += msg.content.length + } else if (Array.isArray(msg.content)) { + for (const block of msg.content) { + if (block.type === 'text' && block.text) { + totalChars += block.text.length + } + } + } + totalChars += 10 + } + + const estimatedTokens = Math.ceil(totalChars / 4) + messages.length * 3 + + logger.debug('Token count estimated (provider route)', { + providerId: req.params.provider, + model, + messageCount: messages.length, + estimatedTokens + }) + + return res.json({ + input_tokens: estimatedTokens + }) + } catch (error: any) { + logger.error('Token counting error', { error }) + return res.status(500).json({ + type: 'error', + error: { + type: 'api_error', + message: error.message || 'Internal server error' + } + }) + } +}) + export { providerRouter as messagesProviderRoutes, router as messagesRoutes } diff --git a/src/main/apiServer/services/models.ts b/src/main/apiServer/services/models.ts index 52f0db857f..b72c21b1e1 100644 --- a/src/main/apiServer/services/models.ts +++ b/src/main/apiServer/services/models.ts @@ -1,13 +1,6 @@ -import { isEmpty } from 'lodash' - import type { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels' import { loggerService } from '../../services/LoggerService' -import { - getAvailableProviders, - getProviderAnthropicModelChecker, - listAllAvailableModels, - transformModelToOpenAI -} from '../utils' +import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils' const logger = loggerService.withContext('ModelsService') @@ -20,11 +13,12 @@ export class ModelsService { try { logger.debug('Getting available models from providers', { filter }) - let providers = await getAvailableProviders() + const providers = await getAvailableProviders() - if (filter.providerType === 'anthropic') { - providers = providers.filter((p) => p.type === 'anthropic' || !isEmpty(p.anthropicApiHost?.trim())) - } + // Note: When providerType === 'anthropic', we now return ALL available models + // because the API Server's unified adapter (AiSdkToAnthropicSSE) can convert + // any provider's response to Anthropic SSE format. This enables Claude Code Agent + // to work with OpenAI, Gemini, and other providers transparently. const models = await listAllAvailableModels(providers) // Use Map to deduplicate models by their full ID (provider:model_id) @@ -32,20 +26,11 @@ export class ModelsService { for (const model of models) { const provider = providers.find((p) => p.id === model.provider) - // logger.debug(`Processing model ${model.id}`) if (!provider) { logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`) continue } - if (filter.providerType === 'anthropic') { - const checker = getProviderAnthropicModelChecker(provider.id) - if (!checker(model)) { - logger.debug(`Skipping model ${model.id} from ${model.provider}. Reason: Not an Anthropic model.`) - continue - } - } - const openAIModel = transformModelToOpenAI(model, provider) const fullModelId = openAIModel.id // This is already in format "provider:model_id" diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts new file mode 100644 index 0000000000..d0acd70231 --- /dev/null +++ b/src/main/apiServer/services/unified-messages.ts @@ -0,0 +1,455 @@ +import type { LanguageModelV2ToolResultOutput } from '@ai-sdk/provider' +import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' +import type { ImageBlockParam, MessageCreateParams, TextBlockParam } from '@anthropic-ai/sdk/resources/messages' +import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' +import { loggerService } from '@logger' +import { reduxService } from '@main/services/ReduxService' +import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters' +import { + type AiSdkConfig, + type AiSdkConfigContext, + formatProviderApiHost, + initializeSharedProviders, + type ProviderFormatContext, + providerToAiSdkConfig as sharedProviderToAiSdkConfig, + resolveActualProvider +} from '@shared/provider' +import { defaultAppHeaders } from '@shared/utils' +import type { Provider } from '@types' +import type { ImagePart, LanguageModel, ModelMessage, Provider as AiSdkProvider, TextPart } from 'ai' +import { stepCountIs, streamText } from 'ai' +import { net } from 'electron' +import type { Response } from 'express' + +const logger = loggerService.withContext('UnifiedMessagesService') + +initializeSharedProviders({ + warn: (message) => logger.warn(message), + error: (message, error) => logger.error(message, error) +}) + +export interface UnifiedStreamConfig { + response: Response + provider: Provider + modelId: string + params: MessageCreateParams + onError?: (error: unknown) => void + onComplete?: () => void +} + +// ============================================================================ +// Provider Factory +// ============================================================================ + +/** + * Main process format context for formatProviderApiHost + * Unlike renderer, main process doesn't have direct access to store getters, so use reduxService cache + */ +function getMainProcessFormatContext(): ProviderFormatContext { + const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai') + return { + vertex: { + project: vertexSettings?.projectId || 'default-project', + location: vertexSettings?.location || 'us-central1' + } + } +} + +/** + * Main process context for providerToAiSdkConfig + * Main process doesn't have access to browser APIs like window.keyv + */ +const mainProcessSdkContext: AiSdkConfigContext = { + // Simple key rotation - just return first key (no persistent rotation in main process) + getRotatedApiKey: (provider) => { + const keys = provider.apiKey.split(',').map((k) => k.trim()) + return keys[0] || provider.apiKey + }, + fetch: net.fetch as typeof globalThis.fetch +} + +/** + * Get actual provider configuration for a model + * + * For aggregated providers (new-api, aihubmix, vertexai, azure-openai), + * this resolves the actual provider type based on the model's characteristics. + */ +function getActualProvider(provider: Provider, modelId: string): Provider { + // Find the model in provider's models list + const model = provider.models?.find((m) => m.id === modelId) + if (!model) { + // If model not found, return provider as-is + return provider + } + + // Resolve actual provider based on model + return resolveActualProvider(provider, model) +} + +/** + * Convert Cherry Studio Provider to AI SDK config + * Uses shared implementation with main process context + */ +function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig { + // First resolve actual provider for aggregated providers + const actualProvider = getActualProvider(provider, modelId) + + // Format the provider's apiHost for AI SDK + const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext()) + + // Use shared implementation + return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext) +} + +/** + * Create an AI SDK provider from Cherry Studio provider configuration + */ +async function createAiSdkProvider(config: AiSdkConfig): Promise { + try { + const provider = await createProviderCore(config.providerId, config.options) + logger.debug('AI SDK provider created', { + providerId: config.providerId, + hasOptions: !!config.options + }) + return provider + } catch (error) { + logger.error('Failed to create AI SDK provider', error as Error, { + providerId: config.providerId + }) + throw error + } +} + +/** + * Create an AI SDK language model from a Cherry Studio provider configuration + * Uses shared provider utilities for consistent behavior with renderer + */ +async function createLanguageModel(provider: Provider, modelId: string): Promise { + logger.debug('Creating language model', { + providerId: provider.id, + providerType: provider.type, + modelId, + apiHost: provider.apiHost + }) + + // Convert provider config to AI SDK config + const config = providerToAiSdkConfig(provider, modelId) + + // Create the AI SDK provider + const aiSdkProvider = await createAiSdkProvider(config) + if (!aiSdkProvider) { + throw new Error(`Failed to create AI SDK provider for ${provider.id}`) + } + + // Get the language model + return aiSdkProvider.languageModel(modelId) +} + +function convertAnthropicToolResultToAiSdk( + content: string | Array +): LanguageModelV2ToolResultOutput { + if (typeof content === 'string') { + return { + type: 'text', + value: content + } + } else { + const values: Array< + | { type: 'text'; text: string } + | { + type: 'media' + /** +Base-64 encoded media data. +*/ + data: string + /** +IANA media type. +@see https://www.iana.org/assignments/media-types/media-types.xhtml +*/ + mediaType: string + } + > = [] + for (const block of content) { + if (block.type === 'text') { + values.push({ + type: 'text', + text: block.text + }) + } else if (block.type === 'image') { + values.push({ + type: 'media', + data: block.source.type === 'base64' ? block.source.data : block.source.url, + mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png' + }) + } + } + return { + type: 'content', + value: [] + } + } +} + +/** + * Convert Anthropic MessageCreateParams to AI SDK message format + */ +function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] { + const messages: ModelMessage[] = [] + + // Add system message if present + if (params.system) { + if (typeof params.system === 'string') { + messages.push({ + role: 'system', + content: params.system + }) + } else if (Array.isArray(params.system)) { + // Handle TextBlockParam array + const systemText = params.system + .filter((block) => block.type === 'text') + .map((block) => block.text) + .join('\n') + if (systemText) { + messages.push({ + role: 'system', + content: systemText + }) + } + } + } + + // Convert user/assistant messages + for (const msg of params.messages) { + if (typeof msg.content === 'string') { + if (msg.role === 'user') { + messages.push({ role: 'user', content: msg.content }) + } else { + messages.push({ role: 'assistant', content: msg.content }) + } + } else if (Array.isArray(msg.content)) { + // Handle content blocks + const textParts: TextPart[] = [] + const imageParts: ImagePart[] = [] + const reasoningParts: ReasoningPart[] = [] + const toolCallParts: ToolCallPart[] = [] + const toolResultParts: ToolResultPart[] = [] + + for (const block of msg.content) { + if (block.type === 'text') { + textParts.push({ type: 'text', text: block.text }) + } else if (block.type === 'thinking') { + reasoningParts.push({ type: 'reasoning', text: block.thinking }) + } else if (block.type === 'redacted_thinking') { + reasoningParts.push({ type: 'reasoning', text: block.data }) + } else if (block.type === 'image') { + const source = block.source + if (source.type === 'base64') { + imageParts.push({ + type: 'image', + image: `data:${source.media_type};base64,${source.data}` + }) + } else if (source.type === 'url') { + imageParts.push({ + type: 'image', + image: source.url + }) + } + } else if (block.type === 'tool_use') { + toolCallParts.push({ + type: 'tool-call', + toolName: block.name, + toolCallId: block.id, + input: block.input + }) + } else if (block.type === 'tool_result') { + toolResultParts.push({ + type: 'tool-result', + toolCallId: block.tool_use_id, + toolName: toolCallParts.find((t) => t.toolCallId === block.tool_use_id)?.toolName || 'unknown', + output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' } + }) + } + } + + // Build the message based on role + if (msg.role === 'user') { + messages.push({ + role: 'user', + content: [...textParts, ...imageParts] + }) + } else { + // Assistant messages can only have text + if (textParts.length > 0) { + messages.push({ + role: 'assistant', + content: [...reasoningParts, ...textParts, ...toolCallParts, ...toolResultParts] + }) + } + } + } + } + + return messages +} + +/** + * Stream a message request using AI SDK and convert to Anthropic SSE format + */ +export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise { + const { response, provider, modelId, params, onError, onComplete } = config + + logger.info('Starting unified message stream', { + providerId: provider.id, + providerType: provider.type, + modelId, + stream: params.stream + }) + + try { + response.setHeader('Content-Type', 'text/event-stream') + response.setHeader('Cache-Control', 'no-cache') + response.setHeader('Connection', 'keep-alive') + response.setHeader('X-Accel-Buffering', 'no') + + const model = await createLanguageModel(provider, modelId) + + const coreMessages = convertAnthropicToAiMessages(params) + + logger.debug('Converted messages', { + originalCount: params.messages.length, + convertedCount: coreMessages.length, + hasSystem: !!params.system + }) + + // Create the adapter + const adapter = new AiSdkToAnthropicSSE({ + model: `${provider.id}:${modelId}`, + onEvent: (event) => { + const sseData = formatSSEEvent(event) + response.write(sseData) + } + }) + + // Start streaming + const result = streamText({ + model, + messages: coreMessages, + maxOutputTokens: params.max_tokens, + temperature: params.temperature, + topP: params.top_p, + stopSequences: params.stop_sequences, + stopWhen: stepCountIs(100), + headers: defaultAppHeaders(), + providerOptions: {} + }) + + // Process the stream through the adapter + await adapter.processStream(result.fullStream) + + // Send done marker + response.write(formatSSEDone()) + response.end() + + logger.info('Unified message stream completed', { + providerId: provider.id, + modelId + }) + + onComplete?.() + } catch (error) { + logger.error('Error in unified message stream', error as Error, { + providerId: provider.id, + modelId + }) + + // Try to send error event if response is still writable + if (!response.writableEnded) { + try { + const errorMessage = error instanceof Error ? error.message : 'Unknown error' + response.write( + `event: error\ndata: ${JSON.stringify({ + type: 'error', + error: { + type: 'api_error', + message: errorMessage + } + })}\n\n` + ) + response.end() + } catch { + // Response already ended + } + } + + onError?.(error) + throw error + } +} + +/** + * Generate a non-streaming message response + */ +export async function generateUnifiedMessage( + provider: Provider, + modelId: string, + params: MessageCreateParams +): Promise> { + logger.info('Starting unified message generation', { + providerId: provider.id, + providerType: provider.type, + modelId + }) + + try { + // Create language model (async - uses @cherrystudio/ai-core) + const model = await createLanguageModel(provider, modelId) + + // Convert messages + const coreMessages = convertAnthropicToAiMessages(params) + + // Create adapter to collect the response + let finalResponse: ReturnType | null = null + const adapter = new AiSdkToAnthropicSSE({ + model: `${provider.id}:${modelId}`, + onEvent: () => { + // We don't need to emit events for non-streaming + } + }) + + // Generate text + const result = streamText({ + model, + messages: coreMessages, + maxOutputTokens: params.max_tokens, + temperature: params.temperature, + topP: params.top_p, + stopSequences: params.stop_sequences, + headers: defaultAppHeaders(), + stopWhen: stepCountIs(100) + }) + + // Process the stream to build the response + await adapter.processStream(result.fullStream) + + // Get the final response + finalResponse = adapter.buildNonStreamingResponse() + + logger.info('Unified message generation completed', { + providerId: provider.id, + modelId + }) + + return finalResponse + } catch (error) { + logger.error('Error in unified message generation', error as Error, { + providerId: provider.id, + modelId + }) + throw error + } +} + +export default { + streamUnifiedMessages, + generateUnifiedMessage +} diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts index 53b318c5b2..261ff7c07e 100644 --- a/src/main/services/agents/services/claudecode/index.ts +++ b/src/main/services/agents/services/claudecode/index.ts @@ -84,18 +84,14 @@ class ClaudeCodeService implements AgentServiceInterface { }) return aiStream } - if ( - (modelInfo.provider?.type !== 'anthropic' && - (modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) || - modelInfo.provider.apiKey === '' - ) { - logger.error('Anthropic provider configuration is missing', { - modelInfo - }) - + // Validate provider has required configuration + // Note: We no longer restrict to anthropic type only - the API Server's unified adapter + // handles format conversion for any provider type (OpenAI, Gemini, etc.) + if (!modelInfo.provider?.apiKey) { + logger.error('Provider API key is missing', { modelInfo }) aiStream.emit('data', { type: 'error', - error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`) + error: new Error(`Provider '${modelInfo.provider?.id}' is missing API key configuration.`) }) return aiStream } @@ -106,15 +102,14 @@ class ClaudeCodeService implements AgentServiceInterface { Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy')) ) as Record + // Route through local API Server which handles format conversion via unified adapter + // This enables Claude Code Agent to work with any provider (OpenAI, Gemini, etc.) + // The API Server converts AI SDK responses to Anthropic SSE format transparently const env = { ...loginShellEnvWithoutProxies, - // TODO: fix the proxy api server - // ANTHROPIC_API_KEY: apiConfig.apiKey, - // ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey, - // ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`, - ANTHROPIC_API_KEY: modelInfo.provider.apiKey, - ANTHROPIC_AUTH_TOKEN: modelInfo.provider.apiKey, - ANTHROPIC_BASE_URL: modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost, + ANTHROPIC_API_KEY: apiConfig.apiKey, + ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey, + ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`, ANTHROPIC_MODEL: modelInfo.modelId, ANTHROPIC_DEFAULT_OPUS_MODEL: modelInfo.modelId, ANTHROPIC_DEFAULT_SONNET_MODEL: modelInfo.modelId, diff --git a/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts index fb371d9ae5..3eaf2f0fb9 100644 --- a/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts @@ -24,7 +24,7 @@ export class VertexAPIClient extends GeminiAPIClient { this.anthropicVertexClient = new AnthropicVertexClient(provider) // 如果传入的是普通 Provider,转换为 VertexProvider if (isVertexProvider(provider)) { - this.vertexProvider = provider + this.vertexProvider = provider as VertexProvider } else { this.vertexProvider = createVertexProvider(provider) } diff --git a/src/renderer/src/aiCore/provider/config/azure-anthropic.ts b/src/renderer/src/aiCore/provider/config/azure-anthropic.ts deleted file mode 100644 index c6cb521386..0000000000 --- a/src/renderer/src/aiCore/provider/config/azure-anthropic.ts +++ /dev/null @@ -1,22 +0,0 @@ -import type { Provider } from '@renderer/types' - -import { provider2Provider, startsWith } from './helper' -import type { RuleSet } from './types' - -// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry -const AZURE_ANTHROPIC_RULES: RuleSet = { - rules: [ - { - match: startsWith('claude'), - provider: (provider: Provider) => ({ - ...provider, - type: 'anthropic', - apiHost: provider.apiHost + 'anthropic/v1', - id: 'azure-anthropic' - }) - } - ], - fallbackRule: (provider: Provider) => provider -} - -export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES) diff --git a/src/renderer/src/aiCore/provider/config/helper.ts b/src/renderer/src/aiCore/provider/config/helper.ts deleted file mode 100644 index 656911fc76..0000000000 --- a/src/renderer/src/aiCore/provider/config/helper.ts +++ /dev/null @@ -1,22 +0,0 @@ -import type { Model, Provider } from '@renderer/types' - -import type { RuleSet } from './types' - -export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase()) -export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type - -/** - * 解析模型对应的Provider - * @param ruleSet 规则集对象 - * @param model 模型对象 - * @param provider 原始provider对象 - * @returns 解析出的provider对象 - */ -export function provider2Provider(ruleSet: RuleSet, model: Model, provider: Provider): Provider { - for (const rule of ruleSet.rules) { - if (rule.match(model)) { - return rule.provider(provider) - } - } - return ruleSet.fallbackRule(provider) -} diff --git a/src/renderer/src/aiCore/provider/config/index.ts b/src/renderer/src/aiCore/provider/config/index.ts index 2f51234cec..b1d57d5a1a 100644 --- a/src/renderer/src/aiCore/provider/config/index.ts +++ b/src/renderer/src/aiCore/provider/config/index.ts @@ -1,3 +1,7 @@ -export { aihubmixProviderCreator } from './aihubmix' -export { newApiResolverCreator } from './newApi' -export { vertexAnthropicProviderCreator } from './vertext-anthropic' +// Re-export from shared config +export { + aihubmixProviderCreator, + azureAnthropicProviderCreator, + newApiResolverCreator, + vertexAnthropicProviderCreator +} from '@shared/provider/config' diff --git a/src/renderer/src/aiCore/provider/config/types.ts b/src/renderer/src/aiCore/provider/config/types.ts deleted file mode 100644 index f3938b84d1..0000000000 --- a/src/renderer/src/aiCore/provider/config/types.ts +++ /dev/null @@ -1,9 +0,0 @@ -import type { Model, Provider } from '@renderer/types' - -export interface RuleSet { - rules: Array<{ - match: (model: Model) => boolean - provider: (provider: Provider) => Provider - }> - fallbackRule: (provider: Provider) => Provider -} diff --git a/src/renderer/src/aiCore/provider/config/vertext-anthropic.ts b/src/renderer/src/aiCore/provider/config/vertext-anthropic.ts deleted file mode 100644 index 23c8b5185c..0000000000 --- a/src/renderer/src/aiCore/provider/config/vertext-anthropic.ts +++ /dev/null @@ -1,19 +0,0 @@ -import type { Provider } from '@renderer/types' - -import { provider2Provider, startsWith } from './helper' -import type { RuleSet } from './types' - -const VERTEX_ANTHROPIC_RULES: RuleSet = { - rules: [ - { - match: startsWith('claude'), - provider: (provider: Provider) => ({ - ...provider, - id: 'google-vertex-anthropic' - }) - } - ], - fallbackRule: (provider: Provider) => provider -} - -export const vertexAnthropicProviderCreator = provider2Provider.bind(null, VERTEX_ANTHROPIC_RULES) diff --git a/src/renderer/src/aiCore/provider/factory.ts b/src/renderer/src/aiCore/provider/factory.ts index 9760839389..97ab29db81 100644 --- a/src/renderer/src/aiCore/provider/factory.ts +++ b/src/renderer/src/aiCore/provider/factory.ts @@ -1,8 +1,7 @@ -import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider' import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' import type { Provider } from '@renderer/types' -import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider' +import { getAiSdkProviderId as sharedGetAiSdkProviderId } from '@shared/provider' import type { Provider as AiSdkProvider } from 'ai' import type { AiSdkConfig } from '../types' @@ -22,68 +21,12 @@ const logger = loggerService.withContext('ProviderFactory') } })() -/** - * 静态Provider映射表 - * 处理Cherry Studio特有的provider ID到AI SDK标准ID的映射 - */ -const STATIC_PROVIDER_MAPPING: Record = { - gemini: 'google', // Google Gemini -> google - 'azure-openai': 'azure', // Azure OpenAI -> azure - 'openai-response': 'openai', // OpenAI Responses -> openai - grok: 'xai', // Grok -> xai - copilot: 'github-copilot-openai-compatible' -} - -/** - * 尝试解析provider标识符(支持静态映射和别名) - */ -function tryResolveProviderId(identifier: string): ProviderId | null { - // 1. 检查静态映射 - const staticMapping = STATIC_PROVIDER_MAPPING[identifier] - if (staticMapping) { - return staticMapping - } - - // 2. 检查AiCore是否支持(包括别名支持) - if (hasProviderConfigByAlias(identifier)) { - // 解析为真实的Provider ID - return resolveProviderConfigId(identifier) as ProviderId - } - - return null -} - /** * 获取AI SDK Provider ID - * 简化版:减少重复逻辑,利用通用解析函数 + * Uses shared implementation with renderer-specific config checker */ export function getAiSdkProviderId(provider: Provider): string { - // 1. 尝试解析provider.id - const resolvedFromId = tryResolveProviderId(provider.id) - if (isAzureOpenAIProvider(provider)) { - if (isAzureResponsesEndpoint(provider)) { - return 'azure-responses' - } else { - return 'azure' - } - } - if (resolvedFromId) { - return resolvedFromId - } - - // 2. 尝试解析provider.type - // 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上 - if (provider.type !== 'openai') { - const resolvedFromType = tryResolveProviderId(provider.type) - if (resolvedFromType) { - return resolvedFromType - } - } - if (provider.apiHost.includes('api.openai.com')) { - return 'openai-chat' - } - // 3. 最后的fallback(使用provider本身的id) - return provider.id + return sharedGetAiSdkProviderId(provider) } export async function createAiSdkProvider(config: AiSdkConfig): Promise { diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index 53194d3506..8e1a63f5a0 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -1,4 +1,4 @@ -import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider' +import { hasProviderConfig } from '@cherrystudio/ai-core/provider' import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models' import { getAwsBedrockAccessKeyId, @@ -10,22 +10,17 @@ import { import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI' import { getProviderByModel } from '@renderer/services/AssistantService' import store from '@renderer/store' -import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types' -import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api' +import { isSystemProvider, type Model, type Provider } from '@renderer/types' import { - isAnthropicProvider, - isAzureOpenAIProvider, - isCherryAIProvider, - isGeminiProvider, - isNewApiProvider, - isPerplexityProvider, - isVertexProvider -} from '@renderer/utils/provider' + type AiSdkConfigContext, + formatProviderApiHost as sharedFormatProviderApiHost, + type ProviderFormatContext, + providerToAiSdkConfig as sharedProviderToAiSdkConfig, + resolveActualProvider +} from '@shared/provider' import { cloneDeep } from 'lodash' import type { AiSdkConfig } from '../types' -import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config' -import { azureAnthropicProviderCreator } from './config/azure-anthropic' import { COPILOT_DEFAULT_HEADERS } from './constants' import { getAiSdkProviderId } from './factory' @@ -56,61 +51,51 @@ function getRotatedApiKey(provider: Provider): string { } /** - * 处理特殊provider的转换逻辑 + * Renderer-specific context for providerToAiSdkConfig + * Provides implementations using browser APIs, store, and hooks */ -function handleSpecialProviders(model: Model, provider: Provider): Provider { - if (isNewApiProvider(provider)) { - return newApiResolverCreator(model, provider) +function createRendererSdkContext(model: Model): AiSdkConfigContext { + return { + getRotatedApiKey: (provider) => getRotatedApiKey(provider as Provider), + isOpenAIChatCompletionOnlyModel: () => isOpenAIChatCompletionOnlyModel(model), + getCopilotDefaultHeaders: () => COPILOT_DEFAULT_HEADERS, + getCopilotStoredHeaders: () => store.getState().copilot.defaultHeaders ?? {}, + getAwsBedrockConfig: () => { + const authType = getAwsBedrockAuthType() + return { + authType, + region: getAwsBedrockRegion(), + apiKey: authType === 'apiKey' ? getAwsBedrockApiKey() : undefined, + accessKeyId: authType === 'iam' ? getAwsBedrockAccessKeyId() : undefined, + secretAccessKey: authType === 'iam' ? getAwsBedrockSecretAccessKey() : undefined + } + }, + getVertexConfig: (provider) => { + if (!isVertexAIConfigured()) { + return undefined + } + return createVertexProvider(provider as Provider) + }, + getEndpointType: () => model.endpoint_type } - - if (isSystemProvider(provider)) { - if (provider.id === 'aihubmix') { - return aihubmixProviderCreator(model, provider) - } - if (provider.id === 'vertexai') { - return vertexAnthropicProviderCreator(model, provider) - } - } - if (isAzureOpenAIProvider(provider)) { - return azureAnthropicProviderCreator(model, provider) - } - return provider } /** * 主要用来对齐AISdk的BaseURL格式 - * @param provider - * @returns + * Uses shared implementation with renderer-specific context */ -function formatProviderApiHost(provider: Provider): Provider { - const formatted = { ...provider } - if (formatted.anthropicApiHost) { - formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost) - } - - if (isAnthropicProvider(provider)) { - const baseHost = formatted.anthropicApiHost || formatted.apiHost - // AI SDK needs /v1 in baseURL, Anthropic SDK will strip it in getSdkClient - formatted.apiHost = formatApiHost(baseHost) - if (!formatted.anthropicApiHost) { - formatted.anthropicApiHost = formatted.apiHost +function getRendererFormatContext(): ProviderFormatContext { + const vertexSettings = store.getState().llm.settings.vertexai + return { + vertex: { + project: vertexSettings.projectId || 'default-project', + location: vertexSettings.location || 'us-central1' } - } else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) { - formatted.apiHost = formatApiHost(formatted.apiHost, false) - } else if (isGeminiProvider(formatted)) { - formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta') - } else if (isAzureOpenAIProvider(formatted)) { - formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost) - } else if (isVertexProvider(formatted)) { - formatted.apiHost = formatVertexApiHost(formatted) - } else if (isCherryAIProvider(formatted)) { - formatted.apiHost = formatApiHost(formatted.apiHost, false) - } else if (isPerplexityProvider(formatted)) { - formatted.apiHost = formatApiHost(formatted.apiHost, false) - } else { - formatted.apiHost = formatApiHost(formatted.apiHost) } - return formatted +} + +function formatProviderApiHost(provider: Provider): Provider { + return sharedFormatProviderApiHost(provider, getRendererFormatContext()) } /** @@ -122,7 +107,9 @@ export function getActualProvider(model: Model): Provider { // 按顺序处理各种转换 let actualProvider = cloneDeep(baseProvider) - actualProvider = handleSpecialProviders(model, actualProvider) + actualProvider = resolveActualProvider(actualProvider, model, { + isSystemProvider + }) as Provider actualProvider = formatProviderApiHost(actualProvider) return actualProvider @@ -130,121 +117,11 @@ export function getActualProvider(model: Model): Provider { /** * 将 Provider 配置转换为新 AI SDK 格式 - * 简化版:利用新的别名映射系统 + * Uses shared implementation with renderer-specific context */ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig { - const aiSdkProviderId = getAiSdkProviderId(actualProvider) - - // 构建基础配置 - const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost) - const baseConfig = { - baseURL: baseURL, - apiKey: getRotatedApiKey(actualProvider) - } - - const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot - if (isCopilotProvider) { - const storedHeaders = store.getState().copilot.defaultHeaders ?? {} - const options = ProviderConfigFactory.fromProvider('github-copilot-openai-compatible', baseConfig, { - headers: { - ...COPILOT_DEFAULT_HEADERS, - ...storedHeaders, - ...actualProvider.extra_headers - }, - name: actualProvider.id, - includeUsage: true - }) - - return { - providerId: 'github-copilot-openai-compatible', - options - } - } - - // 处理OpenAI模式 - const extraOptions: any = {} - extraOptions.endpoint = endpoint - if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) { - extraOptions.mode = 'responses' - } else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) { - extraOptions.mode = 'chat' - } - - // 添加额外headers - if (actualProvider.extra_headers) { - extraOptions.headers = actualProvider.extra_headers - // copy from openaiBaseClient/openaiResponseApiClient - if (aiSdkProviderId === 'openai') { - extraOptions.headers = { - ...extraOptions.headers, - 'HTTP-Referer': 'https://cherry-ai.com', - 'X-Title': 'Cherry Studio', - 'X-Api-Key': baseConfig.apiKey - } - } - } - // azure - // https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest - // https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api - if (aiSdkProviderId === 'azure-responses') { - extraOptions.mode = 'responses' - } else if (aiSdkProviderId === 'azure') { - extraOptions.mode = 'chat' - } - - // bedrock - if (aiSdkProviderId === 'bedrock') { - const authType = getAwsBedrockAuthType() - extraOptions.region = getAwsBedrockRegion() - - if (authType === 'apiKey') { - extraOptions.apiKey = getAwsBedrockApiKey() - } else { - extraOptions.accessKeyId = getAwsBedrockAccessKeyId() - extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey() - } - } - // google-vertex - if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') { - if (!isVertexAIConfigured()) { - throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') - } - const { project, location, googleCredentials } = createVertexProvider(actualProvider) - extraOptions.project = project - extraOptions.location = location - extraOptions.googleCredentials = { - ...googleCredentials, - privateKey: formatPrivateKey(googleCredentials.privateKey) - } - baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models' - } - - // cherryin - if (aiSdkProviderId === 'cherryin') { - if (model.endpoint_type) { - extraOptions.endpointType = model.endpoint_type - } - } - - if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { - const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions) - return { - providerId: aiSdkProviderId, - options - } - } - - // 否则fallback到openai-compatible - const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey) - return { - providerId: 'openai-compatible', - options: { - ...options, - name: actualProvider.id, - ...extraOptions, - includeUsage: true - } - } + const context = createRendererSdkContext(model) + return sharedProviderToAiSdkConfig(actualProvider, model.id, context) as AiSdkConfig } /** @@ -287,13 +164,13 @@ export async function prepareSpecialProviderConfig( break } case 'cherryai': { - config.options.fetch = async (url, options) => { + config.options.fetch = async (url: RequestInfo | URL, options: RequestInit) => { // 在这里对最终参数进行签名 const signature = await window.api.cherryai.generateSignature({ method: 'POST', path: '/chat/completions', query: '', - body: JSON.parse(options.body) + body: JSON.parse(options.body as string) }) return fetch(url, { ...options, diff --git a/src/renderer/src/aiCore/provider/providerInitialization.ts b/src/renderer/src/aiCore/provider/providerInitialization.ts index 2e4b9fced2..5254e78851 100644 --- a/src/renderer/src/aiCore/provider/providerInitialization.ts +++ b/src/renderer/src/aiCore/provider/providerInitialization.ts @@ -1,113 +1,13 @@ -import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' +import { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from '@shared/provider' const logger = loggerService.withContext('ProviderConfigs') -/** - * 新Provider配置定义 - * 定义了需要动态注册的AI Providers - */ -export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [ - { - id: 'openrouter', - name: 'OpenRouter', - import: () => import('@openrouter/ai-sdk-provider'), - creatorFunctionName: 'createOpenRouter', - supportsImageGeneration: true, - aliases: ['openrouter'] - }, - { - id: 'google-vertex', - name: 'Google Vertex AI', - import: () => import('@ai-sdk/google-vertex/edge'), - creatorFunctionName: 'createVertex', - supportsImageGeneration: true, - aliases: ['vertexai'] - }, - { - id: 'google-vertex-anthropic', - name: 'Google Vertex AI Anthropic', - import: () => import('@ai-sdk/google-vertex/anthropic/edge'), - creatorFunctionName: 'createVertexAnthropic', - supportsImageGeneration: true, - aliases: ['vertexai-anthropic'] - }, - { - id: 'azure-anthropic', - name: 'Azure AI Anthropic', - import: () => import('@ai-sdk/anthropic'), - creatorFunctionName: 'createAnthropic', - supportsImageGeneration: false, - aliases: ['azure-anthropic'] - }, - { - id: 'github-copilot-openai-compatible', - name: 'GitHub Copilot OpenAI Compatible', - import: () => import('@opeoginni/github-copilot-openai-compatible'), - creatorFunctionName: 'createGitHubCopilotOpenAICompatible', - supportsImageGeneration: false, - aliases: ['copilot', 'github-copilot'] - }, - { - id: 'bedrock', - name: 'Amazon Bedrock', - import: () => import('@ai-sdk/amazon-bedrock'), - creatorFunctionName: 'createAmazonBedrock', - supportsImageGeneration: true, - aliases: ['aws-bedrock'] - }, - { - id: 'perplexity', - name: 'Perplexity', - import: () => import('@ai-sdk/perplexity'), - creatorFunctionName: 'createPerplexity', - supportsImageGeneration: false, - aliases: ['perplexity'] - }, - { - id: 'mistral', - name: 'Mistral', - import: () => import('@ai-sdk/mistral'), - creatorFunctionName: 'createMistral', - supportsImageGeneration: false, - aliases: ['mistral'] - }, - { - id: 'huggingface', - name: 'HuggingFace', - import: () => import('@ai-sdk/huggingface'), - creatorFunctionName: 'createHuggingFace', - supportsImageGeneration: true, - aliases: ['hf', 'hugging-face'] - }, - { - id: 'ai-gateway', - name: 'AI Gateway', - import: () => import('@ai-sdk/gateway'), - creatorFunctionName: 'createGateway', - supportsImageGeneration: true, - aliases: ['gateway'] - }, - { - id: 'cerebras', - name: 'Cerebras', - import: () => import('@ai-sdk/cerebras'), - creatorFunctionName: 'createCerebras', - supportsImageGeneration: false - } -] as const +export const NEW_PROVIDER_CONFIGS = SHARED_PROVIDER_CONFIGS -/** - * 初始化新的Providers - * 使用aiCore的动态注册功能 - */ export async function initializeNewProviders(): Promise { - try { - const successCount = registerMultipleProviderConfigs(NEW_PROVIDER_CONFIGS) - if (successCount < NEW_PROVIDER_CONFIGS.length) { - logger.warn('Some providers failed to register. Check previous error logs.') - } - } catch (error) { - logger.error('Failed to initialize new providers:', error as Error) - } + initializeSharedProviders({ + warn: (message) => logger.warn(message), + error: (message, error) => logger.error(message, error) + }) } diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index bff57185a7..7db6cbf4bd 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -7,6 +7,8 @@ import type { CSSProperties } from 'react' export * from './file' export * from './note' +import type { MinimalModel } from '@shared/provider/types' + import type { StreamTextParams } from './aiCoreTypes' import type { Chunk } from './chunk' import type { FileMetadata } from './file' @@ -256,7 +258,7 @@ export type ModelCapability = { isUserSelected?: boolean } -export type Model = { +export type Model = MinimalModel & { id: string provider: string name: string diff --git a/src/renderer/src/types/provider.ts b/src/renderer/src/types/provider.ts index 9d948f16d0..573e1e1007 100644 --- a/src/renderer/src/types/provider.ts +++ b/src/renderer/src/types/provider.ts @@ -1,24 +1,14 @@ import type OpenAI from '@cherrystudio/openai' +import type { MinimalProvider } from '@shared/provider' +import type { ProviderType, SystemProviderId, SystemProviderIdTypeMap } from '@shared/provider/types' +import { isSystemProviderId, SystemProviderIds } from '@shared/provider/types' import type { Model } from '@types' -import * as z from 'zod' import type { OpenAIVerbosity } from './aiCoreTypes' -export const ProviderTypeSchema = z.enum([ - 'openai', - 'openai-response', - 'anthropic', - 'gemini', - 'azure-openai', - 'vertexai', - 'mistral', - 'aws-bedrock', - 'vertex-anthropic', - 'new-api', - 'ai-gateway' -]) - -export type ProviderType = z.infer +export type { ProviderType } from '@shared/provider' +export type { SystemProviderId, SystemProviderIdTypeMap } from '@shared/provider/types' +export { isSystemProviderId, ProviderTypeSchema, SystemProviderIds } from '@shared/provider/types' // undefined is treated as supported, enabled by default export type ProviderApiOptions = { @@ -93,7 +83,7 @@ export function isAwsBedrockAuthType(type: string): type is AwsBedrockAuthType { return Object.hasOwn(AwsBedrockAuthTypes, type) } -export type Provider = { +export type Provider = MinimalProvider & { id: string type: ProviderType name: string @@ -128,140 +118,6 @@ export type Provider = { extra_headers?: Record } -export const SystemProviderIdSchema = z.enum([ - 'cherryin', - 'silicon', - 'aihubmix', - 'ocoolai', - 'deepseek', - 'ppio', - 'alayanew', - 'qiniu', - 'dmxapi', - 'burncloud', - 'tokenflux', - '302ai', - 'cephalon', - 'lanyun', - 'ph8', - 'openrouter', - 'ollama', - 'ovms', - 'new-api', - 'lmstudio', - 'anthropic', - 'openai', - 'azure-openai', - 'gemini', - 'vertexai', - 'github', - 'copilot', - 'zhipu', - 'yi', - 'moonshot', - 'baichuan', - 'dashscope', - 'stepfun', - 'doubao', - 'infini', - 'minimax', - 'groq', - 'together', - 'fireworks', - 'nvidia', - 'grok', - 'hyperbolic', - 'mistral', - 'jina', - 'perplexity', - 'modelscope', - 'xirang', - 'hunyuan', - 'tencent-cloud-ti', - 'baidu-cloud', - 'gpustack', - 'voyageai', - 'aws-bedrock', - 'poe', - 'aionly', - 'longcat', - 'huggingface', - 'sophnet', - 'ai-gateway', - 'cerebras' -]) - -export type SystemProviderId = z.infer - -export const isSystemProviderId = (id: string): id is SystemProviderId => { - return SystemProviderIdSchema.safeParse(id).success -} - -export const SystemProviderIds = { - cherryin: 'cherryin', - silicon: 'silicon', - aihubmix: 'aihubmix', - ocoolai: 'ocoolai', - deepseek: 'deepseek', - ppio: 'ppio', - alayanew: 'alayanew', - qiniu: 'qiniu', - dmxapi: 'dmxapi', - burncloud: 'burncloud', - tokenflux: 'tokenflux', - '302ai': '302ai', - cephalon: 'cephalon', - lanyun: 'lanyun', - ph8: 'ph8', - sophnet: 'sophnet', - openrouter: 'openrouter', - ollama: 'ollama', - ovms: 'ovms', - 'new-api': 'new-api', - lmstudio: 'lmstudio', - anthropic: 'anthropic', - openai: 'openai', - 'azure-openai': 'azure-openai', - gemini: 'gemini', - vertexai: 'vertexai', - github: 'github', - copilot: 'copilot', - zhipu: 'zhipu', - yi: 'yi', - moonshot: 'moonshot', - baichuan: 'baichuan', - dashscope: 'dashscope', - stepfun: 'stepfun', - doubao: 'doubao', - infini: 'infini', - minimax: 'minimax', - groq: 'groq', - together: 'together', - fireworks: 'fireworks', - nvidia: 'nvidia', - grok: 'grok', - hyperbolic: 'hyperbolic', - mistral: 'mistral', - jina: 'jina', - perplexity: 'perplexity', - modelscope: 'modelscope', - xirang: 'xirang', - hunyuan: 'hunyuan', - 'tencent-cloud-ti': 'tencent-cloud-ti', - 'baidu-cloud': 'baidu-cloud', - gpustack: 'gpustack', - voyageai: 'voyageai', - 'aws-bedrock': 'aws-bedrock', - poe: 'poe', - aionly: 'aionly', - longcat: 'longcat', - huggingface: 'huggingface', - 'ai-gateway': 'ai-gateway', - cerebras: 'cerebras' -} as const satisfies Record - -type SystemProviderIdTypeMap = typeof SystemProviderIds - export type SystemProvider = Provider & { id: SystemProviderId isSystem: true diff --git a/src/renderer/src/utils/api.ts b/src/renderer/src/utils/api.ts index 845187eb80..ab411e6f15 100644 --- a/src/renderer/src/utils/api.ts +++ b/src/renderer/src/utils/api.ts @@ -1,6 +1,15 @@ -import store from '@renderer/store' -import type { VertexProvider } from '@renderer/types' -import { trim } from 'lodash' +export { + formatApiHost, + formatAzureOpenAIApiHost, + formatVertexApiHost, + getAiSdkBaseUrl, + hasAPIVersion, + routeToEndpoint, + SUPPORTED_ENDPOINT_LIST, + SUPPORTED_IMAGE_ENDPOINT_LIST, + validateApiHost, + withoutTrailingSlash +} from '@shared/api' /** * 格式化 API key 字符串。 @@ -12,169 +21,6 @@ export function formatApiKeys(value: string): string { return value.replaceAll(',', ',').replaceAll('\n', ',') } -/** - * 判断 host 的 path 中是否包含形如版本的字符串(例如 /v1、/v2beta 等), - * - * @param host - 要检查的 host 或 path 字符串 - * @returns 如果 path 中包含版本字符串则返回 true,否则 false - */ -export function hasAPIVersion(host?: string): boolean { - if (!host) return false - - // 匹配路径中以 `/v` 开头并可选跟随 `alpha` 或 `beta` 的版本段, - // 该段后面可以跟 `/` 或字符串结束(用于匹配诸如 `/v3alpha/resources` 的情况)。 - const versionRegex = /\/v\d+(?:alpha|beta)?(?=\/|$)/i - - try { - const url = new URL(host) - return versionRegex.test(url.pathname) - } catch { - // 若无法作为完整 URL 解析,则当作路径直接检测 - return versionRegex.test(host) - } -} - -/** - * Removes the trailing slash from a URL string if it exists. - * - * @template T - The string type to preserve type safety - * @param {T} url - The URL string to process - * @returns {T} The URL string without a trailing slash - * - * @example - * ```ts - * withoutTrailingSlash('https://example.com/') // 'https://example.com' - * withoutTrailingSlash('https://example.com') // 'https://example.com' - * ``` - */ -export function withoutTrailingSlash(url: T): T { - return url.replace(/\/$/, '') as T -} - -/** - * Formats an API host URL by normalizing it and optionally appending an API version. - * - * @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed. - * @param isSupportedAPIVerion - Whether the API version is supported. Defaults to `true`. - * @param apiVersion - The API version to append if needed. Defaults to `'v1'`. - * - * @returns The formatted API host URL. If the host is empty after normalization, returns an empty string. - * If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host as-is. - * Otherwise, returns the host with the API version appended. - * - * @example - * formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1' - * formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#' - * formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2' - */ -export function formatApiHost(host?: string, isSupportedAPIVerion: boolean = true, apiVersion: string = 'v1'): string { - const normalizedHost = withoutTrailingSlash(trim(host)) - if (!normalizedHost) { - return '' - } - - if (normalizedHost.endsWith('#') || !isSupportedAPIVerion || hasAPIVersion(normalizedHost)) { - return normalizedHost - } - return `${normalizedHost}/${apiVersion}` -} - -/** - * 格式化 Azure OpenAI 的 API 主机地址。 - */ -export function formatAzureOpenAIApiHost(host: string): string { - const normalizedHost = withoutTrailingSlash(host) - ?.replace(/\/v1$/, '') - .replace(/\/openai$/, '') - // NOTE: AISDK会添加上`v1` - return formatApiHost(normalizedHost + '/openai', false) -} - -export function formatVertexApiHost(provider: VertexProvider): string { - const { apiHost } = provider - const { projectId: project, location } = store.getState().llm.settings.vertexai - const trimmedHost = withoutTrailingSlash(trim(apiHost)) - if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) { - const host = - location == 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com` - return `${formatApiHost(host)}/projects/${project}/locations/${location}` - } - return formatApiHost(trimmedHost) -} - -// 目前对话界面只支持这些端点 -export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const -export const SUPPORTED_ENDPOINT_LIST = [ - 'chat/completions', - 'responses', - 'messages', - 'generateContent', - 'streamGenerateContent', - ...SUPPORTED_IMAGE_ENDPOINT_LIST -] as const - -/** - * Converts an API host URL into separate base URL and endpoint components. - * - * @param apiHost - The API host string to parse. Expected to be a trimmed URL that may end with '#' followed by an endpoint identifier. - * @returns An object containing: - * - `baseURL`: The base URL without the endpoint suffix - * - `endpoint`: The matched endpoint identifier, or empty string if no match found - * - * @description - * This function extracts endpoint information from a composite API host string. - * If the host ends with '#', it attempts to match the preceding part against the supported endpoint list. - * The '#' delimiter is removed before processing. - * - * @example - * routeToEndpoint('https://api.example.com/openai/chat/completions#') - * // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' } - * - * @example - * routeToEndpoint('https://api.example.com/v1') - * // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' } - */ -export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } { - const trimmedHost = trim(apiHost) - // 前面已经确保apiHost合法 - if (!trimmedHost.endsWith('#')) { - return { baseURL: trimmedHost, endpoint: '' } - } - // 去掉结尾的 # - const host = trimmedHost.slice(0, -1) - const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint)) - if (!endpointMatch) { - const baseURL = withoutTrailingSlash(host) - return { baseURL, endpoint: '' } - } - const baseSegment = host.slice(0, host.length - endpointMatch.length) - const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // 去掉结尾可能存在的冒号(gemini的特殊情况) - return { baseURL, endpoint: endpointMatch } -} - -/** - * 验证 API 主机地址是否合法。 - * - * @param {string} apiHost - 需要验证的 API 主机地址。 - * @returns {boolean} 如果是合法的 URL 则返回 true,否则返回 false。 - */ -export function validateApiHost(apiHost: string): boolean { - // 允许apiHost为空 - if (!apiHost || !trim(apiHost)) { - return true - } - try { - const url = new URL(trim(apiHost)) - // 验证协议是否为 http 或 https - if (url.protocol !== 'http:' && url.protocol !== 'https:') { - return false - } - return true - } catch { - return false - } -} - /** * API key 脱敏函数。仅保留部分前后字符,中间用星号代替。 * diff --git a/src/renderer/src/utils/naming.ts b/src/renderer/src/utils/naming.ts index bc24bc7db8..d258cee81a 100644 --- a/src/renderer/src/utils/naming.ts +++ b/src/renderer/src/utils/naming.ts @@ -2,6 +2,8 @@ import { getProviderLabel } from '@renderer/i18n/label' import type { Provider } from '@renderer/types' import { isSystemProvider } from '@renderer/types' +export { getBaseModelName, getLowerBaseModelName } from '@shared/utils/naming' + /** * 从模型 ID 中提取默认组名。 * 规则如下: @@ -50,38 +52,6 @@ export const getDefaultGroupName = (id: string, provider?: string): string => { return str } -/** - * 从模型 ID 中提取基础名称。 - * 例如: - * - 'deepseek/deepseek-r1' => 'deepseek-r1' - * - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1' - * @param {string} id 模型 ID - * @param {string} [delimiter='/'] 分隔符,默认为 '/' - * @returns {string} 基础名称 - */ -export const getBaseModelName = (id: string, delimiter: string = '/'): string => { - const parts = id.split(delimiter) - return parts[parts.length - 1] -} - -/** - * 从模型 ID 中提取基础名称并转换为小写。 - * 例如: - * - 'deepseek/DeepSeek-R1' => 'deepseek-r1' - * - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1' - * @param {string} id 模型 ID - * @param {string} [delimiter='/'] 分隔符,默认为 '/' - * @returns {string} 小写的基础名称 - */ -export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => { - const baseModelName = getBaseModelName(id, delimiter).toLowerCase() - // for openrouter - if (baseModelName.endsWith(':free')) { - return baseModelName.replace(':free', '') - } - return baseModelName -} - /** * 获取模型服务商名称,根据是否内置服务商来决定要不要翻译 * @param provider 服务商 diff --git a/src/renderer/src/utils/provider.ts b/src/renderer/src/utils/provider.ts index e8fc1b5cc7..2c79d36352 100644 --- a/src/renderer/src/utils/provider.ts +++ b/src/renderer/src/utils/provider.ts @@ -1,10 +1,20 @@ import { CLAUDE_SUPPORTED_PROVIDERS } from '@renderer/pages/code' -import type { AzureOpenAIProvider, ProviderType, VertexProvider } from '@renderer/types' +import type { ProviderType } from '@renderer/types' import { isSystemProvider, type Provider, type SystemProviderId, SystemProviderIds } from '@renderer/types' - -export const isAzureResponsesEndpoint = (provider: AzureOpenAIProvider) => { - return provider.apiVersion === 'preview' || provider.apiVersion === 'v1' -} +export { + isAIGatewayProvider, + isAnthropicProvider, + isAwsBedrockProvider, + isAzureOpenAIProvider, + isAzureResponsesEndpoint, + isCherryAIProvider, + isGeminiProvider, + isNewApiProvider, + isOpenAICompatibleProvider, + isOpenAIProvider, + isPerplexityProvider, + isVertexProvider +} from '@shared/provider' export const getClaudeSupportedProviders = (providers: Provider[]) => { return providers.filter( @@ -119,55 +129,6 @@ export const isGeminiWebSearchProvider = (provider: Provider) => { return SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS.some((id) => id === provider.id) } -export const isNewApiProvider = (provider: Provider) => { - return ['new-api', 'cherryin'].includes(provider.id) || provider.type === 'new-api' -} - -export function isCherryAIProvider(provider: Provider): boolean { - return provider.id === 'cherryai' -} - -export function isPerplexityProvider(provider: Provider): boolean { - return provider.id === 'perplexity' -} - -/** - * 判断是否为 OpenAI 兼容的提供商 - * @param {Provider} provider 提供商对象 - * @returns {boolean} 是否为 OpenAI 兼容提供商 - */ -export function isOpenAICompatibleProvider(provider: Provider): boolean { - return ['openai', 'new-api', 'mistral'].includes(provider.type) -} - -export function isAzureOpenAIProvider(provider: Provider): provider is AzureOpenAIProvider { - return provider.type === 'azure-openai' -} - -export function isOpenAIProvider(provider: Provider): boolean { - return provider.type === 'openai-response' -} - -export function isVertexProvider(provider: Provider): provider is VertexProvider { - return provider.type === 'vertexai' -} - -export function isAwsBedrockProvider(provider: Provider): boolean { - return provider.type === 'aws-bedrock' -} - -export function isAnthropicProvider(provider: Provider): boolean { - return provider.type === 'anthropic' -} - -export function isGeminiProvider(provider: Provider): boolean { - return provider.type === 'gemini' -} - -export function isAIGatewayProvider(provider: Provider): boolean { - return provider.type === 'ai-gateway' -} - const NOT_SUPPORT_API_VERSION_PROVIDERS = ['github', 'copilot', 'perplexity'] as const satisfies SystemProviderId[] export const isSupportAPIVersionProvider = (provider: Provider) => { From 192357a32e150fa7952b4a103260a1e13572db31 Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 27 Nov 2025 19:19:04 +0800 Subject: [PATCH 02/53] feat: Enhance thinking block management and tool conversion in unified messages --- .../shared/adapters/AiSdkToAnthropicSSE.ts | 108 +++++++++++++----- .../apiServer/services/unified-messages.ts | 89 +++++++++++++-- .../agents/services/claudecode/transform.ts | 30 ++++- 3 files changed, 188 insertions(+), 39 deletions(-) diff --git a/packages/shared/adapters/AiSdkToAnthropicSSE.ts b/packages/shared/adapters/AiSdkToAnthropicSSE.ts index 38fab703ac..1674609236 100644 --- a/packages/shared/adapters/AiSdkToAnthropicSSE.ts +++ b/packages/shared/adapters/AiSdkToAnthropicSSE.ts @@ -59,7 +59,9 @@ interface AdapterState { currentBlockIndex: number blocks: Map textBlockIndex: number | null - thinkingBlockIndex: number | null + // Track multiple thinking blocks by their reasoning ID + thinkingBlocks: Map // reasoningId -> blockIndex + currentThinkingId: string | null // Currently active thinking block ID toolBlocks: Map // toolCallId -> blockIndex stopReason: StopReason | null hasEmittedMessageStart: boolean @@ -95,7 +97,8 @@ export class AiSdkToAnthropicSSE { currentBlockIndex: 0, blocks: new Map(), textBlockIndex: null, - thinkingBlockIndex: null, + thinkingBlocks: new Map(), + currentThinkingId: null, toolBlocks: new Map(), stopReason: null, hasEmittedMessageStart: false @@ -133,7 +136,7 @@ export class AiSdkToAnthropicSSE { * Process a single AI SDK chunk and emit corresponding Anthropic events */ private processChunk(chunk: TextStreamPart): void { - logger.silly('AiSdkToAnthropicSSE - Processing chunk:', chunk) + logger.silly('AiSdkToAnthropicSSE - Processing chunk:', { chunk: JSON.stringify(chunk) }) switch (chunk.type) { // === Text Events === case 'text-start': @@ -149,17 +152,23 @@ export class AiSdkToAnthropicSSE { break // === Reasoning/Thinking Events === - case 'reasoning-start': - this.startThinkingBlock() + case 'reasoning-start': { + const reasoningId = (chunk as { id?: string }).id || `reasoning_${Date.now()}` + this.startThinkingBlock(reasoningId) break + } - case 'reasoning-delta': - this.emitThinkingDelta(chunk.text || '') + case 'reasoning-delta': { + const reasoningId = (chunk as { id?: string }).id + this.emitThinkingDelta(chunk.text || '', reasoningId) break + } - case 'reasoning-end': - this.stopThinkingBlock() + case 'reasoning-end': { + const reasoningId = (chunk as { id?: string }).id + this.stopThinkingBlock(reasoningId) break + } // === Tool Events === case 'tool-call': @@ -190,9 +199,7 @@ export class AiSdkToAnthropicSSE { // === Error Events === case 'error': - // Anthropic doesn't have a standard error event in the stream - // Errors are typically sent as separate HTTP responses - // For now, we'll just log and continue + this.handleError(chunk.error) break // Ignore other event types @@ -303,11 +310,13 @@ export class AiSdkToAnthropicSSE { this.state.textBlockIndex = null } - private startThinkingBlock(): void { - if (this.state.thinkingBlockIndex !== null) return + private startThinkingBlock(reasoningId: string): void { + // Check if this thinking block already exists + if (this.state.thinkingBlocks.has(reasoningId)) return const index = this.state.currentBlockIndex++ - this.state.thinkingBlockIndex = index + this.state.thinkingBlocks.set(reasoningId, index) + this.state.currentThinkingId = reasoningId this.state.blocks.set(index, { type: 'thinking', index, @@ -330,15 +339,25 @@ export class AiSdkToAnthropicSSE { this.onEvent(event) } - private emitThinkingDelta(text: string): void { + private emitThinkingDelta(text: string, reasoningId?: string): void { if (!text) return - // Auto-start thinking block if not started - if (this.state.thinkingBlockIndex === null) { - this.startThinkingBlock() + // Determine which thinking block to use + const targetId = reasoningId || this.state.currentThinkingId + if (!targetId) { + // Auto-start thinking block if not started + const newId = `reasoning_${Date.now()}` + this.startThinkingBlock(newId) + return this.emitThinkingDelta(text, newId) + } + + const index = this.state.thinkingBlocks.get(targetId) + if (index === undefined) { + // If the block doesn't exist, create it + this.startThinkingBlock(targetId) + return this.emitThinkingDelta(text, targetId) } - const index = this.state.thinkingBlockIndex! const block = this.state.blocks.get(index) if (block) { block.content += text @@ -358,10 +377,12 @@ export class AiSdkToAnthropicSSE { this.onEvent(event) } - private stopThinkingBlock(): void { - if (this.state.thinkingBlockIndex === null) return + private stopThinkingBlock(reasoningId?: string): void { + const targetId = reasoningId || this.state.currentThinkingId + if (!targetId) return - const index = this.state.thinkingBlockIndex + const index = this.state.thinkingBlocks.get(targetId) + if (index === undefined) return const event: RawContentBlockStopEvent = { type: 'content_block_stop', @@ -369,7 +390,14 @@ export class AiSdkToAnthropicSSE { } this.onEvent(event) - this.state.thinkingBlockIndex = null + this.state.thinkingBlocks.delete(targetId) + + // Update currentThinkingId if we just closed the current one + if (this.state.currentThinkingId === targetId) { + // Set to the most recent remaining thinking block, or null if none + const remaining = Array.from(this.state.thinkingBlocks.keys()) + this.state.currentThinkingId = remaining.length > 0 ? remaining[remaining.length - 1] : null + } } private handleToolCall(chunk: { type: 'tool-call'; toolCallId: string; toolName: string; args: unknown }): void { @@ -471,13 +499,41 @@ export class AiSdkToAnthropicSSE { } } + private handleError(error: unknown): void { + // Log the error for debugging + logger.warn('AiSdkToAnthropicSSE - Provider error received:', { error }) + + // Extract error message + let errorMessage = 'Unknown error from provider' + if (error && typeof error === 'object') { + const err = error as { message?: string; metadata?: { raw?: string } } + if (err.metadata?.raw) { + errorMessage = `Provider error: ${err.metadata.raw}` + } else if (err.message) { + errorMessage = err.message + } + } else if (typeof error === 'string') { + errorMessage = error + } + + // Emit error as a text block so the user can see it + // First close any open thinking blocks to maintain proper event order + for (const reasoningId of Array.from(this.state.thinkingBlocks.keys())) { + this.stopThinkingBlock(reasoningId) + } + + // Emit the error as text + this.emitTextDelta(`\n\n[Error: ${errorMessage}]\n`) + } + private finalize(): void { // Close any open blocks if (this.state.textBlockIndex !== null) { this.stopTextBlock() } - if (this.state.thinkingBlockIndex !== null) { - this.stopThinkingBlock() + // Close all open thinking blocks + for (const reasoningId of this.state.thinkingBlocks.keys()) { + this.stopThinkingBlock(reasoningId) } // Emit message_delta with final stop reason and usage diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index d0acd70231..0f71cfcfae 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -1,6 +1,11 @@ import type { LanguageModelV2ToolResultOutput } from '@ai-sdk/provider' import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' -import type { ImageBlockParam, MessageCreateParams, TextBlockParam } from '@anthropic-ai/sdk/resources/messages' +import type { + ImageBlockParam, + MessageCreateParams, + TextBlockParam, + Tool as AnthropicTool +} from '@anthropic-ai/sdk/resources/messages' import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' import { reduxService } from '@main/services/ReduxService' @@ -16,8 +21,8 @@ import { } from '@shared/provider' import { defaultAppHeaders } from '@shared/utils' import type { Provider } from '@types' -import type { ImagePart, LanguageModel, ModelMessage, Provider as AiSdkProvider, TextPart } from 'ai' -import { stepCountIs, streamText } from 'ai' +import type { ImagePart, LanguageModel, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai' +import { jsonSchema, stepCountIs, streamText, tool } from 'ai' import { net } from 'electron' import type { Response } from 'express' @@ -190,6 +195,39 @@ IANA media type. } } +/** + * Convert Anthropic tools format to AI SDK tools format + */ +function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record | undefined { + if (!tools || tools.length === 0) { + return undefined + } + + const aiSdkTools: Record = {} + + for (const anthropicTool of tools) { + // Handle different tool types + if (anthropicTool.type === 'bash_20250124') { + // Skip computer use and bash tools - these are Anthropic-specific + continue + } + + // Regular tool (type === 'custom' or no type) + const toolDef = anthropicTool as AnthropicTool + const parameters = toolDef.input_schema as Parameters[0] + + aiSdkTools[toolDef.name] = tool({ + description: toolDef.description || '', + inputSchema: jsonSchema(parameters), + execute: async (input: Record) => { + return input + } + }) + } + + return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined +} + /** * Convert Anthropic MessageCreateParams to AI SDK message format */ @@ -271,6 +309,13 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } } + if (toolResultParts.length > 0) { + messages.push({ + role: 'tool', + content: [...toolResultParts] + }) + } + // Build the message based on role if (msg.role === 'user') { messages.push({ @@ -278,13 +323,11 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage content: [...textParts, ...imageParts] }) } else { - // Assistant messages can only have text - if (textParts.length > 0) { - messages.push({ - role: 'assistant', - content: [...reasoningParts, ...textParts, ...toolCallParts, ...toolResultParts] - }) - } + // Assistant messages contain tool calls, not tool results + messages.push({ + role: 'assistant', + content: [...reasoningParts, ...textParts, ...toolCallParts] + }) } } } @@ -315,10 +358,29 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis const coreMessages = convertAnthropicToAiMessages(params) + // Convert tools if present + const tools = convertAnthropicToolsToAiSdk(params.tools) + logger.debug('Converted messages', { originalCount: params.messages.length, convertedCount: coreMessages.length, - hasSystem: !!params.system + hasSystem: !!params.system, + hasTools: !!tools, + toolCount: tools ? Object.keys(tools).length : 0, + toolNames: tools ? Object.keys(tools).slice(0, 10) : [], + paramsToolCount: params.tools?.length || 0 + }) + + // Debug: Log message structure to understand tool_result handling + logger.silly('Message structure for debugging', { + messages: coreMessages.map((m) => ({ + role: m.role, + contentTypes: Array.isArray(m.content) + ? m.content.map((c: { type: string }) => c.type) + : typeof m.content === 'string' + ? ['string'] + : ['unknown'] + })) }) // Create the adapter @@ -340,6 +402,7 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis stopSequences: params.stop_sequences, stopWhen: stepCountIs(100), headers: defaultAppHeaders(), + tools, providerOptions: {} }) @@ -404,8 +467,9 @@ export async function generateUnifiedMessage( // Create language model (async - uses @cherrystudio/ai-core) const model = await createLanguageModel(provider, modelId) - // Convert messages + // Convert messages and tools const coreMessages = convertAnthropicToAiMessages(params) + const tools = convertAnthropicToolsToAiSdk(params.tools) // Create adapter to collect the response let finalResponse: ReturnType | null = null @@ -425,6 +489,7 @@ export async function generateUnifiedMessage( topP: params.top_p, stopSequences: params.stop_sequences, headers: defaultAppHeaders(), + tools, stopWhen: stepCountIs(100) }) diff --git a/src/main/services/agents/services/claudecode/transform.ts b/src/main/services/agents/services/claudecode/transform.ts index 00be683ba8..fa0c615648 100644 --- a/src/main/services/agents/services/claudecode/transform.ts +++ b/src/main/services/agents/services/claudecode/transform.ts @@ -193,6 +193,30 @@ function handleAssistantMessage( } break } + case 'thinking': + case 'redacted_thinking': { + const thinkingText = block.type === 'thinking' ? block.thinking : block.data + if (thinkingText) { + const id = generateMessageId() + chunks.push({ + type: 'reasoning-start', + id, + providerMetadata + }) + chunks.push({ + type: 'reasoning-delta', + id, + text: thinkingText, + providerMetadata + }) + chunks.push({ + type: 'reasoning-end', + id, + providerMetadata + }) + } + break + } case 'tool_use': handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks) break @@ -445,7 +469,11 @@ function handleStreamEvent( case 'content_block_stop': { const block = state.closeBlock(event.index) if (!block) { - logger.warn('Received content_block_stop for unknown index', { index: event.index }) + // Some providers (e.g., Gemini) send content via assistant message before stream events, + // so the block may not exist in state. This is expected behavior, not an error. + logger.debug('Received content_block_stop for unknown index (may be from non-streaming content)', { + index: event.index + }) break } From ccfb9423e0df019984259814ad509e23ed40d697 Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 27 Nov 2025 19:31:59 +0800 Subject: [PATCH 03/53] chore: format --- packages/shared/api/index.ts | 8 ++++++-- packages/shared/provider/config/helper.ts | 10 +++++----- packages/shared/provider/resolve.ts | 3 +-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/packages/shared/api/index.ts b/packages/shared/api/index.ts index 0cf652b427..5ee19611d8 100644 --- a/packages/shared/api/index.ts +++ b/packages/shared/api/index.ts @@ -53,12 +53,16 @@ export function formatAzureOpenAIApiHost(host: string): string { return formatApiHost(normalizedHost + '/openai', false) } -export function formatVertexApiHost(provider: MinimalProvider, project: string, location: string): string { +export function formatVertexApiHost( + provider: MinimalProvider, + project: string = 'test-project', + location: string = 'us-central1' +): string { const { apiHost } = provider const trimmedHost = withoutTrailingSlash(trim(apiHost)) if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) { const host = - location == 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com` + location === 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com` return `${formatApiHost(host)}/projects/${project}/locations/${location}` } return formatApiHost(trimmedHost) diff --git a/packages/shared/provider/config/helper.ts b/packages/shared/provider/config/helper.ts index 4e821a6c8f..95f53f885a 100644 --- a/packages/shared/provider/config/helper.ts +++ b/packages/shared/provider/config/helper.ts @@ -18,11 +18,11 @@ export const endpointIs = * @param provider 原始provider对象 * @returns 解析出的provider对象 */ -export function provider2Provider< - M extends MinimalModel, - R extends MinimalProvider, - P extends R = R ->(ruleSet: RuleSet, model: M, provider: P): P { +export function provider2Provider( + ruleSet: RuleSet, + model: M, + provider: P +): P { for (const rule of ruleSet.rules) { if (rule.match(model)) { return rule.provider(provider) as P diff --git a/packages/shared/provider/resolve.ts b/packages/shared/provider/resolve.ts index 9055a36c6e..385da6a586 100644 --- a/packages/shared/provider/resolve.ts +++ b/packages/shared/provider/resolve.ts @@ -25,8 +25,7 @@ export function resolveActualProvider Date: Thu, 27 Nov 2025 20:43:47 +0800 Subject: [PATCH 04/53] feat: Implement direct processing for Anthropic SDK and refactor message handling --- packages/shared/anthropic/index.ts | 28 ++++- src/main/apiServer/routes/messages.ts | 100 ++++++++++++++++-- src/main/apiServer/services/messages.ts | 24 ++++- .../apiServer/services/unified-messages.ts | 5 +- 4 files changed, 141 insertions(+), 16 deletions(-) diff --git a/packages/shared/anthropic/index.ts b/packages/shared/anthropic/index.ts index bff143d118..2444ad6113 100644 --- a/packages/shared/anthropic/index.ts +++ b/packages/shared/anthropic/index.ts @@ -16,6 +16,20 @@ import type { ModelMessage } from 'ai' const logger = loggerService.withContext('anthropic-sdk') +/** + * Context for Anthropic SDK client creation. + * This allows the shared module to be used in different environments + * by providing environment-specific implementations. + */ +export interface AnthropicSdkContext { + /** + * Custom fetch function to use for HTTP requests. + * In Electron main process, this should be `net.fetch`. + * In other environments, can use the default fetch or a custom implementation. + */ + fetch?: typeof globalThis.fetch +} + const defaultClaudeCodeSystemPrompt = `You are Claude Code, Anthropic's official CLI for Claude.` const defaultClaudeCodeSystem: Array = [ @@ -58,8 +72,11 @@ const defaultClaudeCodeSystem: Array = [ export function getSdkClient( provider: Provider, oauthToken?: string | null, - extraHeaders?: Record + extraHeaders?: Record, + context?: AnthropicSdkContext ): Anthropic { + const customFetch = context?.fetch + if (provider.authType === 'oauth') { if (!oauthToken) { throw new Error('OAuth token is not available') @@ -85,7 +102,8 @@ export function getSdkClient( 'x-stainless-runtime': 'node', 'x-stainless-runtime-version': 'v22.18.0', ...extraHeaders - } + }, + fetch: customFetch }) } let baseURL = @@ -110,7 +128,8 @@ export function getSdkClient( 'APP-Code': 'MLTG2087', ...provider.extra_headers, ...extraHeaders - } + }, + fetch: customFetch }) } @@ -122,7 +141,8 @@ export function getSdkClient( defaultHeaders: { 'anthropic-beta': 'output-128k-2025-02-19', ...provider.extra_headers - } + }, + fetch: customFetch }) } diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index 1ce42c46ea..f2590cf1d5 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -8,6 +8,17 @@ import { messagesService } from '../services/messages' import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages' import { getProviderById, validateModelId } from '../utils' +/** + * Check if provider should use direct Anthropic SDK + * + * A provider is considered "Anthropic-compatible" if: + * 1. It's a native Anthropic provider (type === 'anthropic'), OR + * 2. It has anthropicApiHost configured (aggregated providers routing to Anthropic-compatible endpoints) + */ +function shouldUseDirectAnthropic(provider: Provider): boolean { + return provider.type === 'anthropic' || !!(provider.anthropicApiHost && provider.anthropicApiHost.trim()) +} + const logger = loggerService.withContext('ApiServerMessagesRoutes') const router = express.Router() @@ -41,12 +52,70 @@ interface HandleMessageProcessingOptions { } /** - * Handle message processing using unified AI SDK - * All providers (including Anthropic) are handled through AI SDK: - * - Anthropic providers use @ai-sdk/anthropic which outputs native Anthropic SSE - * - Other providers use their respective AI SDK adapters, with output converted to Anthropic SSE + * Handle message processing using direct Anthropic SDK + * Used for providers with anthropicApiHost or native Anthropic providers + * This bypasses AI SDK conversion and uses native Anthropic protocol */ -async function handleMessageProcessing({ +async function handleDirectAnthropicProcessing({ + res, + provider, + request, + modelId, + extraHeaders +}: HandleMessageProcessingOptions & { extraHeaders?: Record }): Promise { + const actualModelId = modelId || request.model + + logger.info('Processing message via direct Anthropic SDK', { + providerId: provider.id, + providerType: provider.type, + modelId: actualModelId, + stream: !!request.stream, + anthropicApiHost: provider.anthropicApiHost + }) + + try { + // Validate request + const validation = messagesService.validateRequest(request) + if (!validation.isValid) { + res.status(400).json({ + type: 'error', + error: { + type: 'invalid_request_error', + message: validation.errors.join('; ') + } + }) + return + } + + // Process message using messagesService (native Anthropic SDK) + const { client, anthropicRequest } = await messagesService.processMessage({ + provider, + request, + extraHeaders, + modelId: actualModelId + }) + + if (request.stream) { + // Use native Anthropic streaming + await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider) + } else { + // Use native Anthropic non-streaming + const response = await client.messages.create(anthropicRequest) + res.json(response) + } + } catch (error: any) { + logger.error('Direct Anthropic processing error', { error }) + const { statusCode, errorResponse } = messagesService.transformError(error) + res.status(statusCode).json(errorResponse) + } +} + +/** + * Handle message processing using unified AI SDK + * Used for non-Anthropic providers that need format conversion + * - Uses AI SDK adapters with output converted to Anthropic SSE format + */ +async function handleUnifiedProcessing({ res, provider, request, @@ -93,12 +162,31 @@ async function handleMessageProcessing({ res.json(response) } } catch (error: any) { - logger.error('Message processing error', { error }) + logger.error('Unified processing error', { error }) const { statusCode, errorResponse } = messagesService.transformError(error) res.status(statusCode).json(errorResponse) } } +/** + * Handle message processing - routes to appropriate handler based on provider + * + * Routing logic: + * - Providers with anthropicApiHost OR type 'anthropic': Direct Anthropic SDK (no conversion) + * - Other providers: Unified AI SDK with Anthropic SSE conversion + */ +async function handleMessageProcessing({ + res, + provider, + request, + modelId +}: HandleMessageProcessingOptions): Promise { + if (shouldUseDirectAnthropic(provider)) { + return handleDirectAnthropicProcessing({ res, provider, request, modelId }) + } + return handleUnifiedProcessing({ res, provider, request, modelId }) +} + /** * @swagger * /v1/messages: diff --git a/src/main/apiServer/services/messages.ts b/src/main/apiServer/services/messages.ts index 8b46deaa8f..e3fbd069a7 100644 --- a/src/main/apiServer/services/messages.ts +++ b/src/main/apiServer/services/messages.ts @@ -4,6 +4,7 @@ import { loggerService } from '@logger' import anthropicService from '@main/services/AnthropicService' import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic' import type { Provider } from '@types' +import { net } from 'electron' import type { Response } from 'express' const logger = loggerService.withContext('MessagesService') @@ -98,11 +99,30 @@ export class MessagesService { async getClient(provider: Provider, extraHeaders?: Record): Promise { // Create Anthropic client for the provider + // Wrap net.fetch to handle compatibility issues: + // 1. net.fetch expects string URLs, not Request objects + // 2. net.fetch doesn't support 'agent' option from Node.js http module + const electronFetch: typeof globalThis.fetch = async (input: URL | RequestInfo, init?: RequestInit) => { + const url = typeof input === 'string' ? input : input instanceof URL ? input.toString() : input.url + // Remove unsupported options for Electron's net.fetch + if (init) { + const initWithAgent = init as RequestInit & { agent?: unknown } + delete initWithAgent.agent + const headers = new Headers(initWithAgent.headers) + if (headers.has('content-length')) { + headers.delete('content-length') + } + initWithAgent.headers = headers + return net.fetch(url, initWithAgent) + } + return net.fetch(url) + } + const context = { fetch: electronFetch } if (provider.authType === 'oauth') { const oauthToken = await anthropicService.getValidAccessToken() - return getSdkClient(provider, oauthToken, extraHeaders) + return getSdkClient(provider, oauthToken, extraHeaders, context) } - return getSdkClient(provider, null, extraHeaders) + return getSdkClient(provider, null, extraHeaders, context) } prepareHeaders(headers: Record): Record { diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index 0f71cfcfae..5aadfbf534 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -42,10 +42,6 @@ export interface UnifiedStreamConfig { onComplete?: () => void } -// ============================================================================ -// Provider Factory -// ============================================================================ - /** * Main process format context for formatProviderApiHost * Unlike renderer, main process doesn't have direct access to store getters, so use reduxService cache @@ -338,6 +334,7 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage /** * Stream a message request using AI SDK and convert to Anthropic SSE format */ +// TODO: 使用ai-core executor集成中间件和transformstream进来 export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise { const { response, provider, modelId, params, onError, onComplete } = config From 36ed062b847279c2e2eda844e78b925e8250cdc2 Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 27 Nov 2025 20:55:36 +0800 Subject: [PATCH 05/53] fix: test --- packages/shared/config/providers.ts | 5 --- src/main/apiServer/routes/messages.ts | 39 ++++++++++++++----- src/main/apiServer/utils/index.ts | 29 ++++++++++++++ .../provider/__tests__/providerConfig.test.ts | 30 ++++++++++++-- src/renderer/src/utils/__tests__/api.test.ts | 13 +------ 5 files changed, 86 insertions(+), 30 deletions(-) diff --git a/packages/shared/config/providers.ts b/packages/shared/config/providers.ts index f7744150e2..e03661bf0e 100644 --- a/packages/shared/config/providers.ts +++ b/packages/shared/config/providers.ts @@ -41,8 +41,3 @@ const SILICON_ANTHROPIC_COMPATIBLE_MODEL_SET = new Set(SILICON_ANTHROPIC_COMPATI export function isSiliconAnthropicCompatibleModel(modelId: string): boolean { return SILICON_ANTHROPIC_COMPATIBLE_MODEL_SET.has(modelId) } - -/** - * Silicon provider's Anthropic API host URL. - */ -export const SILICON_ANTHROPIC_API_HOST = 'https://api.siliconflow.cn' diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index f2590cf1d5..f0eaac8e4e 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -6,17 +6,34 @@ import express from 'express' import { messagesService } from '../services/messages' import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages' -import { getProviderById, validateModelId } from '../utils' +import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils' /** - * Check if provider should use direct Anthropic SDK + * Check if a specific model on a provider should use direct Anthropic SDK * - * A provider is considered "Anthropic-compatible" if: + * A provider+model combination is considered "Anthropic-compatible" if: * 1. It's a native Anthropic provider (type === 'anthropic'), OR - * 2. It has anthropicApiHost configured (aggregated providers routing to Anthropic-compatible endpoints) + * 2. It has anthropicApiHost configured AND the specific model supports Anthropic API + * (for aggregated providers like Silicon, only certain models support Anthropic endpoint) + * + * @param provider - The provider to check + * @param modelId - The model ID to check (without provider prefix) + * @returns true if should use direct Anthropic SDK, false for unified SDK */ -function shouldUseDirectAnthropic(provider: Provider): boolean { - return provider.type === 'anthropic' || !!(provider.anthropicApiHost && provider.anthropicApiHost.trim()) +function shouldUseDirectAnthropic(provider: Provider, modelId: string): boolean { + // Native Anthropic provider - always use direct SDK + if (provider.type === 'anthropic') { + return true + } + + // No anthropicApiHost configured - use unified SDK + if (!provider.anthropicApiHost?.trim()) { + return false + } + + // Has anthropicApiHost - check model-level compatibility + // For aggregated providers, only specific models support Anthropic API + return isModelAnthropicCompatible(provider, modelId) } const logger = loggerService.withContext('ApiServerMessagesRoutes') @@ -169,11 +186,12 @@ async function handleUnifiedProcessing({ } /** - * Handle message processing - routes to appropriate handler based on provider + * Handle message processing - routes to appropriate handler based on provider and model * * Routing logic: - * - Providers with anthropicApiHost OR type 'anthropic': Direct Anthropic SDK (no conversion) - * - Other providers: Unified AI SDK with Anthropic SSE conversion + * - Native Anthropic providers (type === 'anthropic'): Direct Anthropic SDK + * - Providers with anthropicApiHost AND model supports Anthropic API: Direct Anthropic SDK + * - Other providers/models: Unified AI SDK with Anthropic SSE conversion */ async function handleMessageProcessing({ res, @@ -181,7 +199,8 @@ async function handleMessageProcessing({ request, modelId }: HandleMessageProcessingOptions): Promise { - if (shouldUseDirectAnthropic(provider)) { + const actualModelId = modelId || request.model + if (shouldUseDirectAnthropic(provider, actualModelId)) { return handleDirectAnthropicProcessing({ res, provider, request, modelId }) } return handleUnifiedProcessing({ res, provider, request, modelId }) diff --git a/src/main/apiServer/utils/index.ts b/src/main/apiServer/utils/index.ts index e25b49e750..471e734c18 100644 --- a/src/main/apiServer/utils/index.ts +++ b/src/main/apiServer/utils/index.ts @@ -295,3 +295,32 @@ export const getProviderAnthropicModelChecker = (providerId: string): ((m: Model return () => true } } + +/** + * Check if a specific model is compatible with Anthropic API for a given provider. + * + * This is used for fine-grained routing decisions at the model level. + * For aggregated providers (like Silicon), only certain models support the Anthropic API endpoint. + * + * @param provider - The provider to check + * @param modelId - The model ID to check (without provider prefix) + * @returns true if the model supports Anthropic API endpoint + */ +export function isModelAnthropicCompatible(provider: Provider, modelId: string): boolean { + const checker = getProviderAnthropicModelChecker(provider.id) + + const model = provider.models?.find((m) => m.id === modelId) + + if (model) { + return checker(model) + } + + const minimalModel: Model = { + id: modelId, + name: modelId, + provider: provider.id, + group: '' + } + + return checker(minimalModel) +} diff --git a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts index 430ff52869..22ef654da8 100644 --- a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts +++ b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts @@ -24,7 +24,17 @@ vi.mock('@renderer/services/AssistantService', () => ({ vi.mock('@renderer/store', () => ({ default: { - getState: () => ({ copilot: { defaultHeaders: {} } }) + getState: () => ({ + copilot: { defaultHeaders: {} }, + llm: { + settings: { + vertexai: { + projectId: 'test-project', + location: 'us-central1' + } + } + } + }) } })) @@ -33,7 +43,7 @@ vi.mock('@renderer/utils/api', () => ({ if (isSupportedAPIVersion === false) { return host // Return host as-is when isSupportedAPIVersion is false } - return `${host}/v1` // Default behavior when isSupportedAPIVersion is true + return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true }), routeToEndpoint: vi.fn((host) => ({ baseURL: host, @@ -41,6 +51,20 @@ vi.mock('@renderer/utils/api', () => ({ })) })) +// Also mock @shared/api since formatProviderApiHost uses it directly +vi.mock('@shared/api', async (importOriginal) => { + const actual = (await importOriginal()) as any + return { + ...actual, + formatApiHost: vi.fn((host, isSupportedAPIVersion = true) => { + if (isSupportedAPIVersion === false) { + return host || '' // Return host as-is when isSupportedAPIVersion is false + } + return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true + }) + } +}) + vi.mock('@renderer/utils/provider', async (importOriginal) => { const actual = (await importOriginal()) as any return { @@ -73,8 +97,8 @@ vi.mock('@renderer/services/AssistantService', () => ({ import { getProviderByModel } from '@renderer/services/AssistantService' import type { Model, Provider } from '@renderer/types' -import { formatApiHost } from '@renderer/utils/api' import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider' +import { formatApiHost } from '@shared/api' import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants' import { getActualProvider, providerToAiSdkConfig } from '../providerConfig' diff --git a/src/renderer/src/utils/__tests__/api.test.ts b/src/renderer/src/utils/__tests__/api.test.ts index e854445fc5..f56fb53d00 100644 --- a/src/renderer/src/utils/__tests__/api.test.ts +++ b/src/renderer/src/utils/__tests__/api.test.ts @@ -300,18 +300,7 @@ describe('api', () => { }) it('uses global endpoint when location equals global', () => { - getStateMock.mockReturnValueOnce({ - llm: { - settings: { - vertexai: { - projectId: 'global-project', - location: 'global' - } - } - } - }) - - expect(formatVertexApiHost(createVertexProvider(''))).toBe( + expect(formatVertexApiHost(createVertexProvider(''), 'global-project', 'global')).toBe( 'https://aiplatform.googleapis.com/v1/projects/global-project/locations/global' ) }) From 2a1adfe3224d23874fbd6875c9d998cae54a6bd3 Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 27 Nov 2025 21:00:52 +0800 Subject: [PATCH 06/53] feat: add ppio --- packages/shared/config/providers.ts | 34 +++++++++++++++++++ src/main/apiServer/utils/index.ts | 4 ++- src/renderer/src/pages/code/CodeToolsPage.tsx | 6 ++-- .../ProviderSettings/ProviderSetting.tsx | 3 +- src/renderer/src/store/index.ts | 2 +- src/renderer/src/store/migrate.ts | 14 ++++++++ 6 files changed, 58 insertions(+), 5 deletions(-) diff --git a/packages/shared/config/providers.ts b/packages/shared/config/providers.ts index e03661bf0e..6490c61cc8 100644 --- a/packages/shared/config/providers.ts +++ b/packages/shared/config/providers.ts @@ -41,3 +41,37 @@ const SILICON_ANTHROPIC_COMPATIBLE_MODEL_SET = new Set(SILICON_ANTHROPIC_COMPATI export function isSiliconAnthropicCompatibleModel(modelId: string): boolean { return SILICON_ANTHROPIC_COMPATIBLE_MODEL_SET.has(modelId) } + +/** + * PPIO provider models that support Anthropic API endpoint. + * These models can be used with Claude Code via the Anthropic-compatible API. + * + * @see https://ppio.com/docs/model/llm-anthropic-compatibility + */ +export const PPIO_ANTHROPIC_COMPATIBLE_MODELS: readonly string[] = [ + 'moonshotai/kimi-k2-thinking', + 'minimax/minimax-m2', + 'deepseek/deepseek-v3.2-exp', + 'deepseek/deepseek-v3.1-terminus', + 'zai-org/glm-4.6', + 'moonshotai/kimi-k2-0905', + 'deepseek/deepseek-v3.1', + 'moonshotai/kimi-k2-instruct', + 'qwen/qwen3-next-80b-a3b-instruct', + 'qwen/qwen3-next-80b-a3b-thinking' +] + +/** + * Creates a Set for efficient lookup of PPIO Anthropic-compatible model IDs. + */ +const PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET = new Set(PPIO_ANTHROPIC_COMPATIBLE_MODELS) + +/** + * Checks if a model ID is compatible with Anthropic API on PPIO provider. + * + * @param modelId - The model ID to check + * @returns true if the model supports Anthropic API endpoint + */ +export function isPpioAnthropicCompatibleModel(modelId: string): boolean { + return PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET.has(modelId) +} diff --git a/src/main/apiServer/utils/index.ts b/src/main/apiServer/utils/index.ts index 471e734c18..fde1ff3475 100644 --- a/src/main/apiServer/utils/index.ts +++ b/src/main/apiServer/utils/index.ts @@ -1,7 +1,7 @@ import { CacheService } from '@main/services/CacheService' import { loggerService } from '@main/services/LoggerService' import { reduxService } from '@main/services/ReduxService' -import { isSiliconAnthropicCompatibleModel } from '@shared/config/providers' +import { isPpioAnthropicCompatibleModel, isSiliconAnthropicCompatibleModel } from '@shared/config/providers' import type { ApiModel, Model, Provider } from '@types' const logger = loggerService.withContext('ApiServerUtils') @@ -290,6 +290,8 @@ export const getProviderAnthropicModelChecker = (providerId: string): ((m: Model return (m: Model) => m.id.includes('claude') case 'silicon': return (m: Model) => isSiliconAnthropicCompatibleModel(m.id) + case 'ppio': + return (m: Model) => isPpioAnthropicCompatibleModel(m.id) default: // allow all models when checker not configured return () => true diff --git a/src/renderer/src/pages/code/CodeToolsPage.tsx b/src/renderer/src/pages/code/CodeToolsPage.tsx index fcb2dbf482..a4314dfef9 100644 --- a/src/renderer/src/pages/code/CodeToolsPage.tsx +++ b/src/renderer/src/pages/code/CodeToolsPage.tsx @@ -17,7 +17,7 @@ import type { EndpointType, Model } from '@renderer/types' import { getClaudeSupportedProviders } from '@renderer/utils/provider' import type { TerminalConfig } from '@shared/config/constant' import { codeTools, terminalApps } from '@shared/config/constant' -import { isSiliconAnthropicCompatibleModel } from '@shared/config/providers' +import { isPpioAnthropicCompatibleModel, isSiliconAnthropicCompatibleModel } from '@shared/config/providers' import { Alert, Avatar, Button, Checkbox, Input, Popover, Select, Space, Tooltip } from 'antd' import { ArrowUpRight, Download, FolderOpen, HelpCircle, Terminal, X } from 'lucide-react' import type { FC } from 'react' @@ -82,10 +82,12 @@ const CodeToolsPage: FC = () => { if (m.supported_endpoint_types) { return m.supported_endpoint_types.includes('anthropic') } - // Special handling for silicon provider: only specific models support Anthropic API if (m.provider === 'silicon') { return isSiliconAnthropicCompatibleModel(m.id) } + if (m.provider === 'ppio') { + return isPpioAnthropicCompatibleModel(m.id) + } return m.id.includes('claude') || CLAUDE_OFFICIAL_SUPPORTED_PROVIDERS.includes(m.provider) } diff --git a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx index f341ac9229..b85690c3fb 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx @@ -85,7 +85,8 @@ const ANTHROPIC_COMPATIBLE_PROVIDER_IDS = [ SystemProviderIds.minimax, SystemProviderIds.silicon, SystemProviderIds.qiniu, - SystemProviderIds.dmxapi + SystemProviderIds.dmxapi, + SystemProviderIds.ppio ] as const type AnthropicCompatibleProviderId = (typeof ANTHROPIC_COMPATIBLE_PROVIDER_IDS)[number] diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index 5c562885bb..94b51474b9 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -67,7 +67,7 @@ const persistedReducer = persistReducer( { key: 'cherry-studio', storage, - version: 179, + version: 180, blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs', 'toolPermissions'], migrate }, diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 4b2e4cef89..1049da5964 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -2906,6 +2906,20 @@ const migrateConfig = { logger.error('migrate 179 error', error as Error) return state } + }, + '180': (state: RootState) => { + try { + state.llm.providers.forEach((provider) => { + if (provider.id === SystemProviderIds.ppio) { + provider.anthropicApiHost = 'https://api.ppinfra.com/anthropic' + } + }) + logger.info('migrate 180 success') + return state + } catch (error) { + logger.error('migrate 180 error', error as Error) + return state + } } } From 4c4102da20b14bd75615f75bcd8690ee66b1a580 Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 27 Nov 2025 21:10:26 +0800 Subject: [PATCH 07/53] feat: update agentModelFilter to exclude generate image models --- src/renderer/src/config/models/utils.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/renderer/src/config/models/utils.ts b/src/renderer/src/config/models/utils.ts index 1d5c9a6443..8a52d6e4ff 100644 --- a/src/renderer/src/config/models/utils.ts +++ b/src/renderer/src/config/models/utils.ts @@ -150,7 +150,7 @@ export const isGeminiModel = (model: Model) => { export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const export const agentModelFilter = (model: Model): boolean => { - return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model) + return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model) && !isGenerateImageModel(model) } export const isMaxTemperatureOneModel = (model: Model): boolean => { From f02c0fe9629b5a3b623e50ff26214405f68696aa Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 27 Nov 2025 21:19:01 +0800 Subject: [PATCH 08/53] fix: typecheck --- tsconfig.node.json | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tsconfig.node.json b/tsconfig.node.json index 83c3f2b461..4e60782e11 100644 --- a/tsconfig.node.json +++ b/tsconfig.node.json @@ -26,7 +26,11 @@ "@types": ["./src/renderer/src/types/index.ts"], "@shared/*": ["./packages/shared/*"], "@mcp-trace/*": ["./packages/mcp-trace/*"], - "@modelcontextprotocol/sdk/*": ["./node_modules/@modelcontextprotocol/sdk/dist/esm/*"] + "@modelcontextprotocol/sdk/*": ["./node_modules/@modelcontextprotocol/sdk/dist/esm/*"], + "@cherrystudio/ai-core/provider": ["./packages/aiCore/src/core/providers/index.ts"], + "@cherrystudio/ai-core/built-in/plugins": ["./packages/aiCore/src/core/plugins/built-in/index.ts"], + "@cherrystudio/ai-core/*": ["./packages/aiCore/src/*"], + "@cherrystudio/ai-core": ["./packages/aiCore/src/index.ts"] }, "experimentalDecorators": true, "emitDecoratorMetadata": true, From dad9cc95ad038d518b520705035cd9336cd715d7 Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 27 Nov 2025 21:23:28 +0800 Subject: [PATCH 09/53] fix: typecheck --- tsconfig.node.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tsconfig.node.json b/tsconfig.node.json index 4e60782e11..9871e604f2 100644 --- a/tsconfig.node.json +++ b/tsconfig.node.json @@ -30,7 +30,8 @@ "@cherrystudio/ai-core/provider": ["./packages/aiCore/src/core/providers/index.ts"], "@cherrystudio/ai-core/built-in/plugins": ["./packages/aiCore/src/core/plugins/built-in/index.ts"], "@cherrystudio/ai-core/*": ["./packages/aiCore/src/*"], - "@cherrystudio/ai-core": ["./packages/aiCore/src/index.ts"] + "@cherrystudio/ai-core": ["./packages/aiCore/src/index.ts"], + "@cherrystudio/ai-sdk-provider": ["./packages/ai-sdk-provider/src/index.ts"] }, "experimentalDecorators": true, "emitDecoratorMetadata": true, From 15c0a3881c4896acee0a672284acbf5593c5aba0 Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 27 Nov 2025 21:28:24 +0800 Subject: [PATCH 10/53] fix: test --- src/renderer/src/config/models/__tests__/utils.test.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/renderer/src/config/models/__tests__/utils.test.ts b/src/renderer/src/config/models/__tests__/utils.test.ts index f3f4d402af..97dbf755d1 100644 --- a/src/renderer/src/config/models/__tests__/utils.test.ts +++ b/src/renderer/src/config/models/__tests__/utils.test.ts @@ -120,7 +120,7 @@ describe('model utils', () => { rerankMock.mockReturnValue(false) visionMock.mockReturnValue(true) textToImageMock.mockReturnValue(false) - generateImageMock.mockReturnValue(true) + generateImageMock.mockReturnValue(false) reasoningMock.mockReturnValue(false) openAIWebSearchOnlyMock.mockReturnValue(false) }) @@ -274,6 +274,7 @@ describe('model utils', () => { visionMock.mockReturnValueOnce(true).mockReturnValueOnce(false) expect(isVisionModels(models)).toBe(false) + generateImageMock.mockReturnValue(true) expect(isGenerateImageModels(models)).toBe(true) generateImageMock.mockReturnValueOnce(true).mockReturnValueOnce(false) expect(isGenerateImageModels(models)).toBe(false) @@ -292,6 +293,10 @@ describe('model utils', () => { rerankMock.mockReturnValue(false) textToImageMock.mockReturnValueOnce(true) expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false) + + textToImageMock.mockReturnValue(false) + generateImageMock.mockReturnValueOnce(true) + expect(agentModelFilter(createModel({ id: 'dall-e-3' }))).toBe(false) }) it('identifies models with maximum temperature of 1.0', () => { From 0f6ec3e0614c59e3d411463c637e8ea92b312de0 Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 27 Nov 2025 22:02:50 +0800 Subject: [PATCH 11/53] feat: add new aliases for ai-core provider and core --- electron.vite.config.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/electron.vite.config.ts b/electron.vite.config.ts index 172d48ca9a..761ecfbf15 100644 --- a/electron.vite.config.ts +++ b/electron.vite.config.ts @@ -25,7 +25,9 @@ export default defineConfig({ '@shared': resolve('packages/shared'), '@logger': resolve('src/main/services/LoggerService'), '@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'), - '@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node') + '@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node'), + '@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'), + '@cherrystudio/ai-core': resolve('packages/aiCore/src') } }, build: { From f163c4d3ee6b336e30b1516511cdc745c44fec05 Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 27 Nov 2025 22:28:44 +0800 Subject: [PATCH 12/53] fix: resolve PR review issues for Proxy API Server MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix tool result content bug: return `values` array instead of empty array - Fix empty message bug: skip pushing user/assistant messages when content is empty - Expand provider support: remove type restrictions to support all AI SDK providers - Add missing alias for @cherrystudio/ai-sdk-provider in main process config 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- electron.vite.config.ts | 3 ++- .../apiServer/services/unified-messages.ts | 25 ++++++++++++------- src/main/apiServer/utils/index.ts | 19 +++++--------- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/electron.vite.config.ts b/electron.vite.config.ts index 761ecfbf15..da471c9fc9 100644 --- a/electron.vite.config.ts +++ b/electron.vite.config.ts @@ -27,7 +27,8 @@ export default defineConfig({ '@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'), '@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node'), '@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'), - '@cherrystudio/ai-core': resolve('packages/aiCore/src') + '@cherrystudio/ai-core': resolve('packages/aiCore/src'), + '@cherrystudio/ai-sdk-provider': resolve('packages/ai-sdk-provider/src') } }, build: { diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index 5aadfbf534..ddb6d59b37 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -186,7 +186,7 @@ IANA media type. } return { type: 'content', - value: [] + value: values } } } @@ -313,17 +313,24 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } // Build the message based on role + // Only push user/assistant message if there's actual content (avoid empty messages) if (msg.role === 'user') { - messages.push({ - role: 'user', - content: [...textParts, ...imageParts] - }) + const userContent = [...textParts, ...imageParts] + if (userContent.length > 0) { + messages.push({ + role: 'user', + content: userContent + }) + } } else { // Assistant messages contain tool calls, not tool results - messages.push({ - role: 'assistant', - content: [...reasoningParts, ...textParts, ...toolCallParts] - }) + const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts] + if (assistantContent.length > 0) { + messages.push({ + role: 'assistant', + content: assistantContent + }) + } } } } diff --git a/src/main/apiServer/utils/index.ts b/src/main/apiServer/utils/index.ts index fde1ff3475..17d3f9f088 100644 --- a/src/main/apiServer/utils/index.ts +++ b/src/main/apiServer/utils/index.ts @@ -28,10 +28,9 @@ export async function getAvailableProviders(): Promise { return [] } - // Support OpenAI and Anthropic type providers for API server - const supportedProviders = providers.filter( - (p: Provider) => p.enabled && (p.type === 'openai' || p.type === 'anthropic') - ) + // Support all provider types that AI SDK can handle + // The unified-messages service uses AI SDK which supports many providers + const supportedProviders = providers.filter((p: Provider) => p.enabled) // Cache the filtered results CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL) @@ -160,7 +159,7 @@ export async function validateModelId(model: string): Promise<{ valid: false, error: { type: 'provider_not_found', - message: `Provider '${providerId}' not found, not enabled, or not supported. Only OpenAI providers are currently supported.`, + message: `Provider '${providerId}' not found or not enabled.`, code: 'provider_not_found' } } @@ -262,14 +261,8 @@ export function validateProvider(provider: Provider): boolean { return false } - // Support OpenAI and Anthropic type providers - if (provider.type !== 'openai' && provider.type !== 'anthropic') { - logger.debug('Provider type not supported', { - providerId: provider.id, - providerType: provider.type - }) - return false - } + // AI SDK supports many provider types, no longer need to filter by type + // The unified-messages service handles all supported types return true } catch (error: any) { From 77c1b77113f4280454fd22869bd98f6516864f30 Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 27 Nov 2025 22:41:18 +0800 Subject: [PATCH 13/53] refactor: extract shared token counting logic in messages routes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract duplicated token estimation code from both count_tokens endpoints into a shared `estimateTokenCount` function to improve maintainability. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/main/apiServer/routes/messages.ts | 111 +++++++++++--------------- 1 file changed, 47 insertions(+), 64 deletions(-) diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index f0eaac8e4e..907b498273 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -41,6 +41,51 @@ const logger = loggerService.withContext('ApiServerMessagesRoutes') const router = express.Router() const providerRouter = express.Router({ mergeParams: true }) +/** + * Estimate token count from messages + * Simple approximation: ~4 characters per token for English text + */ +interface CountTokensInput { + messages: Array<{ role: string; content: string | Array<{ type: string; text?: string }> }> + system?: string | Array<{ type: string; text?: string }> +} + +function estimateTokenCount(input: CountTokensInput): number { + const { messages, system } = input + let totalChars = 0 + + // Count system message tokens + if (system) { + if (typeof system === 'string') { + totalChars += system.length + } else if (Array.isArray(system)) { + for (const block of system) { + if (block.type === 'text' && block.text) { + totalChars += block.text.length + } + } + } + } + + // Count message tokens + for (const msg of messages) { + if (typeof msg.content === 'string') { + totalChars += msg.content.length + } else if (Array.isArray(msg.content)) { + for (const block of msg.content) { + if (block.type === 'text' && block.text) { + totalChars += block.text.length + } + } + } + // Add overhead for role + totalChars += 10 + } + + // Estimate tokens (~4 chars per token, with some overhead) + return Math.ceil(totalChars / 4) + messages.length * 3 +} + // Helper function for basic request validation async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> { const request: MessageCreateParams = req.body @@ -589,45 +634,11 @@ router.post('/count_tokens', async (req: Request, res: Response) => { }) } - // Simple token estimation based on character count - // This is a rough approximation: ~4 characters per token for English text - let totalChars = 0 - - // Count system message tokens - if (system) { - if (typeof system === 'string') { - totalChars += system.length - } else if (Array.isArray(system)) { - for (const block of system) { - if (block.type === 'text' && block.text) { - totalChars += block.text.length - } - } - } - } - - // Count message tokens - for (const msg of messages) { - if (typeof msg.content === 'string') { - totalChars += msg.content.length - } else if (Array.isArray(msg.content)) { - for (const block of msg.content) { - if (block.type === 'text' && block.text) { - totalChars += block.text.length - } - } - } - // Add overhead for role - totalChars += 10 - } - - // Estimate tokens (~4 chars per token, with some overhead) - const estimatedTokens = Math.ceil(totalChars / 4) + messages.length * 3 + const estimatedTokens = estimateTokenCount({ messages, system }) logger.debug('Token count estimated', { model, messageCount: messages.length, - totalChars, estimatedTokens }) @@ -663,35 +674,7 @@ providerRouter.post('/count_tokens', async (req: Request, res: Response) => { }) } - // Simple token estimation - let totalChars = 0 - - if (system) { - if (typeof system === 'string') { - totalChars += system.length - } else if (Array.isArray(system)) { - for (const block of system) { - if (block.type === 'text' && block.text) { - totalChars += block.text.length - } - } - } - } - - for (const msg of messages) { - if (typeof msg.content === 'string') { - totalChars += msg.content.length - } else if (Array.isArray(msg.content)) { - for (const block of msg.content) { - if (block.type === 'text' && block.text) { - totalChars += block.text.length - } - } - } - totalChars += 10 - } - - const estimatedTokens = Math.ceil(totalChars / 4) + messages.length * 3 + const estimatedTokens = estimateTokenCount({ messages, system }) logger.debug('Token count estimated (provider route)', { providerId: req.params.provider, From ce2500159041bc31ac9dee8e5d27faf4a80d578e Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 28 Nov 2025 01:27:20 +0800 Subject: [PATCH 14/53] feat: add shared AI SDK middlewares and refactor middleware handling --- packages/shared/middleware/index.ts | 15 + packages/shared/middleware/middlewares.ts | 205 ++++++++ .../apiServer/services/unified-messages.ts | 461 +++++++----------- .../middleware/AiSdkMiddlewareBuilder.ts | 3 +- .../openrouterReasoningMiddleware.ts | 50 -- .../skipGeminiThoughtSignatureMiddleware.ts | 36 -- tsconfig.node.json | 4 +- 7 files changed, 401 insertions(+), 373 deletions(-) create mode 100644 packages/shared/middleware/index.ts create mode 100644 packages/shared/middleware/middlewares.ts delete mode 100644 src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts delete mode 100644 src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts diff --git a/packages/shared/middleware/index.ts b/packages/shared/middleware/index.ts new file mode 100644 index 0000000000..a4db5ad2dd --- /dev/null +++ b/packages/shared/middleware/index.ts @@ -0,0 +1,15 @@ +/** + * Shared AI SDK Middlewares + * + * Environment-agnostic middlewares that can be used in both + * renderer process and main process (API server). + */ + +export { + buildSharedMiddlewares, + getReasoningTagName, + isGemini3ModelId, + openrouterReasoningMiddleware, + type SharedMiddlewareConfig, + skipGeminiThoughtSignatureMiddleware +} from './middlewares' diff --git a/packages/shared/middleware/middlewares.ts b/packages/shared/middleware/middlewares.ts new file mode 100644 index 0000000000..d9725101c2 --- /dev/null +++ b/packages/shared/middleware/middlewares.ts @@ -0,0 +1,205 @@ +/** + * Shared AI SDK Middlewares + * + * These middlewares are environment-agnostic and can be used in both + * renderer process and main process (API server). + */ +import type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider' +import { extractReasoningMiddleware } from 'ai' + +/** + * Configuration for building shared middlewares + */ +export interface SharedMiddlewareConfig { + /** + * Whether to enable reasoning extraction + */ + enableReasoning?: boolean + + /** + * Tag name for reasoning extraction + * Defaults based on model ID + */ + reasoningTagName?: string + + /** + * Model ID - used to determine default reasoning tag and model detection + */ + modelId?: string + + /** + * Provider ID (Cherry Studio provider ID) + * Used for provider-specific middlewares like OpenRouter + */ + providerId?: string + + /** + * AI SDK Provider ID + * Used for Gemini thought signature middleware + * e.g., 'google', 'google-vertex' + */ + aiSdkProviderId?: string +} + +/** + * Check if model ID represents a Gemini 3 (2.5) model + * that requires thought signature handling + * + * @param modelId - The model ID string (not Model object) + */ +export function isGemini3ModelId(modelId?: string): boolean { + if (!modelId) return false + const lowerModelId = modelId.toLowerCase() + return lowerModelId.includes('gemini-2.5') || lowerModelId.includes('gemini-exp') || lowerModelId.includes('gemini-3') +} + +/** + * Get the default reasoning tag name based on model ID + * + * Different models use different tags for reasoning content: + * - Most models: 'think' + * - GPT-OSS models: 'reasoning' + * - Gemini models: 'thought' + * - Seed models: 'seed:think' + */ +export function getReasoningTagName(modelId?: string): string { + if (!modelId) return 'think' + const lowerModelId = modelId.toLowerCase() + if (lowerModelId.includes('gpt-oss')) return 'reasoning' + if (lowerModelId.includes('gemini')) return 'thought' + if (lowerModelId.includes('seed-oss-36b')) return 'seed:think' + return 'think' +} + +/** + * Skip Gemini Thought Signature Middleware + * + * Due to the complexity of multi-model client requests (which can switch + * to other models mid-process), this middleware skips all Gemini 3 + * thinking signatures validation. + * + * @param aiSdkId - AI SDK Provider ID (e.g., 'google', 'google-vertex') + * @returns LanguageModelV2Middleware + */ +export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelV2Middleware { + const MAGIC_STRING = 'skip_thought_signature_validator' + return { + middlewareVersion: 'v2', + + transformParams: async ({ params }) => { + const transformedParams = { ...params } + // Process messages in prompt + if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) { + transformedParams.prompt = transformedParams.prompt.map((message) => { + if (typeof message.content !== 'string') { + for (const part of message.content) { + const googleOptions = part?.providerOptions?.[aiSdkId] + if (googleOptions?.thoughtSignature) { + googleOptions.thoughtSignature = MAGIC_STRING + } + } + } + return message + }) + } + + return transformedParams + } + } +} + +/** + * OpenRouter Reasoning Middleware + * + * Filters out [REDACTED] blocks from OpenRouter reasoning responses. + * OpenRouter may include [REDACTED] markers in reasoning content that + * should be removed for cleaner output. + * + * @see https://openrouter.ai/docs/docs/best-practices/reasoning-tokens + * @returns LanguageModelV2Middleware + */ +export function openrouterReasoningMiddleware(): LanguageModelV2Middleware { + const REDACTED_BLOCK = '[REDACTED]' + return { + middlewareVersion: 'v2', + wrapGenerate: async ({ doGenerate }) => { + const { content, ...rest } = await doGenerate() + const modifiedContent = content.map((part) => { + if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) { + return { + ...part, + text: part.text.replace(REDACTED_BLOCK, '') + } + } + return part + }) + return { content: modifiedContent, ...rest } + }, + wrapStream: async ({ doStream }) => { + const { stream, ...rest } = await doStream() + return { + stream: stream.pipeThrough( + new TransformStream({ + transform( + chunk: LanguageModelV2StreamPart, + controller: TransformStreamDefaultController + ) { + if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) { + controller.enqueue({ + ...chunk, + delta: chunk.delta.replace(REDACTED_BLOCK, '') + }) + } else { + controller.enqueue(chunk) + } + } + }) + ), + ...rest + } + } + } +} + +/** + * Build shared middlewares based on configuration + * + * This function builds a set of middlewares that are commonly needed + * across different environments (renderer, API server). + * + * @param config - Configuration for middleware building + * @returns Array of AI SDK middlewares + * + * @example + * ```typescript + * import { buildSharedMiddlewares } from '@shared/middleware' + * + * const middlewares = buildSharedMiddlewares({ + * enableReasoning: true, + * modelId: 'gemini-2.5-pro', + * providerId: 'openrouter', + * aiSdkProviderId: 'google' + * }) + * ``` + */ +export function buildSharedMiddlewares(config: SharedMiddlewareConfig): LanguageModelV2Middleware[] { + const middlewares: LanguageModelV2Middleware[] = [] + + // 1. Reasoning extraction middleware + if (config.enableReasoning) { + const tagName = config.reasoningTagName || getReasoningTagName(config.modelId) + middlewares.push(extractReasoningMiddleware({ tagName })) + } + + // 2. OpenRouter-specific: filter [REDACTED] blocks + if (config.providerId === 'openrouter' && config.enableReasoning) { + middlewares.push(openrouterReasoningMiddleware()) + } + + // 3. Gemini 3 (2.5) specific: skip thought signature validation + if (isGemini3ModelId(config.modelId) && config.aiSdkProviderId) { + middlewares.push(skipGeminiThoughtSignatureMiddleware(config.aiSdkProviderId)) + } + + return middlewares +} diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index ddb6d59b37..be8b05aeac 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -1,4 +1,4 @@ -import type { LanguageModelV2ToolResultOutput } from '@ai-sdk/provider' +import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider' import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' import type { ImageBlockParam, @@ -6,7 +6,7 @@ import type { TextBlockParam, Tool as AnthropicTool } from '@anthropic-ai/sdk/resources/messages' -import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' +import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core' import { loggerService } from '@logger' import { reduxService } from '@main/services/ReduxService' import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters' @@ -21,8 +21,8 @@ import { } from '@shared/provider' import { defaultAppHeaders } from '@shared/utils' import type { Provider } from '@types' -import type { ImagePart, LanguageModel, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai' -import { jsonSchema, stepCountIs, streamText, tool } from 'ai' +import type { ImagePart, ModelMessage, TextPart, Tool } from 'ai' +import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool } from 'ai' import { net } from 'electron' import type { Response } from 'express' @@ -33,6 +33,9 @@ initializeSharedProviders({ error: (message, error) => logger.error(message, error) }) +/** + * Configuration for unified message streaming + */ export interface UnifiedStreamConfig { response: Response provider: Provider @@ -40,12 +43,31 @@ export interface UnifiedStreamConfig { params: MessageCreateParams onError?: (error: unknown) => void onComplete?: () => void + /** + * Optional AI SDK middlewares to apply + */ + middlewares?: LanguageModelV2Middleware[] + /** + * Optional AI Core plugins to use with the executor + */ + plugins?: AiPlugin[] } /** - * Main process format context for formatProviderApiHost - * Unlike renderer, main process doesn't have direct access to store getters, so use reduxService cache + * Configuration for non-streaming message generation */ +export interface GenerateUnifiedMessageConfig { + provider: Provider + modelId: string + params: MessageCreateParams + middlewares?: LanguageModelV2Middleware[] + plugins?: AiPlugin[] +} + +// ============================================================================ +// Internal Utilities +// ============================================================================ + function getMainProcessFormatContext(): ProviderFormatContext { const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai') return { @@ -56,12 +78,7 @@ function getMainProcessFormatContext(): ProviderFormatContext { } } -/** - * Main process context for providerToAiSdkConfig - * Main process doesn't have access to browser APIs like window.keyv - */ const mainProcessSdkContext: AiSdkConfigContext = { - // Simple key rotation - just return first key (no persistent rotation in main process) getRotatedApiKey: (provider) => { const keys = provider.apiKey.split(',').map((k) => k.trim()) return keys[0] || provider.apiKey @@ -69,199 +86,82 @@ const mainProcessSdkContext: AiSdkConfigContext = { fetch: net.fetch as typeof globalThis.fetch } -/** - * Get actual provider configuration for a model - * - * For aggregated providers (new-api, aihubmix, vertexai, azure-openai), - * this resolves the actual provider type based on the model's characteristics. - */ function getActualProvider(provider: Provider, modelId: string): Provider { - // Find the model in provider's models list const model = provider.models?.find((m) => m.id === modelId) - if (!model) { - // If model not found, return provider as-is - return provider - } - - // Resolve actual provider based on model + if (!model) return provider return resolveActualProvider(provider, model) } -/** - * Convert Cherry Studio Provider to AI SDK config - * Uses shared implementation with main process context - */ function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig { - // First resolve actual provider for aggregated providers const actualProvider = getActualProvider(provider, modelId) - - // Format the provider's apiHost for AI SDK const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext()) - - // Use shared implementation return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext) } -/** - * Create an AI SDK provider from Cherry Studio provider configuration - */ -async function createAiSdkProvider(config: AiSdkConfig): Promise { - try { - const provider = await createProviderCore(config.providerId, config.options) - logger.debug('AI SDK provider created', { - providerId: config.providerId, - hasOptions: !!config.options - }) - return provider - } catch (error) { - logger.error('Failed to create AI SDK provider', error as Error, { - providerId: config.providerId - }) - throw error - } -} - -/** - * Create an AI SDK language model from a Cherry Studio provider configuration - * Uses shared provider utilities for consistent behavior with renderer - */ -async function createLanguageModel(provider: Provider, modelId: string): Promise { - logger.debug('Creating language model', { - providerId: provider.id, - providerType: provider.type, - modelId, - apiHost: provider.apiHost - }) - - // Convert provider config to AI SDK config - const config = providerToAiSdkConfig(provider, modelId) - - // Create the AI SDK provider - const aiSdkProvider = await createAiSdkProvider(config) - if (!aiSdkProvider) { - throw new Error(`Failed to create AI SDK provider for ${provider.id}`) - } - - // Get the language model - return aiSdkProvider.languageModel(modelId) -} - function convertAnthropicToolResultToAiSdk( content: string | Array ): LanguageModelV2ToolResultOutput { if (typeof content === 'string') { - return { - type: 'text', - value: content - } - } else { - const values: Array< - | { type: 'text'; text: string } - | { - type: 'media' - /** -Base-64 encoded media data. -*/ - data: string - /** -IANA media type. -@see https://www.iana.org/assignments/media-types/media-types.xhtml -*/ - mediaType: string - } - > = [] - for (const block of content) { - if (block.type === 'text') { - values.push({ - type: 'text', - text: block.text - }) - } else if (block.type === 'image') { - values.push({ - type: 'media', - data: block.source.type === 'base64' ? block.source.data : block.source.url, - mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png' - }) - } - } - return { - type: 'content', - value: values + return { type: 'text', value: content } + } + const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = [] + for (const block of content) { + if (block.type === 'text') { + values.push({ type: 'text', text: block.text }) + } else if (block.type === 'image') { + values.push({ + type: 'media', + data: block.source.type === 'base64' ? block.source.data : block.source.url, + mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png' + }) } } + return { type: 'content', value: values } } -/** - * Convert Anthropic tools format to AI SDK tools format - */ function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record | undefined { - if (!tools || tools.length === 0) { - return undefined - } + if (!tools || tools.length === 0) return undefined const aiSdkTools: Record = {} - for (const anthropicTool of tools) { - // Handle different tool types - if (anthropicTool.type === 'bash_20250124') { - // Skip computer use and bash tools - these are Anthropic-specific - continue - } - - // Regular tool (type === 'custom' or no type) + if (anthropicTool.type === 'bash_20250124') continue const toolDef = anthropicTool as AnthropicTool const parameters = toolDef.input_schema as Parameters[0] - aiSdkTools[toolDef.name] = tool({ description: toolDef.description || '', inputSchema: jsonSchema(parameters), - execute: async (input: Record) => { - return input - } + execute: async (input: Record) => input }) } - return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined } -/** - * Convert Anthropic MessageCreateParams to AI SDK message format - */ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] { const messages: ModelMessage[] = [] - // Add system message if present + // System message if (params.system) { if (typeof params.system === 'string') { - messages.push({ - role: 'system', - content: params.system - }) + messages.push({ role: 'system', content: params.system }) } else if (Array.isArray(params.system)) { - // Handle TextBlockParam array const systemText = params.system .filter((block) => block.type === 'text') .map((block) => block.text) .join('\n') if (systemText) { - messages.push({ - role: 'system', - content: systemText - }) + messages.push({ role: 'system', content: systemText }) } } } - // Convert user/assistant messages + // User/assistant messages for (const msg of params.messages) { if (typeof msg.content === 'string') { - if (msg.role === 'user') { - messages.push({ role: 'user', content: msg.content }) - } else { - messages.push({ role: 'assistant', content: msg.content }) - } + messages.push({ + role: msg.role === 'user' ? 'user' : 'assistant', + content: msg.content + }) } else if (Array.isArray(msg.content)) { - // Handle content blocks const textParts: TextPart[] = [] const imageParts: ImagePart[] = [] const reasoningParts: ReasoningPart[] = [] @@ -278,15 +178,9 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } else if (block.type === 'image') { const source = block.source if (source.type === 'base64') { - imageParts.push({ - type: 'image', - image: `data:${source.media_type};base64,${source.data}` - }) + imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` }) } else if (source.type === 'url') { - imageParts.push({ - type: 'image', - image: source.url - }) + imageParts.push({ type: 'image', image: source.url }) } } else if (block.type === 'tool_use') { toolCallParts.push({ @@ -306,30 +200,18 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } if (toolResultParts.length > 0) { - messages.push({ - role: 'tool', - content: [...toolResultParts] - }) + messages.push({ role: 'tool', content: [...toolResultParts] }) } - // Build the message based on role - // Only push user/assistant message if there's actual content (avoid empty messages) if (msg.role === 'user') { const userContent = [...textParts, ...imageParts] if (userContent.length > 0) { - messages.push({ - role: 'user', - content: userContent - }) + messages.push({ role: 'user', content: userContent }) } } else { - // Assistant messages contain tool calls, not tool results const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts] if (assistantContent.length > 0) { - messages.push({ - role: 'assistant', - content: assistantContent - }) + messages.push({ role: 'assistant', content: assistantContent }) } } } @@ -338,67 +220,54 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage return messages } -/** - * Stream a message request using AI SDK and convert to Anthropic SSE format - */ -// TODO: 使用ai-core executor集成中间件和transformstream进来 -export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise { - const { response, provider, modelId, params, onError, onComplete } = config +interface ExecuteStreamConfig { + provider: Provider + modelId: string + params: MessageCreateParams + middlewares?: LanguageModelV2Middleware[] + plugins?: AiPlugin[] + onEvent?: (event: Parameters[0]) => void +} - logger.info('Starting unified message stream', { - providerId: provider.id, - providerType: provider.type, - modelId, - stream: params.stream +/** + * Core stream execution function - single source of truth for AI SDK calls + */ +async function executeStream(config: ExecuteStreamConfig): Promise { + const { provider, modelId, params, middlewares = [], plugins = [], onEvent } = config + + // Convert provider config to AI SDK config + const sdkConfig = providerToAiSdkConfig(provider, modelId) + + logger.debug('Created AI SDK config', { + providerId: sdkConfig.providerId, + hasOptions: !!sdkConfig.options }) - try { - response.setHeader('Content-Type', 'text/event-stream') - response.setHeader('Cache-Control', 'no-cache') - response.setHeader('Connection', 'keep-alive') - response.setHeader('X-Accel-Buffering', 'no') + // Create executor with plugins + const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins) - const model = await createLanguageModel(provider, modelId) + // Convert messages and tools + const coreMessages = convertAnthropicToAiMessages(params) + const tools = convertAnthropicToolsToAiSdk(params.tools) - const coreMessages = convertAnthropicToAiMessages(params) + logger.debug('Converted messages', { + originalCount: params.messages.length, + convertedCount: coreMessages.length, + hasSystem: !!params.system, + hasTools: !!tools, + toolCount: tools ? Object.keys(tools).length : 0 + }) - // Convert tools if present - const tools = convertAnthropicToolsToAiSdk(params.tools) + // Create the adapter + const adapter = new AiSdkToAnthropicSSE({ + model: `${provider.id}:${modelId}`, + onEvent: onEvent || (() => {}) + }) - logger.debug('Converted messages', { - originalCount: params.messages.length, - convertedCount: coreMessages.length, - hasSystem: !!params.system, - hasTools: !!tools, - toolCount: tools ? Object.keys(tools).length : 0, - toolNames: tools ? Object.keys(tools).slice(0, 10) : [], - paramsToolCount: params.tools?.length || 0 - }) - - // Debug: Log message structure to understand tool_result handling - logger.silly('Message structure for debugging', { - messages: coreMessages.map((m) => ({ - role: m.role, - contentTypes: Array.isArray(m.content) - ? m.content.map((c: { type: string }) => c.type) - : typeof m.content === 'string' - ? ['string'] - : ['unknown'] - })) - }) - - // Create the adapter - const adapter = new AiSdkToAnthropicSSE({ - model: `${provider.id}:${modelId}`, - onEvent: (event) => { - const sseData = formatSSEEvent(event) - response.write(sseData) - } - }) - - // Start streaming - const result = streamText({ - model, + // Execute stream + const result = await executor.streamText( + { + model: modelId, messages: coreMessages, maxOutputTokens: params.max_tokens, temperature: params.temperature, @@ -408,38 +277,65 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis headers: defaultAppHeaders(), tools, providerOptions: {} - }) + }, + { middlewares } + ) - // Process the stream through the adapter - await adapter.processStream(result.fullStream) + // Process the stream through the adapter + await adapter.processStream(result.fullStream) + + return adapter +} + +/** + * Stream a message request using AI SDK executor and convert to Anthropic SSE format + */ +export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise { + const { response, provider, modelId, params, onError, onComplete, middlewares = [], plugins = [] } = config + + logger.info('Starting unified message stream', { + providerId: provider.id, + providerType: provider.type, + modelId, + stream: params.stream, + middlewareCount: middlewares.length, + pluginCount: plugins.length + }) + + try { + response.setHeader('Content-Type', 'text/event-stream') + response.setHeader('Cache-Control', 'no-cache') + response.setHeader('Connection', 'keep-alive') + response.setHeader('X-Accel-Buffering', 'no') + + await executeStream({ + provider, + modelId, + params, + middlewares, + plugins, + onEvent: (event) => { + const sseData = formatSSEEvent(event) + response.write(sseData) + } + }) // Send done marker response.write(formatSSEDone()) response.end() - logger.info('Unified message stream completed', { - providerId: provider.id, - modelId - }) - + logger.info('Unified message stream completed', { providerId: provider.id, modelId }) onComplete?.() } catch (error) { - logger.error('Error in unified message stream', error as Error, { - providerId: provider.id, - modelId - }) + logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId }) - // Try to send error event if response is still writable if (!response.writableEnded) { try { const errorMessage = error instanceof Error ? error.message : 'Unknown error' response.write( `event: error\ndata: ${JSON.stringify({ type: 'error', - error: { - type: 'api_error', - message: errorMessage - } + error: { type: 'api_error', message: errorMessage } })}\n\n` ) response.end() @@ -455,64 +351,61 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis /** * Generate a non-streaming message response + * + * Uses simulateStreamingMiddleware to reuse the same streaming logic, + * similar to renderer's ModernAiProvider pattern. */ export async function generateUnifiedMessage( - provider: Provider, - modelId: string, - params: MessageCreateParams + providerOrConfig: Provider | GenerateUnifiedMessageConfig, + modelId?: string, + params?: MessageCreateParams ): Promise> { + // Support both old signature and new config-based signature + let config: GenerateUnifiedMessageConfig + if ('provider' in providerOrConfig && 'modelId' in providerOrConfig && 'params' in providerOrConfig) { + config = providerOrConfig + } else { + config = { + provider: providerOrConfig as Provider, + modelId: modelId!, + params: params! + } + } + + const { provider, middlewares = [], plugins = [] } = config + logger.info('Starting unified message generation', { providerId: provider.id, providerType: provider.type, - modelId + modelId: config.modelId, + middlewareCount: middlewares.length, + pluginCount: plugins.length }) try { - // Create language model (async - uses @cherrystudio/ai-core) - const model = await createLanguageModel(provider, modelId) + // Add simulateStreamingMiddleware to reuse streaming logic for non-streaming + const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares] - // Convert messages and tools - const coreMessages = convertAnthropicToAiMessages(params) - const tools = convertAnthropicToolsToAiSdk(params.tools) - - // Create adapter to collect the response - let finalResponse: ReturnType | null = null - const adapter = new AiSdkToAnthropicSSE({ - model: `${provider.id}:${modelId}`, - onEvent: () => { - // We don't need to emit events for non-streaming - } + const adapter = await executeStream({ + provider, + modelId: config.modelId, + params: config.params, + middlewares: allMiddlewares, + plugins }) - // Generate text - const result = streamText({ - model, - messages: coreMessages, - maxOutputTokens: params.max_tokens, - temperature: params.temperature, - topP: params.top_p, - stopSequences: params.stop_sequences, - headers: defaultAppHeaders(), - tools, - stopWhen: stepCountIs(100) - }) - - // Process the stream to build the response - await adapter.processStream(result.fullStream) - - // Get the final response - finalResponse = adapter.buildNonStreamingResponse() + const finalResponse = adapter.buildNonStreamingResponse() logger.info('Unified message generation completed', { providerId: provider.id, - modelId + modelId: config.modelId }) return finalResponse } catch (error) { logger.error('Error in unified message generation', error as Error, { providerId: provider.id, - modelId + modelId: config.modelId }) throw error } diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index b314ddd737..82e1c32465 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -5,6 +5,7 @@ import type { MCPTool } from '@renderer/types' import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' import { isSupportEnableThinkingProvider } from '@renderer/utils/provider' +import { openrouterReasoningMiddleware, skipGeminiThoughtSignatureMiddleware } from '@shared/middleware' import type { LanguageModelMiddleware } from 'ai' import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' import { isEmpty } from 'lodash' @@ -13,9 +14,7 @@ import { getAiSdkProviderId } from '../provider/factory' import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' import { noThinkMiddleware } from './noThinkMiddleware' import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware' -import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware' import { qwenThinkingMiddleware } from './qwenThinkingMiddleware' -import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware' import { toolChoiceMiddleware } from './toolChoiceMiddleware' const logger = loggerService.withContext('AiSdkMiddlewareBuilder') diff --git a/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts b/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts deleted file mode 100644 index 9ef3df61e9..0000000000 --- a/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts +++ /dev/null @@ -1,50 +0,0 @@ -import type { LanguageModelV2StreamPart } from '@ai-sdk/provider' -import type { LanguageModelMiddleware } from 'ai' - -/** - * https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude - * - * @returns LanguageModelMiddleware - a middleware filter redacted block - */ -export function openrouterReasoningMiddleware(): LanguageModelMiddleware { - const REDACTED_BLOCK = '[REDACTED]' - return { - middlewareVersion: 'v2', - wrapGenerate: async ({ doGenerate }) => { - const { content, ...rest } = await doGenerate() - const modifiedContent = content.map((part) => { - if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) { - return { - ...part, - text: part.text.replace(REDACTED_BLOCK, '') - } - } - return part - }) - return { content: modifiedContent, ...rest } - }, - wrapStream: async ({ doStream }) => { - const { stream, ...rest } = await doStream() - return { - stream: stream.pipeThrough( - new TransformStream({ - transform( - chunk: LanguageModelV2StreamPart, - controller: TransformStreamDefaultController - ) { - if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) { - controller.enqueue({ - ...chunk, - delta: chunk.delta.replace(REDACTED_BLOCK, '') - }) - } else { - controller.enqueue(chunk) - } - } - }) - ), - ...rest - } - } - } -} diff --git a/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts b/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts deleted file mode 100644 index da318ea60d..0000000000 --- a/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts +++ /dev/null @@ -1,36 +0,0 @@ -import type { LanguageModelMiddleware } from 'ai' - -/** - * skip Gemini Thought Signature Middleware - * 由于多模型客户端请求的复杂性(可以中途切换其他模型),这里选择通过中间件方式添加跳过所有 Gemini3 思考签名 - * Due to the complexity of multi-model client requests (which can switch to other models mid-process), - * it was decided to add a skip for all Gemini3 thinking signatures via middleware. - * @param aiSdkId AI SDK Provider ID - * @returns LanguageModelMiddleware - */ -export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware { - const MAGIC_STRING = 'skip_thought_signature_validator' - return { - middlewareVersion: 'v2', - - transformParams: async ({ params }) => { - const transformedParams = { ...params } - // Process messages in prompt - if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) { - transformedParams.prompt = transformedParams.prompt.map((message) => { - if (typeof message.content !== 'string') { - for (const part of message.content) { - const googleOptions = part?.providerOptions?.[aiSdkId] - if (googleOptions?.thoughtSignature) { - googleOptions.thoughtSignature = MAGIC_STRING - } - } - } - return message - }) - } - - return transformedParams - } - } -} diff --git a/tsconfig.node.json b/tsconfig.node.json index 9871e604f2..4f9e797146 100644 --- a/tsconfig.node.json +++ b/tsconfig.node.json @@ -7,9 +7,11 @@ "src/main/env.d.ts", "src/renderer/src/types/*", "packages/shared/**/*", + "packages/aiCore/src/**/*", "scripts", "packages/mcp-trace/**/*", - "src/renderer/src/services/traceApi.ts" + "src/renderer/src/services/traceApi.ts", + "packages/ai-sdk-provider/**/*" ], "compilerOptions": { "composite": true, From 356e82842299df02701d55cc59f989c6f3ca095f Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 28 Nov 2025 04:12:18 +0800 Subject: [PATCH 15/53] feat: enhance AI SDK integration with middleware support and improve message handling --- .../shared/adapters/AiSdkToAnthropicSSE.ts | 45 +++---- packages/shared/middleware/middlewares.ts | 2 +- src/main/apiServer/routes/messages.ts | 23 +++- .../apiServer/services/unified-messages.ts | 120 +++++++++++++----- 4 files changed, 130 insertions(+), 60 deletions(-) diff --git a/packages/shared/adapters/AiSdkToAnthropicSSE.ts b/packages/shared/adapters/AiSdkToAnthropicSSE.ts index 1674609236..f1d6b0c022 100644 --- a/packages/shared/adapters/AiSdkToAnthropicSSE.ts +++ b/packages/shared/adapters/AiSdkToAnthropicSSE.ts @@ -36,7 +36,7 @@ import type { Usage } from '@anthropic-ai/sdk/resources/messages' import { loggerService } from '@logger' -import type { TextStreamPart, ToolSet } from 'ai' +import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai' const logger = loggerService.withContext('AiSdkToAnthropicSSE') @@ -56,6 +56,7 @@ interface AdapterState { model: string inputTokens: number outputTokens: number + cacheInputTokens: number currentBlockIndex: number blocks: Map textBlockIndex: number | null @@ -67,10 +68,6 @@ interface AdapterState { hasEmittedMessageStart: boolean } -// ============================================================================ -// Adapter Class -// ============================================================================ - export type SSEEventCallback = (event: RawMessageStreamEvent) => void export interface AiSdkToAnthropicSSEOptions { @@ -94,6 +91,7 @@ export class AiSdkToAnthropicSSE { model: options.model, inputTokens: options.inputTokens || 0, outputTokens: 0, + cacheInputTokens: 0, currentBlockIndex: 0, blocks: new Map(), textBlockIndex: null, @@ -153,19 +151,19 @@ export class AiSdkToAnthropicSSE { // === Reasoning/Thinking Events === case 'reasoning-start': { - const reasoningId = (chunk as { id?: string }).id || `reasoning_${Date.now()}` + const reasoningId = chunk.id this.startThinkingBlock(reasoningId) break } case 'reasoning-delta': { - const reasoningId = (chunk as { id?: string }).id + const reasoningId = chunk.id this.emitThinkingDelta(chunk.text || '', reasoningId) break } case 'reasoning-end': { - const reasoningId = (chunk as { id?: string }).id + const reasoningId = chunk.id this.stopThinkingBlock(reasoningId) break } @@ -176,14 +174,18 @@ export class AiSdkToAnthropicSSE { type: 'tool-call', toolCallId: chunk.toolCallId, toolName: chunk.toolName, - // AI SDK uses 'args' in some versions and 'input' in others - args: 'args' in chunk ? chunk.args : (chunk as any).input + args: chunk.input }) break case 'tool-result': - // Tool results are handled separately in Anthropic API - // They come from user messages, not assistant stream + // this.handleToolResult({ + // type: 'tool-result', + // toolCallId: chunk.toolCallId, + // toolName: chunk.toolName, + // args: chunk.input, + // result: chunk.output + // }) break // === Completion Events === @@ -465,34 +467,29 @@ export class AiSdkToAnthropicSSE { this.state.stopReason = 'tool_use' } - private handleFinish(chunk: { - type: 'finish' - finishReason?: string - totalUsage?: { - inputTokens?: number - outputTokens?: number - } - }): void { + private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void { // Update usage if (chunk.totalUsage) { this.state.inputTokens = chunk.totalUsage.inputTokens || 0 this.state.outputTokens = chunk.totalUsage.outputTokens || 0 + this.state.cacheInputTokens = chunk.totalUsage.cachedInputTokens || 0 } // Determine finish reason if (!this.state.stopReason) { switch (chunk.finishReason) { case 'stop': - case 'end_turn': this.state.stopReason = 'end_turn' break case 'length': - case 'max_tokens': this.state.stopReason = 'max_tokens' break case 'tool-calls': this.state.stopReason = 'tool_use' break + case 'content-filter': + this.state.stopReason = 'refusal' + break default: this.state.stopReason = 'end_turn' } @@ -539,8 +536,8 @@ export class AiSdkToAnthropicSSE { // Emit message_delta with final stop reason and usage const usage: MessageDeltaUsage = { output_tokens: this.state.outputTokens, - input_tokens: null, - cache_creation_input_tokens: null, + input_tokens: this.state.inputTokens, + cache_creation_input_tokens: this.state.cacheInputTokens, cache_read_input_tokens: null, server_tool_use: null } diff --git a/packages/shared/middleware/middlewares.ts b/packages/shared/middleware/middlewares.ts index d9725101c2..de857699f7 100644 --- a/packages/shared/middleware/middlewares.ts +++ b/packages/shared/middleware/middlewares.ts @@ -50,7 +50,7 @@ export interface SharedMiddlewareConfig { export function isGemini3ModelId(modelId?: string): boolean { if (!modelId) return false const lowerModelId = modelId.toLowerCase() - return lowerModelId.includes('gemini-2.5') || lowerModelId.includes('gemini-exp') || lowerModelId.includes('gemini-3') + return lowerModelId.includes('gemini-3') } /** diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index 907b498273..018e7d60ad 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -1,5 +1,7 @@ import type { MessageCreateParams } from '@anthropic-ai/sdk/resources' import { loggerService } from '@logger' +import { buildSharedMiddlewares, type SharedMiddlewareConfig } from '@shared/middleware' +import { getAiSdkProviderId } from '@shared/provider' import type { Provider } from '@types' import type { Request, Response } from 'express' import express from 'express' @@ -206,12 +208,26 @@ async function handleUnifiedProcessing({ return } + const middlewareConfig: SharedMiddlewareConfig = { + modelId: actualModelId, + providerId: provider.id, + aiSdkProviderId: getAiSdkProviderId(provider) + } + const middlewares = buildSharedMiddlewares(middlewareConfig) + + logger.debug('Built middlewares for unified processing', { + middlewareCount: middlewares.length, + modelId: actualModelId, + providerId: provider.id + }) + if (request.stream) { await streamUnifiedMessages({ response: res, provider, modelId: actualModelId, params: request, + middlewares, onError: (error) => { logger.error('Stream error', error as Error) }, @@ -220,7 +236,12 @@ async function handleUnifiedProcessing({ } }) } else { - const response = await generateUnifiedMessage(provider, actualModelId, request) + const response = await generateUnifiedMessage({ + provider, + modelId: actualModelId, + params: request, + middlewares + }) res.json(response) } } catch (error: any) { diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index be8b05aeac..4370a429d0 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -1,5 +1,5 @@ import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider' -import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' +import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' import type { ImageBlockParam, MessageCreateParams, @@ -7,9 +7,11 @@ import type { Tool as AnthropicTool } from '@anthropic-ai/sdk/resources/messages' import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core' +import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' import { reduxService } from '@main/services/ReduxService' import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters' +import { isGemini3ModelId } from '@shared/middleware' import { type AiSdkConfig, type AiSdkConfigContext, @@ -21,13 +23,15 @@ import { } from '@shared/provider' import { defaultAppHeaders } from '@shared/utils' import type { Provider } from '@types' -import type { ImagePart, ModelMessage, TextPart, Tool } from 'ai' -import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool } from 'ai' +import type { ImagePart, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai' +import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel } from 'ai' import { net } from 'electron' import type { Response } from 'express' const logger = loggerService.withContext('UnifiedMessagesService') +const MAGIC_STRING = 'skip_thought_signature_validator' + initializeSharedProviders({ warn: (message) => logger.warn(message), error: (message, error) => logger.error(message, error) @@ -64,10 +68,6 @@ export interface GenerateUnifiedMessageConfig { plugins?: AiPlugin[] } -// ============================================================================ -// Internal Utilities -// ============================================================================ - function getMainProcessFormatContext(): ProviderFormatContext { const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai') return { @@ -154,6 +154,19 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } } + // Build a map of tool_use_id -> toolName from all messages first + // This is needed because tool_result references tool_use from previous assistant messages + const toolCallIdToName = new Map() + for (const msg of params.messages) { + if (Array.isArray(msg.content)) { + for (const block of msg.content) { + if (block.type === 'tool_use') { + toolCallIdToName.set(block.id, block.name) + } + } + } + } + // User/assistant messages for (const msg of params.messages) { if (typeof msg.content === 'string') { @@ -190,10 +203,12 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage input: block.input }) } else if (block.type === 'tool_result') { + // Look up toolName from the pre-built map (covers cross-message references) + const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown' toolResultParts.push({ type: 'tool-result', toolCallId: block.tool_use_id, - toolName: toolCallParts.find((t) => t.toolCallId === block.tool_use_id)?.toolName || 'unknown', + toolName, output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' } }) } @@ -211,7 +226,18 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } else { const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts] if (assistantContent.length > 0) { - messages.push({ role: 'assistant', content: assistantContent }) + let providerOptions: ProviderOptions | undefined = undefined + if (isGemini3ModelId(params.model)) { + providerOptions = { + google: { + thoughtSignature: MAGIC_STRING + }, + openrouter: { + reasoning_details: [] + } + } + } + messages.push({ role: 'assistant', content: assistantContent, providerOptions }) } } } @@ -229,6 +255,32 @@ interface ExecuteStreamConfig { onEvent?: (event: Parameters[0]) => void } +/** + * Create AI SDK provider instance from config + * Similar to renderer's createAiSdkProvider + */ +async function createAiSdkProvider(config: AiSdkConfig): Promise { + let providerId = config.providerId + + // Handle special provider modes (same as renderer) + if (providerId === 'openai' && config.options?.mode === 'chat') { + providerId = 'openai-chat' + } else if (providerId === 'azure' && config.options?.mode === 'responses') { + providerId = 'azure-responses' + } else if (providerId === 'cherryin' && config.options?.mode === 'chat') { + providerId = 'cherryin-chat' + } + + const provider = await createProviderCore(providerId, config.options) + + logger.debug('AI SDK provider created', { + providerId, + hasOptions: !!config.options + }) + + return provider +} + /** * Core stream execution function - single source of truth for AI SDK calls */ @@ -240,9 +292,20 @@ async function executeStream(config: ExecuteStreamConfig): Promise 0 && typeof baseModel === 'object' + ? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel) + : baseModel + // Create executor with plugins const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins) @@ -250,36 +313,25 @@ async function executeStream(config: ExecuteStreamConfig): Promise {}) }) - // Execute stream - const result = await executor.streamText( - { - model: modelId, - messages: coreMessages, - maxOutputTokens: params.max_tokens, - temperature: params.temperature, - topP: params.top_p, - stopSequences: params.stop_sequences, - stopWhen: stepCountIs(100), - headers: defaultAppHeaders(), - tools, - providerOptions: {} - }, - { middlewares } - ) + // Execute stream - pass model object instead of string + const result = await executor.streamText({ + model, // Now passing LanguageModel object, not string + messages: coreMessages, + maxOutputTokens: params.max_tokens, + temperature: params.temperature, + topP: params.top_p, + stopSequences: params.stop_sequences, + stopWhen: stepCountIs(100), + headers: defaultAppHeaders(), + tools, + providerOptions: {} + }) // Process the stream through the adapter await adapter.processStream(result.fullStream) From d367040fd4e5798e81768128b675bf7da5b45adc Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 28 Nov 2025 13:11:13 +0800 Subject: [PATCH 16/53] feat: implement reasoning cache for improved performance and error handling in AI SDK integration --- .../shared/adapters/AiSdkToAnthropicSSE.ts | 44 ++----- src/main/apiServer/routes/messages.ts | 1 - src/main/apiServer/services/cache.ts | 116 ++++++++++++++++++ src/main/apiServer/services/messages.ts | 38 +++++- .../apiServer/services/unified-messages.ts | 34 ++--- .../claudecode/claude-stream-state.ts | 11 ++ .../agents/services/claudecode/index.ts | 13 ++ .../agents/services/claudecode/transform.ts | 45 ++++--- 8 files changed, 225 insertions(+), 77 deletions(-) create mode 100644 src/main/apiServer/services/cache.ts diff --git a/packages/shared/adapters/AiSdkToAnthropicSSE.ts b/packages/shared/adapters/AiSdkToAnthropicSSE.ts index f1d6b0c022..9b23638f48 100644 --- a/packages/shared/adapters/AiSdkToAnthropicSSE.ts +++ b/packages/shared/adapters/AiSdkToAnthropicSSE.ts @@ -36,7 +36,8 @@ import type { Usage } from '@anthropic-ai/sdk/resources/messages' import { loggerService } from '@logger' -import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai' +import { reasoningCache } from '@main/apiServer/services/cache' +import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai' const logger = loggerService.withContext('AiSdkToAnthropicSSE') @@ -125,6 +126,9 @@ export class AiSdkToAnthropicSSE { // Ensure all blocks are closed and emit final events this.finalize() + } catch (error) { + await reader.cancel() + throw error } finally { reader.releaseLock() } @@ -188,8 +192,13 @@ export class AiSdkToAnthropicSSE { // }) break - // === Completion Events === case 'finish-step': + if ( + chunk.providerMetadata?.openrouter?.reasoning_details && + Array.isArray(chunk.providerMetadata.openrouter.reasoning_details) + ) { + reasoningCache.set('openrouter', chunk.providerMetadata?.openrouter?.reasoning_details) + } if (chunk.finishReason === 'tool-calls') { this.state.stopReason = 'tool_use' } @@ -199,10 +208,8 @@ export class AiSdkToAnthropicSSE { this.handleFinish(chunk) break - // === Error Events === case 'error': - this.handleError(chunk.error) - break + throw chunk.error // Ignore other event types default: @@ -496,33 +503,6 @@ export class AiSdkToAnthropicSSE { } } - private handleError(error: unknown): void { - // Log the error for debugging - logger.warn('AiSdkToAnthropicSSE - Provider error received:', { error }) - - // Extract error message - let errorMessage = 'Unknown error from provider' - if (error && typeof error === 'object') { - const err = error as { message?: string; metadata?: { raw?: string } } - if (err.metadata?.raw) { - errorMessage = `Provider error: ${err.metadata.raw}` - } else if (err.message) { - errorMessage = err.message - } - } else if (typeof error === 'string') { - errorMessage = error - } - - // Emit error as a text block so the user can see it - // First close any open thinking blocks to maintain proper event order - for (const reasoningId of Array.from(this.state.thinkingBlocks.keys())) { - this.stopThinkingBlock(reasoningId) - } - - // Emit the error as text - this.emitTextDelta(`\n\n[Error: ${errorMessage}]\n`) - } - private finalize(): void { // Close any open blocks if (this.state.textBlockIndex !== null) { diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index 018e7d60ad..1e18c86118 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -245,7 +245,6 @@ async function handleUnifiedProcessing({ res.json(response) } } catch (error: any) { - logger.error('Unified processing error', { error }) const { statusCode, errorResponse } = messagesService.transformError(error) res.status(statusCode).json(errorResponse) } diff --git a/src/main/apiServer/services/cache.ts b/src/main/apiServer/services/cache.ts new file mode 100644 index 0000000000..765ab1e1b9 --- /dev/null +++ b/src/main/apiServer/services/cache.ts @@ -0,0 +1,116 @@ +import { loggerService } from '@logger' + +const logger = loggerService.withContext('Cache') +/** + * Cache entry with TTL support + */ +interface CacheEntry { + details: T[] + timestamp: number +} + +/** + * In-memory cache for reasoning details + * Key: signature + * Value: reasoning array with timestamp + */ +export class ReasoningCache { + private cache = new Map>() + private readonly ttlMs: number + private cleanupInterval: ReturnType | null = null + + constructor(ttlMs: number = 30 * 60 * 1000) { + // Default 30 minutes TTL + this.ttlMs = ttlMs + this.startCleanup() + } + + /** + * Store reasoning details by signature + */ + set(signature: string, details: T[]): void { + if (!signature || !details.length) return + + this.cache.set(signature, { + details, + timestamp: Date.now() + }) + + logger.debug('Cached reasoning details', { + signature: signature.substring(0, 20) + '...', + detailsCount: details.length + }) + } + + /** + * Retrieve reasoning details by signature + */ + get(signature: string): T[] | undefined { + const entry = this.cache.get(signature) + if (!entry) return undefined + + // Check TTL + if (Date.now() - entry.timestamp > this.ttlMs) { + this.cache.delete(signature) + return undefined + } + + logger.debug('Retrieved reasoning details from cache', { + signature: signature.substring(0, 20) + '...', + detailsCount: entry.details.length + }) + + return entry.details + } + + /** + * Clear expired entries + */ + cleanup(): void { + const now = Date.now() + let cleaned = 0 + + for (const [key, entry] of this.cache) { + if (now - entry.timestamp > this.ttlMs) { + this.cache.delete(key) + cleaned++ + } + } + + if (cleaned > 0) { + logger.debug('Cleaned up expired reasoning cache entries', { cleaned, remaining: this.cache.size }) + } + } + + /** + * Start periodic cleanup + */ + private startCleanup(): void { + // Cleanup every 5 minutes + this.cleanupInterval = setInterval(() => this.cleanup(), 5 * 60 * 1000) + } + + /** + * Stop cleanup and clear cache + */ + destroy(): void { + if (this.cleanupInterval) { + clearInterval(this.cleanupInterval) + this.cleanupInterval = null + } + this.cache.clear() + } + + /** + * Get cache stats for debugging + */ + stats(): { size: number; ttlMs: number } { + return { + size: this.cache.size, + ttlMs: this.ttlMs + } + } +} + +// Singleton cache instance +export const reasoningCache = new ReasoningCache() diff --git a/src/main/apiServer/services/messages.ts b/src/main/apiServer/services/messages.ts index e3fbd069a7..3277378266 100644 --- a/src/main/apiServer/services/messages.ts +++ b/src/main/apiServer/services/messages.ts @@ -4,6 +4,7 @@ import { loggerService } from '@logger' import anthropicService from '@main/services/AnthropicService' import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic' import type { Provider } from '@types' +import { APICallError } from 'ai' import { net } from 'electron' import type { Response } from 'express' @@ -253,9 +254,36 @@ export class MessagesService { } transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } { - let statusCode = 500 - let errorType = 'api_error' - let errorMessage = 'Internal server error' + let statusCode: number | undefined = undefined + let errorType: string | undefined = undefined + let errorMessage: string | undefined = undefined + + const errorMap: Record = { + 400: 'invalid_request_error', + 401: 'authentication_error', + 403: 'forbidden_error', + 404: 'not_found_error', + 429: 'rate_limit_error', + 500: 'internal_server_error' + } + + if (APICallError.isInstance(error)) { + statusCode = error.statusCode + errorMessage = error.message + if (statusCode) { + return { + statusCode, + errorResponse: { + type: 'error', + error: { + type: errorMap[statusCode] || 'api_error', + message: errorMessage, + requestId: error.name + } + } + } + } + } const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined const anthropicError = error?.error @@ -297,11 +325,11 @@ export class MessagesService { typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error' return { - statusCode, + statusCode: statusCode ?? 500, errorResponse: { type: 'error', error: { - type: errorType, + type: errorType || 'api_error', message: safeErrorMessage, requestId: error?.request_id } diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index 4370a429d0..5cd59377f6 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -23,11 +23,13 @@ import { } from '@shared/provider' import { defaultAppHeaders } from '@shared/utils' import type { Provider } from '@types' -import type { ImagePart, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai' +import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai' import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel } from 'ai' import { net } from 'electron' import type { Response } from 'express' +import { reasoningCache } from './cache' + const logger = loggerService.withContext('UnifiedMessagesService') const MAGIC_STRING = 'skip_thought_signature_validator' @@ -154,8 +156,6 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } } - // Build a map of tool_use_id -> toolName from all messages first - // This is needed because tool_result references tool_use from previous assistant messages const toolCallIdToName = new Map() for (const msg of params.messages) { if (Array.isArray(msg.content)) { @@ -227,13 +227,16 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts] if (assistantContent.length > 0) { let providerOptions: ProviderOptions | undefined = undefined - if (isGemini3ModelId(params.model)) { + if (reasoningCache.get('openrouter')) { + providerOptions = { + openrouter: { + reasoning_details: (reasoningCache.get('openrouter') as JSONValue[]) || [] + } + } + } else if (isGemini3ModelId(params.model)) { providerOptions = { google: { thoughtSignature: MAGIC_STRING - }, - openrouter: { - reasoning_details: [] } } } @@ -367,6 +370,7 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis middlewares, plugins, onEvent: (event) => { + logger.silly('Streaming event', { eventType: event.type }) const sseData = formatSSEEvent(event) response.write(sseData) } @@ -380,22 +384,6 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis onComplete?.() } catch (error) { logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId }) - - if (!response.writableEnded) { - try { - const errorMessage = error instanceof Error ? error.message : 'Unknown error' - response.write( - `event: error\ndata: ${JSON.stringify({ - type: 'error', - error: { type: 'api_error', message: errorMessage } - })}\n\n` - ) - response.end() - } catch { - // Response already ended - } - } - onError?.(error) throw error } diff --git a/src/main/services/agents/services/claudecode/claude-stream-state.ts b/src/main/services/agents/services/claudecode/claude-stream-state.ts index 30b5790c82..5266fda995 100644 --- a/src/main/services/agents/services/claudecode/claude-stream-state.ts +++ b/src/main/services/agents/services/claudecode/claude-stream-state.ts @@ -87,6 +87,7 @@ export class ClaudeStreamState { private pendingUsage: PendingUsageState = {} private pendingToolCalls = new Map() private stepActive = false + private _streamFinished = false constructor(options: ClaudeStreamStateOptions) { this.logger = loggerService.withContext('ClaudeStreamState') @@ -289,6 +290,16 @@ export class ClaudeStreamState { getNamespacedToolCallId(rawToolCallId: string): string { return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId) } + + /** Marks the stream as finished (either completed or errored). */ + markFinished(): void { + this._streamFinished = true + } + + /** Returns true if the stream has already emitted a terminal event. */ + isFinished(): boolean { + return this._streamFinished + } } export type { PendingToolCall } diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts index 261ff7c07e..00a395c751 100644 --- a/src/main/services/agents/services/claudecode/index.ts +++ b/src/main/services/agents/services/claudecode/index.ts @@ -529,6 +529,19 @@ class ClaudeCodeService implements AgentServiceInterface { return } + // Skip emitting error if stream already finished (error was handled via result message) + if (streamState.isFinished()) { + logger.debug('SDK process exited after stream finished, skipping duplicate error event', { + duration, + error: errorObj instanceof Error ? { name: errorObj.name, message: errorObj.message } : String(errorObj) + }) + // Still emit complete to signal stream end + stream.emit('data', { + type: 'complete' + }) + return + } + errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj)) const errorMessage = errorChunks.join('\n\n') logger.error('SDK query failed', { diff --git a/src/main/services/agents/services/claudecode/transform.ts b/src/main/services/agents/services/claudecode/transform.ts index fa0c615648..094076e500 100644 --- a/src/main/services/agents/services/claudecode/transform.ts +++ b/src/main/services/agents/services/claudecode/transform.ts @@ -121,7 +121,7 @@ export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state: case 'system': return handleSystemMessage(sdkMessage) case 'result': - return handleResultMessage(sdkMessage) + return handleResultMessage(sdkMessage, state) default: logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type }) return [] @@ -707,7 +707,13 @@ function handleSystemMessage(message: Extract): * Successful runs yield a `finish` frame with aggregated usage metrics, while * failures are surfaced as `error` frames. */ -function handleResultMessage(message: Extract): AgentStreamPart[] { +function handleResultMessage( + message: Extract, + state: ClaudeStreamState +): AgentStreamPart[] { + // Mark stream as finished to prevent duplicate error events when SDK process exits + state.markFinished() + const chunks: AgentStreamPart[] = [] let usage: LanguageModelUsage | undefined @@ -719,26 +725,33 @@ function handleResultMessage(message: Extract): } } - if (message.subtype === 'success') { - chunks.push({ - type: 'finish', - totalUsage: usage ?? emptyUsage, - finishReason: mapClaudeCodeFinishReason(message.subtype), - providerMetadata: { - ...sdkMessageToProviderMetadata(message), - usage: message.usage, - durationMs: message.duration_ms, - costUsd: message.total_cost_usd, - raw: message - } - } as AgentStreamPart) - } else { + chunks.push({ + type: 'finish', + totalUsage: usage ?? emptyUsage, + finishReason: mapClaudeCodeFinishReason(message.subtype), + providerMetadata: { + ...sdkMessageToProviderMetadata(message), + usage: message.usage, + durationMs: message.duration_ms, + costUsd: message.total_cost_usd, + raw: message + } + } as AgentStreamPart) + if (message.subtype !== 'success') { chunks.push({ type: 'error', error: { message: `${message.subtype}: Process failed after ${message.num_turns} turns` } } as AgentStreamPart) + } else { + if (message.is_error) { + const errorMatch = message.result.match(/\{.*\}/) + if (errorMatch) { + const errorDetail = JSON.parse(errorMatch[0]) + chunks.push(errorDetail) + } + } } return chunks } From 9d34098a5342b6f350f45695cf1a9363fad6f139 Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 28 Nov 2025 13:36:29 +0800 Subject: [PATCH 17/53] feat: enhance provider configuration and error handling for AI SDK integration --- packages/shared/provider/constant.ts | 26 +++++++ packages/shared/provider/sdk-config.ts | 12 +++- src/main/apiServer/services/messages.ts | 37 +++++++++- .../apiServer/services/unified-messages.ts | 70 ++++++++++++++++++- src/renderer/src/aiCore/provider/constants.ts | 26 +------ 5 files changed, 142 insertions(+), 29 deletions(-) create mode 100644 packages/shared/provider/constant.ts diff --git a/packages/shared/provider/constant.ts b/packages/shared/provider/constant.ts new file mode 100644 index 0000000000..fe47d6dcce --- /dev/null +++ b/packages/shared/provider/constant.ts @@ -0,0 +1,26 @@ +import { getLowerBaseModelName } from '@shared/utils/naming' + +import type { MinimalModel } from './types' + +export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1' +export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7' +export const COPILOT_INTEGRATION_ID = 'vscode-chat' +export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7' + +export const COPILOT_DEFAULT_HEADERS = { + 'Copilot-Integration-Id': COPILOT_INTEGRATION_ID, + 'User-Agent': COPILOT_USER_AGENT, + 'Editor-Version': COPILOT_EDITOR_VERSION, + 'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION, + 'editor-version': COPILOT_EDITOR_VERSION, + 'editor-plugin-version': COPILOT_PLUGIN_VERSION, + 'copilot-vision-request': 'true' +} as const + +// Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560) +const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex'] + +export function isCopilotResponsesModel(model: M): boolean { + const normalizedId = getLowerBaseModelName(model.id) + return COPILOT_RESPONSES_MODEL_IDS.some((target) => normalizedId === target) +} diff --git a/packages/shared/provider/sdk-config.ts b/packages/shared/provider/sdk-config.ts index a03b3b1417..e520cb6350 100644 --- a/packages/shared/provider/sdk-config.ts +++ b/packages/shared/provider/sdk-config.ts @@ -127,7 +127,7 @@ export function providerToAiSdkConfig( if (provider.id === SystemProviderIds.copilot) { const defaultHeaders = context.getCopilotDefaultHeaders?.() ?? {} const storedHeaders = context.getCopilotStoredHeaders?.() ?? {} - const options = ProviderConfigFactory.fromProvider('github-copilot-openai-compatible', baseConfig, { + const copilotExtraOptions: Record = { headers: { ...defaultHeaders, ...storedHeaders, @@ -135,7 +135,15 @@ export function providerToAiSdkConfig( }, name: provider.id, includeUsage: true - }) + } + if (context.fetch) { + copilotExtraOptions.fetch = context.fetch + } + const options = ProviderConfigFactory.fromProvider( + 'github-copilot-openai-compatible', + baseConfig, + copilotExtraOptions + ) return { providerId: 'github-copilot-openai-compatible', diff --git a/src/main/apiServer/services/messages.ts b/src/main/apiServer/services/messages.ts index 3277378266..e2c9ad24e2 100644 --- a/src/main/apiServer/services/messages.ts +++ b/src/main/apiServer/services/messages.ts @@ -4,7 +4,7 @@ import { loggerService } from '@logger' import anthropicService from '@main/services/AnthropicService' import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic' import type { Provider } from '@types' -import { APICallError } from 'ai' +import { APICallError, RetryError } from 'ai' import { net } from 'electron' import type { Response } from 'express' @@ -267,6 +267,41 @@ export class MessagesService { 500: 'internal_server_error' } + // Handle AI SDK RetryError - extract the last error for better error messages + if (RetryError.isInstance(error)) { + const lastError = error.lastError + // If the last error is an APICallError, extract its details + if (APICallError.isInstance(lastError)) { + statusCode = lastError.statusCode || 502 + errorMessage = lastError.message + return { + statusCode, + errorResponse: { + type: 'error', + error: { + type: errorMap[statusCode] || 'api_error', + message: `${error.reason}: ${errorMessage}`, + requestId: lastError.name + } + } + } + } + // Fallback for other retry errors + errorMessage = error.message + statusCode = 502 + return { + statusCode, + errorResponse: { + type: 'error', + error: { + type: 'api_error', + message: errorMessage, + requestId: error.name + } + } + } + } + if (APICallError.isInstance(error)) { statusCode = error.statusCode errorMessage = error.message diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index 5cd59377f6..51751202dd 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -9,6 +9,8 @@ import type { import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core' import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' +import anthropicService from '@main/services/AnthropicService' +import copilotService from '@main/services/CopilotService' import { reduxService } from '@main/services/ReduxService' import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters' import { isGemini3ModelId } from '@shared/middleware' @@ -21,6 +23,7 @@ import { providerToAiSdkConfig as sharedProviderToAiSdkConfig, resolveActualProvider } from '@shared/provider' +import { COPILOT_DEFAULT_HEADERS } from '@shared/provider/constant' import { defaultAppHeaders } from '@shared/utils' import type { Provider } from '@types' import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai' @@ -284,6 +287,68 @@ async function createAiSdkProvider(config: AiSdkConfig): Promise return provider } +/** + * Prepare special provider configuration for providers that need dynamic tokens + * Similar to renderer's prepareSpecialProviderConfig + */ +async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkConfig): Promise { + switch (provider.id) { + case 'copilot': { + const storedHeaders = + ((await reduxService.select('state.copilot.defaultHeaders')) as Record | null) ?? {} + const headers: Record = { + ...COPILOT_DEFAULT_HEADERS, + ...storedHeaders + } + + try { + const { token } = await copilotService.getToken(null as any, headers) + config.options.apiKey = token + const existingHeaders = (config.options.headers as Record | undefined) ?? {} + config.options.headers = { + ...headers, + ...existingHeaders + } + logger.debug('Copilot token retrieved successfully') + } catch (error) { + logger.error('Failed to get Copilot token', error as Error) + throw new Error('Failed to get Copilot token. Please re-authorize Copilot.') + } + break + } + case 'anthropic': { + if (provider.authType === 'oauth') { + try { + const oauthToken = await anthropicService.getValidAccessToken() + if (!oauthToken) { + throw new Error('Anthropic OAuth token not available. Please re-authorize.') + } + config.options = { + ...config.options, + headers: { + ...(config.options.headers ? config.options.headers : {}), + 'Content-Type': 'application/json', + 'anthropic-version': '2023-06-01', + 'anthropic-beta': 'oauth-2025-04-20', + Authorization: `Bearer ${oauthToken}` + }, + baseURL: 'https://api.anthropic.com/v1', + apiKey: '' + } + logger.debug('Anthropic OAuth token retrieved successfully') + } catch (error) { + logger.error('Failed to get Anthropic OAuth token', error as Error) + throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.') + } + } + break + } + // Note: cherryai requires request-level signing which is not easily supported here + // It would need custom fetch implementation similar to renderer + } + return config +} + /** * Core stream execution function - single source of truth for AI SDK calls */ @@ -291,7 +356,10 @@ async function executeStream(config: ExecuteStreamConfig): Promise normalizedId === target || normalizedName === target) -} +export { COPILOT_DEFAULT_HEADERS, isCopilotResponsesModel } from '@shared/provider/constant' From 534d27f37eff5c438534f2fb824e67843ee4da8e Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 28 Nov 2025 13:37:43 +0800 Subject: [PATCH 18/53] feat: add additional model IDs for OpenAI Responses endpoint in Copilot --- packages/shared/provider/constant.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/shared/provider/constant.ts b/packages/shared/provider/constant.ts index fe47d6dcce..c449c9f635 100644 --- a/packages/shared/provider/constant.ts +++ b/packages/shared/provider/constant.ts @@ -18,7 +18,7 @@ export const COPILOT_DEFAULT_HEADERS = { } as const // Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560) -const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex'] +const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex', 'gpt-5.1-codex', 'gpt-5.1-codex-mini'] export function isCopilotResponsesModel(model: M): boolean { const normalizedId = getLowerBaseModelName(model.id) From 95c18d192a948297258d9e868140c44c3e157300 Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 28 Nov 2025 13:42:33 +0800 Subject: [PATCH 19/53] feat: add reasoning cache support to AiSdkToAnthropicSSE and update unified-messages integration --- packages/shared/adapters/AiSdkToAnthropicSSE.ts | 17 +++++++++++++++-- src/main/apiServer/services/unified-messages.ts | 3 ++- src/renderer/src/aiCore/provider/constants.ts | 2 +- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/packages/shared/adapters/AiSdkToAnthropicSSE.ts b/packages/shared/adapters/AiSdkToAnthropicSSE.ts index 9b23638f48..08d45a09d7 100644 --- a/packages/shared/adapters/AiSdkToAnthropicSSE.ts +++ b/packages/shared/adapters/AiSdkToAnthropicSSE.ts @@ -36,7 +36,6 @@ import type { Usage } from '@anthropic-ai/sdk/resources/messages' import { loggerService } from '@logger' -import { reasoningCache } from '@main/apiServer/services/cache' import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai' const logger = loggerService.withContext('AiSdkToAnthropicSSE') @@ -71,11 +70,22 @@ interface AdapterState { export type SSEEventCallback = (event: RawMessageStreamEvent) => void +/** + * Interface for a simple cache that stores reasoning details + */ +export interface ReasoningCacheInterface { + set(signature: string, details: unknown[]): void +} + export interface AiSdkToAnthropicSSEOptions { model: string messageId?: string inputTokens?: number onEvent: SSEEventCallback + /** + * Optional cache for storing reasoning details from providers like OpenRouter + */ + reasoningCache?: ReasoningCacheInterface } /** @@ -84,9 +94,11 @@ export interface AiSdkToAnthropicSSEOptions { export class AiSdkToAnthropicSSE { private state: AdapterState private onEvent: SSEEventCallback + private reasoningCache?: ReasoningCacheInterface constructor(options: AiSdkToAnthropicSSEOptions) { this.onEvent = options.onEvent + this.reasoningCache = options.reasoningCache this.state = { messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`, model: options.model, @@ -194,10 +206,11 @@ export class AiSdkToAnthropicSSE { case 'finish-step': if ( + this.reasoningCache && chunk.providerMetadata?.openrouter?.reasoning_details && Array.isArray(chunk.providerMetadata.openrouter.reasoning_details) ) { - reasoningCache.set('openrouter', chunk.providerMetadata?.openrouter?.reasoning_details) + this.reasoningCache.set('openrouter', chunk.providerMetadata.openrouter.reasoning_details) } if (chunk.finishReason === 'tool-calls') { this.state.stopReason = 'tool_use' diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index 51751202dd..298131460f 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -387,7 +387,8 @@ async function executeStream(config: ExecuteStreamConfig): Promise {}) + onEvent: onEvent || (() => {}), + reasoningCache }) // Execute stream - pass model object instead of string diff --git a/src/renderer/src/aiCore/provider/constants.ts b/src/renderer/src/aiCore/provider/constants.ts index 67cde7894d..57dad9fbc0 100644 --- a/src/renderer/src/aiCore/provider/constants.ts +++ b/src/renderer/src/aiCore/provider/constants.ts @@ -1 +1 @@ -export { COPILOT_DEFAULT_HEADERS, isCopilotResponsesModel } from '@shared/provider/constant' +export { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '@shared/provider/constant' From ed769ac4f7a96c4507e6d1002ef2e6a2e3c69517 Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 28 Nov 2025 15:48:50 +0800 Subject: [PATCH 20/53] feat: add CherryAI signed fetch wrapper and enhance tool conversion to Zod schema --- packages/shared/provider/sdk-config.ts | 15 +- .../apiServer/services/unified-messages.ts | 144 ++++++++++++++++-- 2 files changed, 147 insertions(+), 12 deletions(-) diff --git a/packages/shared/provider/sdk-config.ts b/packages/shared/provider/sdk-config.ts index e520cb6350..91b3c8d54e 100644 --- a/packages/shared/provider/sdk-config.ts +++ b/packages/shared/provider/sdk-config.ts @@ -88,6 +88,12 @@ export interface AiSdkConfigContext { * Renderer process: use browser fetch (default) */ fetch?: typeof globalThis.fetch + + /** + * Get CherryAI signed fetch wrapper + * Returns a fetch function that adds signature headers to requests + */ + getCherryAISignedFetch?: () => typeof globalThis.fetch } /** @@ -220,8 +226,13 @@ export function providerToAiSdkConfig( } } - // Inject custom fetch if provided - if (context.fetch) { + // Handle cherryai signed fetch + if (provider.id === 'cherryai') { + const signedFetch = context.getCherryAISignedFetch?.() + if (signedFetch) { + extraOptions.fetch = signedFetch + } + } else if (context.fetch) { extraOptions.fetch = context.fetch } diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index 298131460f..063885d72c 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -9,6 +9,7 @@ import type { import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core' import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' +import { generateSignature as cherryaiGenerateSignature } from '@main/integration/cherryai' import anthropicService from '@main/services/AnthropicService' import copilotService from '@main/services/CopilotService' import { reduxService } from '@main/services/ReduxService' @@ -26,10 +27,11 @@ import { import { COPILOT_DEFAULT_HEADERS } from '@shared/provider/constant' import { defaultAppHeaders } from '@shared/utils' import type { Provider } from '@types' -import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai' -import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel } from 'ai' +import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool as AiSdkTool } from 'ai' +import { simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel, zodSchema } from 'ai' import { net } from 'electron' import type { Response } from 'express' +import * as z from 'zod' import { reasoningCache } from './cache' @@ -124,19 +126,119 @@ function convertAnthropicToolResultToAiSdk( return { type: 'content', value: values } } -function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record | undefined { +// Type alias for JSON Schema (compatible with recursive calls) +type JsonSchemaLike = AnthropicTool.InputSchema | Record + +/** + * Convert JSON Schema to Zod schema + * This avoids non-standard fields like input_examples that Anthropic doesn't support + */ +function jsonSchemaToZod(schema: JsonSchemaLike): z.ZodTypeAny { + const s = schema as Record + const schemaType = s.type as string | string[] | undefined + const enumValues = s.enum as unknown[] | undefined + const description = s.description as string | undefined + + // Handle enum first + if (enumValues && Array.isArray(enumValues) && enumValues.length > 0) { + if (enumValues.every((v) => typeof v === 'string')) { + const zodEnum = z.enum(enumValues as [string, ...string[]]) + return description ? zodEnum.describe(description) : zodEnum + } + // For non-string enums, use union of literals + const literals = enumValues.map((v) => z.literal(v as string | number | boolean)) + if (literals.length === 1) { + return description ? literals[0].describe(description) : literals[0] + } + const zodUnion = z.union(literals as unknown as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]]) + return description ? zodUnion.describe(description) : zodUnion + } + + // Handle union types (type: ["string", "null"]) + if (Array.isArray(schemaType)) { + const schemas = schemaType.map((t) => jsonSchemaToZod({ ...s, type: t, enum: undefined })) + if (schemas.length === 1) { + return schemas[0] + } + return z.union(schemas as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]]) + } + + // Handle by type + switch (schemaType) { + case 'string': { + let zodString = z.string() + if (typeof s.minLength === 'number') zodString = zodString.min(s.minLength) + if (typeof s.maxLength === 'number') zodString = zodString.max(s.maxLength) + if (typeof s.pattern === 'string') zodString = zodString.regex(new RegExp(s.pattern)) + return description ? zodString.describe(description) : zodString + } + + case 'number': + case 'integer': { + let zodNumber = schemaType === 'integer' ? z.number().int() : z.number() + if (typeof s.minimum === 'number') zodNumber = zodNumber.min(s.minimum) + if (typeof s.maximum === 'number') zodNumber = zodNumber.max(s.maximum) + return description ? zodNumber.describe(description) : zodNumber + } + + case 'boolean': { + const zodBoolean = z.boolean() + return description ? zodBoolean.describe(description) : zodBoolean + } + + case 'null': + return z.null() + + case 'array': { + const items = s.items as Record | undefined + let zodArray = items ? z.array(jsonSchemaToZod(items)) : z.array(z.unknown()) + if (typeof s.minItems === 'number') zodArray = zodArray.min(s.minItems) + if (typeof s.maxItems === 'number') zodArray = zodArray.max(s.maxItems) + return description ? zodArray.describe(description) : zodArray + } + + case 'object': { + const properties = s.properties as Record> | undefined + const required = (s.required as string[]) || [] + + // Always use z.object() to ensure "properties" field is present in output schema + // OpenAI requires explicit properties field even for empty objects + const shape: Record = {} + if (properties) { + for (const [key, propSchema] of Object.entries(properties)) { + const zodProp = jsonSchemaToZod(propSchema) + shape[key] = required.includes(key) ? zodProp : zodProp.optional() + } + } + + const zodObject = z.object(shape) + return description ? zodObject.describe(description) : zodObject + } + + default: + // Unknown type, use z.unknown() + return z.unknown() + } +} + +function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record | undefined { if (!tools || tools.length === 0) return undefined - const aiSdkTools: Record = {} + const aiSdkTools: Record = {} for (const anthropicTool of tools) { if (anthropicTool.type === 'bash_20250124') continue const toolDef = anthropicTool as AnthropicTool - const parameters = toolDef.input_schema as Parameters[0] - aiSdkTools[toolDef.name] = tool({ + const rawSchema = toolDef.input_schema + const schema = jsonSchemaToZod(rawSchema) + + // Use tool() with inputSchema (AI SDK v5 API) + const aiTool = tool({ description: toolDef.description || '', - inputSchema: jsonSchema(parameters), - execute: async (input: Record) => input + inputSchema: zodSchema(schema) }) + + logger.debug('Converted Anthropic tool to AI SDK tool', aiTool) + aiSdkTools[toolDef.name] = aiTool } return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined } @@ -343,8 +445,30 @@ async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkCon } break } - // Note: cherryai requires request-level signing which is not easily supported here - // It would need custom fetch implementation similar to renderer + case 'cherryai': { + // Create a signed fetch wrapper for cherryai + const baseFetch = net.fetch as typeof globalThis.fetch + config.options.fetch = async (url: RequestInfo | URL, options?: RequestInit) => { + if (!options?.body) { + return baseFetch(url, options) + } + const signature = cherryaiGenerateSignature({ + method: 'POST', + path: '/chat/completions', + query: '', + body: JSON.parse(options.body as string) + }) + return baseFetch(url, { + ...options, + headers: { + ...(options.headers as Record), + ...signature + } + }) + } + logger.debug('CherryAI signed fetch configured') + break + } } return config } From e8dccf51feab167102e9358e48d2067d2a722a22 Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 28 Nov 2025 16:37:58 +0800 Subject: [PATCH 21/53] feat: enhance reasoning cache integration and update provider options in unified messages --- .../shared/adapters/AiSdkToAnthropicSSE.ts | 26 +++++++++++-------- src/main/apiServer/services/cache.ts | 21 +++++---------- .../apiServer/services/unified-messages.ts | 15 ++++++++++- 3 files changed, 35 insertions(+), 27 deletions(-) diff --git a/packages/shared/adapters/AiSdkToAnthropicSSE.ts b/packages/shared/adapters/AiSdkToAnthropicSSE.ts index 08d45a09d7..a9f1508a6a 100644 --- a/packages/shared/adapters/AiSdkToAnthropicSSE.ts +++ b/packages/shared/adapters/AiSdkToAnthropicSSE.ts @@ -36,6 +36,8 @@ import type { Usage } from '@anthropic-ai/sdk/resources/messages' import { loggerService } from '@logger' +import { reasoningCache } from '@main/apiServer/services/cache' +import type { JSONValue } from 'ai' import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai' const logger = loggerService.withContext('AiSdkToAnthropicSSE') @@ -74,7 +76,7 @@ export type SSEEventCallback = (event: RawMessageStreamEvent) => void * Interface for a simple cache that stores reasoning details */ export interface ReasoningCacheInterface { - set(signature: string, details: unknown[]): void + set(signature: string, details: JSONValue): void } export interface AiSdkToAnthropicSSEOptions { @@ -82,9 +84,6 @@ export interface AiSdkToAnthropicSSEOptions { messageId?: string inputTokens?: number onEvent: SSEEventCallback - /** - * Optional cache for storing reasoning details from providers like OpenRouter - */ reasoningCache?: ReasoningCacheInterface } @@ -186,6 +185,17 @@ export class AiSdkToAnthropicSSE { // === Tool Events === case 'tool-call': + if (this.reasoningCache && chunk.providerMetadata?.google?.thoughtSignature) { + this.reasoningCache.set('google', chunk.providerMetadata?.google?.thoughtSignature) + } + // FIXME: 按toolcall id绑定 + if ( + this.reasoningCache && + chunk.providerMetadata?.openrouter?.reasoning_details && + Array.isArray(chunk.providerMetadata.openrouter.reasoning_details) + ) { + this.reasoningCache.set('openrouter', chunk.providerMetadata.openrouter.reasoning_details) + } this.handleToolCall({ type: 'tool-call', toolCallId: chunk.toolCallId, @@ -205,13 +215,6 @@ export class AiSdkToAnthropicSSE { break case 'finish-step': - if ( - this.reasoningCache && - chunk.providerMetadata?.openrouter?.reasoning_details && - Array.isArray(chunk.providerMetadata.openrouter.reasoning_details) - ) { - this.reasoningCache.set('openrouter', chunk.providerMetadata.openrouter.reasoning_details) - } if (chunk.finishReason === 'tool-calls') { this.state.stopReason = 'tool_use' } @@ -552,6 +555,7 @@ export class AiSdkToAnthropicSSE { } this.onEvent(messageStopEvent) + reasoningCache.destroy() } /** diff --git a/src/main/apiServer/services/cache.ts b/src/main/apiServer/services/cache.ts index 765ab1e1b9..39dc5b1544 100644 --- a/src/main/apiServer/services/cache.ts +++ b/src/main/apiServer/services/cache.ts @@ -1,11 +1,12 @@ import { loggerService } from '@logger' +import type { JSONValue } from 'ai' const logger = loggerService.withContext('Cache') /** * Cache entry with TTL support */ interface CacheEntry { - details: T[] + details: T timestamp: number } @@ -28,24 +29,19 @@ export class ReasoningCache { /** * Store reasoning details by signature */ - set(signature: string, details: T[]): void { - if (!signature || !details.length) return + set(signature: string, details: T): void { + if (!signature || !details) return this.cache.set(signature, { details, timestamp: Date.now() }) - - logger.debug('Cached reasoning details', { - signature: signature.substring(0, 20) + '...', - detailsCount: details.length - }) } /** * Retrieve reasoning details by signature */ - get(signature: string): T[] | undefined { + get(signature: string): T | undefined { const entry = this.cache.get(signature) if (!entry) return undefined @@ -55,11 +51,6 @@ export class ReasoningCache { return undefined } - logger.debug('Retrieved reasoning details from cache', { - signature: signature.substring(0, 20) + '...', - detailsCount: entry.details.length - }) - return entry.details } @@ -113,4 +104,4 @@ export class ReasoningCache { } // Singleton cache instance -export const reasoningCache = new ReasoningCache() +export const reasoningCache = new ReasoningCache() diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index 063885d72c..af97941f2b 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -301,11 +301,24 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage imageParts.push({ type: 'image', image: source.url }) } } else if (block.type === 'tool_use') { + const options: ProviderOptions = {} + if (isGemini3ModelId(params.model)) { + if (reasoningCache.get('google')) { + options.google = { + thoughtSignature: MAGIC_STRING + } + } else if (reasoningCache.get('openrouter')) { + options.openrouter = { + reasoning_details: (reasoningCache.get('openrouter') as JSONValue[]) || [] + } + } + } toolCallParts.push({ type: 'tool-call', toolName: block.name, toolCallId: block.id, - input: block.input + input: block.input, + providerOptions: options }) } else if (block.type === 'tool_result') { // Look up toolName from the pre-built map (covers cross-message references) From e255a992cccd1edc0ec11fce81c7272528ed9b07 Mon Sep 17 00:00:00 2001 From: suyao Date: Sat, 29 Nov 2025 17:39:31 +0800 Subject: [PATCH 22/53] fix: type check --- packages/shared/adapters/AiSdkToAnthropicSSE.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/shared/adapters/AiSdkToAnthropicSSE.ts b/packages/shared/adapters/AiSdkToAnthropicSSE.ts index a9f1508a6a..c6e7555ea3 100644 --- a/packages/shared/adapters/AiSdkToAnthropicSSE.ts +++ b/packages/shared/adapters/AiSdkToAnthropicSSE.ts @@ -36,7 +36,6 @@ import type { Usage } from '@anthropic-ai/sdk/resources/messages' import { loggerService } from '@logger' -import { reasoningCache } from '@main/apiServer/services/cache' import type { JSONValue } from 'ai' import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai' @@ -77,6 +76,7 @@ export type SSEEventCallback = (event: RawMessageStreamEvent) => void */ export interface ReasoningCacheInterface { set(signature: string, details: JSONValue): void + destroy?(): void } export interface AiSdkToAnthropicSSEOptions { @@ -555,7 +555,7 @@ export class AiSdkToAnthropicSSE { } this.onEvent(messageStopEvent) - reasoningCache.destroy() + this.reasoningCache?.destroy?.() } /** From 35cfc7c517ab9fa3de085bdb6a5f38282716919e Mon Sep 17 00:00:00 2001 From: suyao Date: Sat, 29 Nov 2025 19:12:56 +0800 Subject: [PATCH 23/53] feat: add sanitizeToolsForAnthropic function to clean tool definitions for Anthropic API --- packages/shared/anthropic/index.ts | 30 ++++++++++++++++++++++++- src/main/apiServer/services/messages.ts | 5 +++-- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/packages/shared/anthropic/index.ts b/packages/shared/anthropic/index.ts index 2444ad6113..e2113eb749 100644 --- a/packages/shared/anthropic/index.ts +++ b/packages/shared/anthropic/index.ts @@ -9,7 +9,7 @@ */ import Anthropic from '@anthropic-ai/sdk' -import type { TextBlockParam } from '@anthropic-ai/sdk/resources' +import type { MessageCreateParams, TextBlockParam, Tool as AnthropicTool } from '@anthropic-ai/sdk/resources' import { loggerService } from '@logger' import type { Provider } from '@types' import type { ModelMessage } from 'ai' @@ -193,3 +193,31 @@ export function buildClaudeCodeSystemModelMessage(system?: string | Array { + if ('type' in tool && tool.type !== 'custom') return tool + + // oxlint-disable-next-line no-unused-vars + const { input_examples, ...sanitizedTool } = tool as AnthropicTool & { input_examples?: unknown } + + return sanitizedTool as typeof tool + }) +} diff --git a/src/main/apiServer/services/messages.ts b/src/main/apiServer/services/messages.ts index e2c9ad24e2..957c066520 100644 --- a/src/main/apiServer/services/messages.ts +++ b/src/main/apiServer/services/messages.ts @@ -2,7 +2,7 @@ import type Anthropic from '@anthropic-ai/sdk' import type { MessageCreateParams, MessageStreamEvent } from '@anthropic-ai/sdk/resources' import { loggerService } from '@logger' import anthropicService from '@main/services/AnthropicService' -import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic' +import { buildClaudeCodeSystemMessage, getSdkClient, sanitizeToolsForAnthropic } from '@shared/anthropic' import type { Provider } from '@types' import { APICallError, RetryError } from 'ai' import { net } from 'electron' @@ -148,7 +148,8 @@ export class MessagesService { createAnthropicRequest(request: MessageCreateParams, provider: Provider, modelId?: string): MessageCreateParams { const anthropicRequest: MessageCreateParams = { ...request, - stream: !!request.stream + stream: !!request.stream, + tools: sanitizeToolsForAnthropic(request.tools) } // Override model if provided From 3989229f611268a129d3c9e26cbb01b637a12abc Mon Sep 17 00:00:00 2001 From: suyao Date: Sun, 30 Nov 2025 07:10:10 +0800 Subject: [PATCH 24/53] feat: enhance API version handling and cache functionality - Updated reasoning cache to use tool-specific keys for better organization. - Added methods to list cache keys and entries. - Improved API version regex patterns for more accurate matching. - Refactored API host formatting to handle leading/trailing whitespace and slashes. - Added functions to extract and remove trailing API version segments from URLs. --- .../shared/adapters/AiSdkToAnthropicSSE.ts | 3 +- packages/shared/api/index.ts | 86 +++++++++++++++++-- src/main/apiServer/services/cache.ts | 12 +++ .../apiServer/services/unified-messages.ts | 18 +--- 4 files changed, 92 insertions(+), 27 deletions(-) diff --git a/packages/shared/adapters/AiSdkToAnthropicSSE.ts b/packages/shared/adapters/AiSdkToAnthropicSSE.ts index c6e7555ea3..c4f2355c0a 100644 --- a/packages/shared/adapters/AiSdkToAnthropicSSE.ts +++ b/packages/shared/adapters/AiSdkToAnthropicSSE.ts @@ -186,7 +186,7 @@ export class AiSdkToAnthropicSSE { // === Tool Events === case 'tool-call': if (this.reasoningCache && chunk.providerMetadata?.google?.thoughtSignature) { - this.reasoningCache.set('google', chunk.providerMetadata?.google?.thoughtSignature) + this.reasoningCache.set(`google-${chunk.toolName}`, chunk.providerMetadata?.google?.thoughtSignature) } // FIXME: 按toolcall id绑定 if ( @@ -555,7 +555,6 @@ export class AiSdkToAnthropicSSE { } this.onEvent(messageStopEvent) - this.reasoningCache?.destroy?.() } /** diff --git a/packages/shared/api/index.ts b/packages/shared/api/index.ts index 5ee19611d8..2e85c11c36 100644 --- a/packages/shared/api/index.ts +++ b/packages/shared/api/index.ts @@ -27,18 +27,35 @@ export function withoutTrailingSlash(url: T): T { } /** - * Checks if the host path contains a version string (e.g., /v1, /v2beta). + * Matches a version segment in a path that starts with `/v` and optionally + * continues with `alpha` or `beta`. The segment may be followed by `/` or the end + * of the string (useful for cases like `/v3alpha/resources`). + */ +const VERSION_REGEX_PATTERN = '\\/v\\d+(?:alpha|beta)?(?=\\/|$)' + +/** + * Matches an API version at the end of a URL (with optional trailing slash). + * Used to detect and extract versions only from the trailing position. + */ +const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i + +/** + * 判断 host 的 path 中是否包含形如版本的字符串(例如 /v1、/v2beta 等), + * + * @param host - 要检查的 host 或 path 字符串 + * @returns 如果 path 中包含版本字符串则返回 true,否则 false */ export function hasAPIVersion(host?: string): boolean { if (!host) return false - const versionRegex = /\/v\d+(?:alpha|beta)?(?=\/|$)/i + const regex = new RegExp(VERSION_REGEX_PATTERN, 'i') try { const url = new URL(host) - return versionRegex.test(url.pathname) + return regex.test(url.pathname) } catch { - return versionRegex.test(host) + // 若无法作为完整 URL 解析,则当作路径直接检测 + return regex.test(host) } } @@ -71,22 +88,26 @@ export function formatVertexApiHost( /** * Formats an API host URL by normalizing it and optionally appending an API version. * - * @param host - The API host URL to format - * @param isSupportedAPIVersion - Whether the API version is supported. Defaults to `true`. + * @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed. + * @param supportApiVersion - Whether the API version is supported. Defaults to `true`. * @param apiVersion - The API version to append if needed. Defaults to `'v1'`. * + * @returns The formatted API host URL. If the host is empty after normalization, returns an empty string. + * If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host as-is. + * Otherwise, returns the host with the API version appended. + * * @example * formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1' * formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#' * formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2' */ -export function formatApiHost(host?: string, isSupportedAPIVersion: boolean = true, apiVersion: string = 'v1'): string { - const normalizedHost = withoutTrailingSlash((host || '').trim()) +export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string { + const normalizedHost = withoutTrailingSlash(trim(host)) if (!normalizedHost) { return '' } - if (normalizedHost.endsWith('#') || !isSupportedAPIVersion || hasAPIVersion(normalizedHost)) { + if (normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost)) { return normalizedHost } return `${normalizedHost}/${apiVersion}` @@ -175,3 +196,50 @@ export function validateApiHost(apiHost: string): boolean { return false } } + +/** + * Extracts the trailing API version segment from a URL path. + * + * This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL. + * Only versions at the end of the path are extracted, not versions in the middle. + * The returned version string does not include leading or trailing slashes. + * + * @param {string} url - The URL string to parse. + * @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found. + * + * @example + * getTrailingApiVersion('https://api.example.com/v1') // 'v1' + * getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta' + * getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end) + * getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta' + * getTrailingApiVersion('https://api.example.com') // undefined + */ +export function getTrailingApiVersion(url: string): string | undefined { + const match = url.match(TRAILING_VERSION_REGEX) + + if (match) { + // Extract version without leading slash and trailing slash + return match[0].replace(/^\//, '').replace(/\/$/, '') + } + + return undefined +} + +/** + * Removes the trailing API version segment from a URL path. + * + * This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL. + * Only versions at the end of the path are removed, not versions in the middle. + * + * @param {string} url - The URL string to process. + * @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found. + * + * @example + * withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com' + * withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com' + * withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change) + * withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com' + */ +export function withoutTrailingApiVersion(url: string): string { + return url.replace(TRAILING_VERSION_REGEX, '') +} diff --git a/src/main/apiServer/services/cache.ts b/src/main/apiServer/services/cache.ts index 39dc5b1544..9515778e16 100644 --- a/src/main/apiServer/services/cache.ts +++ b/src/main/apiServer/services/cache.ts @@ -54,6 +54,18 @@ export class ReasoningCache { return entry.details } + listKeys(): string[] { + return Array.from(this.cache.keys()) + } + + listEntries(): Array<{ key: string; entry: CacheEntry }> { + const entries: Array<{ key: string; entry: CacheEntry }> = [] + for (const [key, entry] of this.cache.entries()) { + entries.push({ key, entry }) + } + return entries + } + /** * Clear expired entries */ diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index af97941f2b..815b1217f2 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -237,7 +237,6 @@ function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Reco inputSchema: zodSchema(schema) }) - logger.debug('Converted Anthropic tool to AI SDK tool', aiTool) aiSdkTools[toolDef.name] = aiTool } return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined @@ -302,8 +301,9 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } } else if (block.type === 'tool_use') { const options: ProviderOptions = {} + if (isGemini3ModelId(params.model)) { - if (reasoningCache.get('google')) { + if (reasoningCache.get(`google-${block.name}`)) { options.google = { thoughtSignature: MAGIC_STRING } @@ -394,11 +394,6 @@ async function createAiSdkProvider(config: AiSdkConfig): Promise const provider = await createProviderCore(providerId, config.options) - logger.debug('AI SDK provider created', { - providerId, - hasOptions: !!config.options - }) - return provider } @@ -424,7 +419,6 @@ async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkCon ...headers, ...existingHeaders } - logger.debug('Copilot token retrieved successfully') } catch (error) { logger.error('Failed to get Copilot token', error as Error) throw new Error('Failed to get Copilot token. Please re-authorize Copilot.') @@ -450,7 +444,6 @@ async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkCon baseURL: 'https://api.anthropic.com/v1', apiKey: '' } - logger.debug('Anthropic OAuth token retrieved successfully') } catch (error) { logger.error('Failed to get Anthropic OAuth token', error as Error) throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.') @@ -479,7 +472,6 @@ async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkCon } }) } - logger.debug('CherryAI signed fetch configured') break } } @@ -498,12 +490,6 @@ async function executeStream(config: ExecuteStreamConfig): Promise Date: Sun, 30 Nov 2025 18:35:58 +0800 Subject: [PATCH 25/53] fix: update 'anthropic-beta' header and add authorization for longcat provider --- packages/shared/anthropic/index.ts | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/shared/anthropic/index.ts b/packages/shared/anthropic/index.ts index e2113eb749..78df7ff7af 100644 --- a/packages/shared/anthropic/index.ts +++ b/packages/shared/anthropic/index.ts @@ -11,7 +11,7 @@ import Anthropic from '@anthropic-ai/sdk' import type { MessageCreateParams, TextBlockParam, Tool as AnthropicTool } from '@anthropic-ai/sdk/resources' import { loggerService } from '@logger' -import type { Provider } from '@types' +import { type Provider, SystemProviderIds } from '@types' import type { ModelMessage } from 'ai' const logger = loggerService.withContext('anthropic-sdk') @@ -124,7 +124,7 @@ export function getSdkClient( baseURL, dangerouslyAllowBrowser: true, defaultHeaders: { - 'anthropic-beta': 'output-128k-2025-02-19', + 'anthropic-beta': 'interleaved-thinking-2025-05-14', 'APP-Code': 'MLTG2087', ...provider.extra_headers, ...extraHeaders @@ -139,7 +139,8 @@ export function getSdkClient( baseURL, dangerouslyAllowBrowser: true, defaultHeaders: { - 'anthropic-beta': 'output-128k-2025-02-19', + 'anthropic-beta': 'interleaved-thinking-2025-05-14', + Authorization: provider.id === SystemProviderIds.longcat ? `Bearer ${provider.apiKey}` : undefined, ...provider.extra_headers }, fetch: customFetch From 4a913fcef7fe2a559fb7c70a82e2dcf228c5961a Mon Sep 17 00:00:00 2001 From: suyao Date: Mon, 1 Dec 2025 00:36:55 +0800 Subject: [PATCH 26/53] Refactor: Remove old cache implementation and integrate new reasoning cache service - Deleted the old ReasoningCache class and its instance. - Introduced CacheService for managing reasoning caches. - Updated unified-messages service to utilize new googleReasoningCache and openRouterReasoningCache. - Added AiSdkToAnthropicSSE adapter to handle streaming events and integrate with new cache service. - Reorganized shared adapters to include the new AiSdkToAnthropicSSE adapter. - Created openrouter adapter with detailed reasoning schemas for better type safety and validation. --- .../adapters/AiSdkToAnthropicSSE.ts | 28 ++--- .../main/apiServer}/adapters/index.ts | 0 src/main/apiServer/adapters/openrouter.ts | 95 ++++++++++++++ src/main/apiServer/services/cache.ts | 119 ------------------ .../apiServer/services/unified-messages.ts | 21 ++-- src/main/services/CacheService.ts | 31 +++++ 6 files changed, 150 insertions(+), 144 deletions(-) rename {packages/shared => src/main/apiServer}/adapters/AiSdkToAnthropicSSE.ts (95%) rename {packages/shared => src/main/apiServer}/adapters/index.ts (100%) create mode 100644 src/main/apiServer/adapters/openrouter.ts delete mode 100644 src/main/apiServer/services/cache.ts diff --git a/packages/shared/adapters/AiSdkToAnthropicSSE.ts b/src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts similarity index 95% rename from packages/shared/adapters/AiSdkToAnthropicSSE.ts rename to src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts index c4f2355c0a..b5b52c4e03 100644 --- a/packages/shared/adapters/AiSdkToAnthropicSSE.ts +++ b/src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts @@ -36,9 +36,10 @@ import type { Usage } from '@anthropic-ai/sdk/resources/messages' import { loggerService } from '@logger' -import type { JSONValue } from 'ai' import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai' +import { googleReasoningCache, openRouterReasoningCache } from '../../services/CacheService' + const logger = loggerService.withContext('AiSdkToAnthropicSSE') interface ContentBlockState { @@ -71,20 +72,11 @@ interface AdapterState { export type SSEEventCallback = (event: RawMessageStreamEvent) => void -/** - * Interface for a simple cache that stores reasoning details - */ -export interface ReasoningCacheInterface { - set(signature: string, details: JSONValue): void - destroy?(): void -} - export interface AiSdkToAnthropicSSEOptions { model: string messageId?: string inputTokens?: number onEvent: SSEEventCallback - reasoningCache?: ReasoningCacheInterface } /** @@ -93,11 +85,9 @@ export interface AiSdkToAnthropicSSEOptions { export class AiSdkToAnthropicSSE { private state: AdapterState private onEvent: SSEEventCallback - private reasoningCache?: ReasoningCacheInterface constructor(options: AiSdkToAnthropicSSEOptions) { this.onEvent = options.onEvent - this.reasoningCache = options.reasoningCache this.state = { messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`, model: options.model, @@ -185,16 +175,22 @@ export class AiSdkToAnthropicSSE { // === Tool Events === case 'tool-call': - if (this.reasoningCache && chunk.providerMetadata?.google?.thoughtSignature) { - this.reasoningCache.set(`google-${chunk.toolName}`, chunk.providerMetadata?.google?.thoughtSignature) + if (googleReasoningCache && chunk.providerMetadata?.google?.thoughtSignature) { + googleReasoningCache.set( + `google-${chunk.toolName}`, + chunk.providerMetadata?.google?.thoughtSignature as string + ) } // FIXME: 按toolcall id绑定 if ( - this.reasoningCache && + openRouterReasoningCache && chunk.providerMetadata?.openrouter?.reasoning_details && Array.isArray(chunk.providerMetadata.openrouter.reasoning_details) ) { - this.reasoningCache.set('openrouter', chunk.providerMetadata.openrouter.reasoning_details) + openRouterReasoningCache.set( + 'openrouter', + JSON.parse(JSON.stringify(chunk.providerMetadata.openrouter.reasoning_details)) + ) } this.handleToolCall({ type: 'tool-call', diff --git a/packages/shared/adapters/index.ts b/src/main/apiServer/adapters/index.ts similarity index 100% rename from packages/shared/adapters/index.ts rename to src/main/apiServer/adapters/index.ts diff --git a/src/main/apiServer/adapters/openrouter.ts b/src/main/apiServer/adapters/openrouter.ts new file mode 100644 index 0000000000..3b63191781 --- /dev/null +++ b/src/main/apiServer/adapters/openrouter.ts @@ -0,0 +1,95 @@ +import * as z from 'zod/v4' + +enum ReasoningFormat { + Unknown = 'unknown', + OpenAIResponsesV1 = 'openai-responses-v1', + XAIResponsesV1 = 'xai-responses-v1', + AnthropicClaudeV1 = 'anthropic-claude-v1', + GoogleGeminiV1 = 'google-gemini-v1' +} + +// Anthropic Claude was the first reasoning that we're +// passing back and forth +export const DEFAULT_REASONING_FORMAT = ReasoningFormat.AnthropicClaudeV1 + +function isDefinedOrNotNull(value: T | null | undefined): value is T { + return value !== null && value !== undefined +} + +export enum ReasoningDetailType { + Summary = 'reasoning.summary', + Encrypted = 'reasoning.encrypted', + Text = 'reasoning.text' +} + +export const CommonReasoningDetailSchema = z + .object({ + id: z.string().nullish(), + format: z.enum(ReasoningFormat).nullish(), + index: z.number().optional() + }) + .loose() + +export const ReasoningDetailSummarySchema = z + .object({ + type: z.literal(ReasoningDetailType.Summary), + summary: z.string() + }) + .extend(CommonReasoningDetailSchema.shape) +export type ReasoningDetailSummary = z.infer + +export const ReasoningDetailEncryptedSchema = z + .object({ + type: z.literal(ReasoningDetailType.Encrypted), + data: z.string() + }) + .extend(CommonReasoningDetailSchema.shape) + +export type ReasoningDetailEncrypted = z.infer + +export const ReasoningDetailTextSchema = z + .object({ + type: z.literal(ReasoningDetailType.Text), + text: z.string().nullish(), + signature: z.string().nullish() + }) + .extend(CommonReasoningDetailSchema.shape) + +export type ReasoningDetailText = z.infer + +export const ReasoningDetailUnionSchema = z.union([ + ReasoningDetailSummarySchema, + ReasoningDetailEncryptedSchema, + ReasoningDetailTextSchema +]) + +export type ReasoningDetailUnion = z.infer + +const ReasoningDetailsWithUnknownSchema = z.union([ReasoningDetailUnionSchema, z.unknown().transform(() => null)]) + +export const ReasoningDetailArraySchema = z + .array(ReasoningDetailsWithUnknownSchema) + .transform((d) => d.filter((d): d is ReasoningDetailUnion => !!d)) + +export const OutputUnionToReasoningDetailsSchema = z.union([ + z + .object({ + delta: z.object({ + reasoning_details: z.array(ReasoningDetailsWithUnknownSchema) + }) + }) + .transform((data) => data.delta.reasoning_details.filter(isDefinedOrNotNull)), + z + .object({ + message: z.object({ + reasoning_details: z.array(ReasoningDetailsWithUnknownSchema) + }) + }) + .transform((data) => data.message.reasoning_details.filter(isDefinedOrNotNull)), + z + .object({ + text: z.string(), + reasoning_details: z.array(ReasoningDetailsWithUnknownSchema) + }) + .transform((data) => data.reasoning_details.filter(isDefinedOrNotNull)) +]) diff --git a/src/main/apiServer/services/cache.ts b/src/main/apiServer/services/cache.ts deleted file mode 100644 index 9515778e16..0000000000 --- a/src/main/apiServer/services/cache.ts +++ /dev/null @@ -1,119 +0,0 @@ -import { loggerService } from '@logger' -import type { JSONValue } from 'ai' - -const logger = loggerService.withContext('Cache') -/** - * Cache entry with TTL support - */ -interface CacheEntry { - details: T - timestamp: number -} - -/** - * In-memory cache for reasoning details - * Key: signature - * Value: reasoning array with timestamp - */ -export class ReasoningCache { - private cache = new Map>() - private readonly ttlMs: number - private cleanupInterval: ReturnType | null = null - - constructor(ttlMs: number = 30 * 60 * 1000) { - // Default 30 minutes TTL - this.ttlMs = ttlMs - this.startCleanup() - } - - /** - * Store reasoning details by signature - */ - set(signature: string, details: T): void { - if (!signature || !details) return - - this.cache.set(signature, { - details, - timestamp: Date.now() - }) - } - - /** - * Retrieve reasoning details by signature - */ - get(signature: string): T | undefined { - const entry = this.cache.get(signature) - if (!entry) return undefined - - // Check TTL - if (Date.now() - entry.timestamp > this.ttlMs) { - this.cache.delete(signature) - return undefined - } - - return entry.details - } - - listKeys(): string[] { - return Array.from(this.cache.keys()) - } - - listEntries(): Array<{ key: string; entry: CacheEntry }> { - const entries: Array<{ key: string; entry: CacheEntry }> = [] - for (const [key, entry] of this.cache.entries()) { - entries.push({ key, entry }) - } - return entries - } - - /** - * Clear expired entries - */ - cleanup(): void { - const now = Date.now() - let cleaned = 0 - - for (const [key, entry] of this.cache) { - if (now - entry.timestamp > this.ttlMs) { - this.cache.delete(key) - cleaned++ - } - } - - if (cleaned > 0) { - logger.debug('Cleaned up expired reasoning cache entries', { cleaned, remaining: this.cache.size }) - } - } - - /** - * Start periodic cleanup - */ - private startCleanup(): void { - // Cleanup every 5 minutes - this.cleanupInterval = setInterval(() => this.cleanup(), 5 * 60 * 1000) - } - - /** - * Stop cleanup and clear cache - */ - destroy(): void { - if (this.cleanupInterval) { - clearInterval(this.cleanupInterval) - this.cleanupInterval = null - } - this.cache.clear() - } - - /** - * Get cache stats for debugging - */ - stats(): { size: number; ttlMs: number } { - return { - size: this.cache.size, - ttlMs: this.ttlMs - } - } -} - -// Singleton cache instance -export const reasoningCache = new ReasoningCache() diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index 815b1217f2..7c85e2d6cf 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -9,11 +9,11 @@ import type { import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core' import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' +import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@main/apiServer/adapters' import { generateSignature as cherryaiGenerateSignature } from '@main/integration/cherryai' import anthropicService from '@main/services/AnthropicService' import copilotService from '@main/services/CopilotService' import { reduxService } from '@main/services/ReduxService' -import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters' import { isGemini3ModelId } from '@shared/middleware' import { type AiSdkConfig, @@ -33,12 +33,16 @@ import { net } from 'electron' import type { Response } from 'express' import * as z from 'zod' -import { reasoningCache } from './cache' +import { googleReasoningCache, openRouterReasoningCache } from '../../services/CacheService' const logger = loggerService.withContext('UnifiedMessagesService') const MAGIC_STRING = 'skip_thought_signature_validator' +function sanitizeJson(value: unknown): JSONValue { + return JSON.parse(JSON.stringify(value)) +} + initializeSharedProviders({ warn: (message) => logger.warn(message), error: (message, error) => logger.error(message, error) @@ -303,13 +307,13 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage const options: ProviderOptions = {} if (isGemini3ModelId(params.model)) { - if (reasoningCache.get(`google-${block.name}`)) { + if (googleReasoningCache.get(`google-${block.name}`)) { options.google = { thoughtSignature: MAGIC_STRING } - } else if (reasoningCache.get('openrouter')) { + } else if (openRouterReasoningCache.get('openrouter')) { options.openrouter = { - reasoning_details: (reasoningCache.get('openrouter') as JSONValue[]) || [] + reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || [] } } } @@ -345,10 +349,10 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts] if (assistantContent.length > 0) { let providerOptions: ProviderOptions | undefined = undefined - if (reasoningCache.get('openrouter')) { + if (openRouterReasoningCache.get('openrouter')) { providerOptions = { openrouter: { - reasoning_details: (reasoningCache.get('openrouter') as JSONValue[]) || [] + reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || [] } } } else if (isGemini3ModelId(params.model)) { @@ -510,8 +514,7 @@ async function executeStream(config: ExecuteStreamConfig): Promise {}), - reasoningCache + onEvent: onEvent || (() => {}) }) // Execute stream - pass model object instead of string diff --git a/src/main/services/CacheService.ts b/src/main/services/CacheService.ts index d2984a9984..b9de349b7b 100644 --- a/src/main/services/CacheService.ts +++ b/src/main/services/CacheService.ts @@ -4,6 +4,26 @@ interface CacheItem { duration: number } +// Import the reasoning detail type from openrouter adapter +type ReasoningDetailUnion = { + id?: string | null + format?: 'unknown' | 'openai-responses-v1' | 'xai-responses-v1' | 'anthropic-claude-v1' | 'google-gemini-v1' | null + index?: number + type: 'reasoning.summary' | 'reasoning.encrypted' | 'reasoning.text' + summary?: string + data?: string + text?: string | null + signature?: string | null +} + +/** + * Interface for reasoning cache + */ +export interface IReasoningCache { + set(key: string, value: T): void + get(key: string): T | undefined +} + export class CacheService { private static cache: Map> = new Map() @@ -72,3 +92,14 @@ export class CacheService { return true } } + +// Singleton cache instances using CacheService +export const googleReasoningCache: IReasoningCache = { + set: (key, value) => CacheService.set(`google-reasoning:${key}`, value, 30 * 60 * 1000), + get: (key) => CacheService.get(`google-reasoning:${key}`) || undefined +} + +export const openRouterReasoningCache: IReasoningCache = { + set: (key, value) => CacheService.set(`openrouter-reasoning:${key}`, value, 30 * 60 * 1000), + get: (key) => CacheService.get(`openrouter-reasoning:${key}`) || undefined +} From 874d69291fa7df53a6e473f5b2d716be8821613d Mon Sep 17 00:00:00 2001 From: suyao Date: Mon, 1 Dec 2025 01:33:55 +0800 Subject: [PATCH 27/53] refactor: import --- src/main/services/CacheService.ts | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/main/services/CacheService.ts b/src/main/services/CacheService.ts index b9de349b7b..84c6935d3d 100644 --- a/src/main/services/CacheService.ts +++ b/src/main/services/CacheService.ts @@ -1,21 +1,11 @@ +import type { ReasoningDetailUnion } from '@main/apiServer/adapters/openrouter' + interface CacheItem { data: T timestamp: number duration: number } -// Import the reasoning detail type from openrouter adapter -type ReasoningDetailUnion = { - id?: string | null - format?: 'unknown' | 'openai-responses-v1' | 'xai-responses-v1' | 'anthropic-claude-v1' | 'google-gemini-v1' | null - index?: number - type: 'reasoning.summary' | 'reasoning.encrypted' | 'reasoning.text' - summary?: string - data?: string - text?: string | null - signature?: string | null -} - /** * Interface for reasoning cache */ From a23195296935dfcf1b28f46d3eee80512c184980 Mon Sep 17 00:00:00 2001 From: suyao Date: Mon, 1 Dec 2025 01:59:41 +0800 Subject: [PATCH 28/53] feat: Add model exclusion logic for the Azure OpenAI provider and update the tool call model filter. --- src/renderer/src/config/models/tooluse.ts | 18 ++++++++++++++++++ src/renderer/src/config/models/utils.ts | 10 +++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/renderer/src/config/models/tooluse.ts b/src/renderer/src/config/models/tooluse.ts index 7f90df5f7b..db0a767346 100644 --- a/src/renderer/src/config/models/tooluse.ts +++ b/src/renderer/src/config/models/tooluse.ts @@ -1,6 +1,8 @@ +import { getProviderByModel } from '@renderer/services/AssistantService' import type { Model } from '@renderer/types' import { isSystemProviderId } from '@renderer/types' import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils' +import { isAzureOpenAIProvider } from '@shared/provider' import { isEmbeddingModel, isRerankModel } from './embedding' import { isDeepSeekHybridInferenceModel } from './reasoning' @@ -52,6 +54,13 @@ export const FUNCTION_CALLING_REGEX = new RegExp( 'i' ) +const AZURE_FUNCTION_CALLING_EXCLUDED_MODELS = [ + '(?:Meta-)?Llama-3(?:\\.\\d+)?-[\\w-]+', + 'Phi-[34](?:\\.[\\w-]+)?(?:-[\\w-]+)?', + 'DeepSeek-(?:R1|V3)', + 'Codestral-2501' +] + export function isFunctionCallingModel(model?: Model): boolean { if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) { return false @@ -67,6 +76,15 @@ export function isFunctionCallingModel(model?: Model): boolean { return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name) } + const provider = getProviderByModel(model) + + if (isAzureOpenAIProvider(provider)) { + const azureExcludedRegex = new RegExp(`\\b(?:${AZURE_FUNCTION_CALLING_EXCLUDED_MODELS.join('|')})\\b`, 'i') + if (azureExcludedRegex.test(modelId)) { + return false + } + } + if (['deepseek', 'anthropic', 'kimi', 'moonshot'].includes(model.provider)) { return true } diff --git a/src/renderer/src/config/models/utils.ts b/src/renderer/src/config/models/utils.ts index bd45ed224f..129dc4abfd 100644 --- a/src/renderer/src/config/models/utils.ts +++ b/src/renderer/src/config/models/utils.ts @@ -13,6 +13,7 @@ import { isOpenAIReasoningModel } from './openai' import { isQwenMTModel } from './qwen' +import { isFunctionCallingModel } from './tooluse' import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision' export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i') @@ -181,8 +182,15 @@ export const isGeminiModel = (model: Model) => { // zhipu 视觉推理模型用这组 special token 标记推理结果 export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const +// TODO: 支持提示词模式的工具调用 export const agentModelFilter = (model: Model): boolean => { - return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model) && !isGenerateImageModel(model) + return ( + !isEmbeddingModel(model) && + !isRerankModel(model) && + !isTextToImageModel(model) && + !isGenerateImageModel(model) && + isFunctionCallingModel(model) + ) } export const isMaxTemperatureOneModel = (model: Model): boolean => { From fb9a8e7e2ce0bbfaef74e1364a047137fc9ab97c Mon Sep 17 00:00:00 2001 From: suyao Date: Mon, 1 Dec 2025 02:45:27 +0800 Subject: [PATCH 29/53] fix: params map --- .../apiServer/services/unified-messages.ts | 68 ++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index 7c85e2d6cf..63bd461f5b 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -1,3 +1,6 @@ +import type { AnthropicProviderOptions } from '@ai-sdk/anthropic' +import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google' +import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai' import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider' import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' import type { @@ -20,6 +23,9 @@ import { type AiSdkConfigContext, formatProviderApiHost, initializeSharedProviders, + isAnthropicProvider, + isGeminiProvider, + isOpenAIProvider, type ProviderFormatContext, providerToAiSdkConfig as sharedProviderToAiSdkConfig, resolveActualProvider @@ -482,6 +488,63 @@ async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkCon return config } +function mapAnthropicThinkToAISdkProviderOptions( + provider: Provider, + config: MessageCreateParams['thinking'] +): ProviderOptions | undefined { + if (!config) return undefined + if (isAnthropicProvider(provider)) { + return { + anthropic: { + ...mapToAnthropicProviderOptions(config) + } + } + } + if (isGeminiProvider(provider)) { + return { + google: { + ...mapToGeminiProviderOptions(config) + } + } + } + if (isOpenAIProvider(provider)) { + return { + openai: { + ...mapToOpenAIProviderOptions(config) + } + } + } + return undefined +} + +function mapToAnthropicProviderOptions(config: NonNullable): AnthropicProviderOptions { + return { + thinking: { + type: config.type, + budgetTokens: config.type === 'enabled' ? config.budget_tokens : undefined + } + } +} + +function mapToGeminiProviderOptions( + config: NonNullable +): GoogleGenerativeAIProviderOptions { + return { + thinkingConfig: { + thinkingBudget: config.type === 'enabled' ? config.budget_tokens : -1, + includeThoughts: config.type === 'enabled' + } + } +} + +function mapToOpenAIProviderOptions( + config: NonNullable +): OpenAIResponsesProviderOptions { + return { + reasoningEffort: config.type === 'enabled' ? 'high' : 'none' + } +} + /** * Core stream execution function - single source of truth for AI SDK calls */ @@ -521,14 +584,17 @@ async function executeStream(config: ExecuteStreamConfig): Promise Date: Mon, 1 Dec 2025 13:45:52 +0800 Subject: [PATCH 30/53] gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index a8107fa93e..9322c8717e 100644 --- a/.gitignore +++ b/.gitignore @@ -73,3 +73,5 @@ test-results YOUR_MEMORY_FILE_PATH .sessions/ +.next/ +*.tsbuildinfo From 4d77202afdf017f02650aba3c30543e986e68e66 Mon Sep 17 00:00:00 2001 From: suyao Date: Mon, 1 Dec 2025 13:50:01 +0800 Subject: [PATCH 31/53] filter: copilot --- src/renderer/src/config/models/__tests__/utils.test.ts | 8 ++++++++ src/renderer/src/config/models/utils.ts | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/src/renderer/src/config/models/__tests__/utils.test.ts b/src/renderer/src/config/models/__tests__/utils.test.ts index b27ed930cd..618a9e9dfe 100644 --- a/src/renderer/src/config/models/__tests__/utils.test.ts +++ b/src/renderer/src/config/models/__tests__/utils.test.ts @@ -15,6 +15,7 @@ import { isSupportVerbosityModel } from '../openai' import { isQwenMTModel } from '../qwen' +import { isFunctionCallingModel } from '../tooluse' import { agentModelFilter, getModelSupportedVerbosity, @@ -112,6 +113,7 @@ const textToImageMock = vi.mocked(isTextToImageModel) const generateImageMock = vi.mocked(isGenerateImageModel) const reasoningMock = vi.mocked(isOpenAIReasoningModel) const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel) +const isFunctionCallingModelMock = vi.mocked(isFunctionCallingModel) describe('model utils', () => { beforeEach(() => { @@ -457,6 +459,12 @@ describe('model utils', () => { expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false) }) + it('filters out non-function-call models', () => { + rerankMock.mockReturnValue(false) + isFunctionCallingModelMock.mockReturnValueOnce(false) + expect(agentModelFilter(createModel({ id: 'DeepSeek R1' }))).toBe(false) + }) + it('filters out text-to-image models', () => { rerankMock.mockReturnValue(false) textToImageMock.mockReturnValueOnce(true) diff --git a/src/renderer/src/config/models/utils.ts b/src/renderer/src/config/models/utils.ts index 129dc4abfd..9a7d5fedfb 100644 --- a/src/renderer/src/config/models/utils.ts +++ b/src/renderer/src/config/models/utils.ts @@ -1,5 +1,6 @@ import type OpenAI from '@cherrystudio/openai' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding' +import { getProviderByModel } from '@renderer/services/AssistantService' import { type Model, SystemProviderIds } from '@renderer/types' import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes' import { getLowerBaseModelName } from '@renderer/utils' @@ -184,6 +185,12 @@ export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as con // TODO: 支持提示词模式的工具调用 export const agentModelFilter = (model: Model): boolean => { + const provider = getProviderByModel(model) + + // 需要适配,且容易超出限额 + if (provider.id === SystemProviderIds.copilot) { + return false + } return ( !isEmbeddingModel(model) && !isRerankModel(model) && From fd0be32ab44e4c170be09cde65104b0c2573a6fa Mon Sep 17 00:00:00 2001 From: suyao Date: Wed, 3 Dec 2025 11:58:46 +0800 Subject: [PATCH 32/53] fix: openrouter --- .../apiServer/adapters/AiSdkToAnthropicSSE.ts | 3 +- .../apiServer/services/unified-messages.ts | 36 ++++++++++++++----- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts b/src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts index b5b52c4e03..f24d8304a7 100644 --- a/src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts +++ b/src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts @@ -181,14 +181,13 @@ export class AiSdkToAnthropicSSE { chunk.providerMetadata?.google?.thoughtSignature as string ) } - // FIXME: 按toolcall id绑定 if ( openRouterReasoningCache && chunk.providerMetadata?.openrouter?.reasoning_details && Array.isArray(chunk.providerMetadata.openrouter.reasoning_details) ) { openRouterReasoningCache.set( - 'openrouter', + `openrouter-${chunk.toolCallId}`, JSON.parse(JSON.stringify(chunk.providerMetadata.openrouter.reasoning_details)) ) } diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index 63bd461f5b..c418f8c434 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -17,6 +17,7 @@ import { generateSignature as cherryaiGenerateSignature } from '@main/integratio import anthropicService from '@main/services/AnthropicService' import copilotService from '@main/services/CopilotService' import { reduxService } from '@main/services/ReduxService' +import type { OpenRouterProviderOptions } from '@openrouter/ai-sdk-provider' import { isGemini3ModelId } from '@shared/middleware' import { type AiSdkConfig, @@ -28,7 +29,8 @@ import { isOpenAIProvider, type ProviderFormatContext, providerToAiSdkConfig as sharedProviderToAiSdkConfig, - resolveActualProvider + resolveActualProvider, + SystemProviderIds } from '@shared/provider' import { COPILOT_DEFAULT_HEADERS } from '@shared/provider/constant' import { defaultAppHeaders } from '@shared/utils' @@ -311,18 +313,19 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } } else if (block.type === 'tool_use') { const options: ProviderOptions = {} - + logger.debug('Processing tool call block', { block, msgRole: msg.role, model: params.model }) if (isGemini3ModelId(params.model)) { if (googleReasoningCache.get(`google-${block.name}`)) { options.google = { thoughtSignature: MAGIC_STRING } - } else if (openRouterReasoningCache.get('openrouter')) { - options.openrouter = { - reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || [] - } } } + if (openRouterReasoningCache.get(`openrouter-${block.id}`)) { + options.openrouter = { + reasoning_details: (sanitizeJson(openRouterReasoningCache.get(`openrouter-${block.id}`)) as JSONValue[]) || [] + } + } toolCallParts.push({ type: 'tool-call', toolName: block.name, @@ -514,6 +517,13 @@ function mapAnthropicThinkToAISdkProviderOptions( } } } + if (provider.id === SystemProviderIds.openrouter) { + return { + openrouter: { + ...mapToOpenRouterProviderOptions(config) + } + } + } return undefined } @@ -545,6 +555,17 @@ function mapToOpenAIProviderOptions( } } +function mapToOpenRouterProviderOptions( + config: NonNullable +): OpenRouterProviderOptions { + return { + reasoning: { + enabled: config.type === 'enabled', + effort: 'high' + } + } +} + /** * Core stream execution function - single source of truth for AI SDK calls */ @@ -580,9 +601,8 @@ async function executeStream(config: ExecuteStreamConfig): Promise {}) }) - // Execute stream - pass model object instead of string const result = await executor.streamText({ - model, // Now passing LanguageModel object, not string + model, messages: coreMessages, // FIXME: Claude Code传入的maxToken会超出有些模型限制,需做特殊处理,可能在v2好修复一点,现在维护的成本有点高 // 已知: 豆包 From 08a537bfe4bfafcf1b7060557698f6c8001d191b Mon Sep 17 00:00:00 2001 From: suyao Date: Wed, 3 Dec 2025 12:15:53 +0800 Subject: [PATCH 33/53] fix: test --- .../apiServer/services/unified-messages.ts | 7 +++-- .../config/models/__tests__/tooluse.test.ts | 23 +++++++++++++++ .../src/config/models/__tests__/utils.test.ts | 28 +++++++++++++++++++ src/renderer/src/utils/api.ts | 1 + 4 files changed, 56 insertions(+), 3 deletions(-) diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index c418f8c434..cfcdb48393 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -322,9 +322,10 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } } if (openRouterReasoningCache.get(`openrouter-${block.id}`)) { - options.openrouter = { - reasoning_details: (sanitizeJson(openRouterReasoningCache.get(`openrouter-${block.id}`)) as JSONValue[]) || [] - } + options.openrouter = { + reasoning_details: + (sanitizeJson(openRouterReasoningCache.get(`openrouter-${block.id}`)) as JSONValue[]) || [] + } } toolCallParts.push({ type: 'tool-call', diff --git a/src/renderer/src/config/models/__tests__/tooluse.test.ts b/src/renderer/src/config/models/__tests__/tooluse.test.ts index e147e87f2f..24653f9a2c 100644 --- a/src/renderer/src/config/models/__tests__/tooluse.test.ts +++ b/src/renderer/src/config/models/__tests__/tooluse.test.ts @@ -6,6 +6,29 @@ import { isDeepSeekHybridInferenceModel } from '../reasoning' import { isFunctionCallingModel } from '../tooluse' import { isPureGenerateImageModel, isTextToImageModel } from '../vision' +vi.mock('@renderer/i18n', () => ({ + __esModule: true, + default: { + t: vi.fn((key: string) => key) + } +})) + +vi.mock('@renderer/services/AssistantService', () => ({ + getProviderByModel: vi.fn().mockReturnValue({ + id: 'openai', + type: 'openai', + name: 'OpenAI', + models: [] + }), + getAssistantSettings: vi.fn(), + getDefaultAssistant: vi.fn().mockReturnValue({ + id: 'default', + name: 'Default Assistant', + prompt: '', + settings: {} + }) +})) + vi.mock('@renderer/hooks/useStore', () => ({ getStoreProviders: vi.fn(() => []) })) diff --git a/src/renderer/src/config/models/__tests__/utils.test.ts b/src/renderer/src/config/models/__tests__/utils.test.ts index 618a9e9dfe..a9387cc3f7 100644 --- a/src/renderer/src/config/models/__tests__/utils.test.ts +++ b/src/renderer/src/config/models/__tests__/utils.test.ts @@ -68,6 +68,29 @@ vi.mock('@renderer/store/settings', () => { ) }) +vi.mock('@renderer/i18n', () => ({ + __esModule: true, + default: { + t: vi.fn((key: string) => key) + } +})) + +vi.mock('@renderer/services/AssistantService', () => ({ + getProviderByModel: vi.fn().mockReturnValue({ + id: 'openai', + type: 'openai', + name: 'OpenAI', + models: [] + }), + getAssistantSettings: vi.fn(), + getDefaultAssistant: vi.fn().mockReturnValue({ + id: 'default', + name: 'Default Assistant', + prompt: '', + settings: {} + }) +})) + vi.mock('@renderer/hooks/useSettings', () => ({ useSettings: vi.fn(() => ({})), useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })), @@ -98,6 +121,10 @@ vi.mock('../websearch', () => ({ isOpenAIWebSearchChatCompletionOnlyModel: vi.fn() })) +vi.mock('../tooluse', () => ({ + isFunctionCallingModel: vi.fn() +})) + const createModel = (overrides: Partial = {}): Model => ({ id: 'gpt-4o', name: 'gpt-4o', @@ -125,6 +152,7 @@ describe('model utils', () => { generateImageMock.mockReturnValue(false) reasoningMock.mockReturnValue(false) openAIWebSearchOnlyMock.mockReturnValue(false) + isFunctionCallingModelMock.mockReturnValue(true) }) describe('OpenAI model detection', () => { diff --git a/src/renderer/src/utils/api.ts b/src/renderer/src/utils/api.ts index 0bdb4cc999..3c9b8b0465 100644 --- a/src/renderer/src/utils/api.ts +++ b/src/renderer/src/utils/api.ts @@ -10,6 +10,7 @@ export { SUPPORTED_IMAGE_ENDPOINT_LIST, validateApiHost, withoutTrailingApiVersion, + withoutTrailingSharp, withoutTrailingSlash } from '@shared/api' From 0fc901108e2e86d10c8a95edfe7eecc9eb33b0ad Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 4 Dec 2025 22:54:50 +0800 Subject: [PATCH 34/53] fix: type check --- src/renderer/src/utils/api.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/src/renderer/src/utils/api.ts b/src/renderer/src/utils/api.ts index 3c9b8b0465..665b49c7c4 100644 --- a/src/renderer/src/utils/api.ts +++ b/src/renderer/src/utils/api.ts @@ -1,6 +1,7 @@ export { formatApiHost, formatAzureOpenAIApiHost, + formatOllamaApiHost, formatVertexApiHost, getAiSdkBaseUrl, getTrailingApiVersion, From 39d1c71819aa124ca2dda201a37d16028a240aac Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 5 Dec 2025 13:59:00 +0800 Subject: [PATCH 35/53] fix: type check --- src/main/apiServer/services/unified-messages.ts | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index cfcdb48393..27037f1a84 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -98,10 +98,6 @@ function getMainProcessFormatContext(): ProviderFormatContext { } const mainProcessSdkContext: AiSdkConfigContext = { - getRotatedApiKey: (provider) => { - const keys = provider.apiKey.split(',').map((k) => k.trim()) - return keys[0] || provider.apiKey - }, fetch: net.fetch as typeof globalThis.fetch } From c0c7d1b0dff72c0772733e6cd58e54d782e0c321 Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 11 Dec 2025 12:43:30 +0800 Subject: [PATCH 36/53] fix: test --- src/renderer/src/utils/api.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/src/renderer/src/utils/api.ts b/src/renderer/src/utils/api.ts index 665b49c7c4..f6e0c63c58 100644 --- a/src/renderer/src/utils/api.ts +++ b/src/renderer/src/utils/api.ts @@ -6,6 +6,7 @@ export { getAiSdkBaseUrl, getTrailingApiVersion, hasAPIVersion, + isWithTrailingSharp, routeToEndpoint, SUPPORTED_ENDPOINT_LIST, SUPPORTED_IMAGE_ENDPOINT_LIST, From 03dbc52477e65acd98e992ca0c1202bde4600c91 Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 11 Dec 2025 12:54:20 +0800 Subject: [PATCH 37/53] fix: test --- .../provider/__tests__/providerConfig.test.ts | 136 +++++------------- 1 file changed, 36 insertions(+), 100 deletions(-) diff --git a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts index 5d508f95fa..bbeedd69c9 100644 --- a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts +++ b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts @@ -110,6 +110,31 @@ const createWindowKeyv = () => { } } +/** + * 创建默认的 mock state,包含所有必需的字段 + */ +const createDefaultMockState = (overrides?: { + includeUsage?: boolean | undefined + copilotHeaders?: Record +}) => ({ + copilot: { defaultHeaders: overrides?.copilotHeaders ?? {} }, + settings: { + openAI: { + streamOptions: { + includeUsage: overrides?.includeUsage + } + } + }, + llm: { + settings: { + vertexai: { + projectId: '', + location: '' + } + } + } +}) + const createCopilotProvider = (): Provider => ({ id: 'copilot', type: 'openai', @@ -153,16 +178,7 @@ describe('Copilot responses routing', () => { ...(globalThis as any).window, keyv: createWindowKeyv() } - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: undefined - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState()) }) it('detects official GPT-5 Codex identifiers case-insensitively', () => { @@ -198,16 +214,7 @@ describe('CherryAI provider configuration', () => { ...(globalThis as any).window, keyv: createWindowKeyv() } - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: undefined - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState()) vi.clearAllMocks() }) @@ -279,16 +286,7 @@ describe('Perplexity provider configuration', () => { ...(globalThis as any).window, keyv: createWindowKeyv() } - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: undefined - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState()) vi.clearAllMocks() }) @@ -363,6 +361,7 @@ describe('Stream options includeUsage configuration', () => { ...(globalThis as any).window, keyv: createWindowKeyv() } + mockGetState.mockReturnValue(createDefaultMockState()) vi.clearAllMocks() }) @@ -377,16 +376,7 @@ describe('Stream options includeUsage configuration', () => { }) it('uses includeUsage from settings when undefined', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: undefined - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: undefined })) const provider = createOpenAIProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai')) @@ -395,16 +385,7 @@ describe('Stream options includeUsage configuration', () => { }) it('uses includeUsage from settings when set to true', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: true - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: true })) const provider = createOpenAIProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai')) @@ -413,16 +394,7 @@ describe('Stream options includeUsage configuration', () => { }) it('uses includeUsage from settings when set to false', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: false - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: false })) const provider = createOpenAIProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai')) @@ -431,16 +403,7 @@ describe('Stream options includeUsage configuration', () => { }) it('respects includeUsage setting for non-supporting providers', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: true - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: true })) const testProvider: Provider = { id: 'test', @@ -462,16 +425,7 @@ describe('Stream options includeUsage configuration', () => { }) it('uses includeUsage from settings for Copilot provider when set to false', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: false - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: false })) const provider = createCopilotProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot')) @@ -481,16 +435,7 @@ describe('Stream options includeUsage configuration', () => { }) it('uses includeUsage from settings for Copilot provider when set to true', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: true - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: true })) const provider = createCopilotProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot')) @@ -500,16 +445,7 @@ describe('Stream options includeUsage configuration', () => { }) it('uses includeUsage from settings for Copilot provider when undefined', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: undefined - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: undefined })) const provider = createCopilotProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot')) From dd2faa2b6a29b6c6314478745b4e00fc05608b77 Mon Sep 17 00:00:00 2001 From: suyao Date: Wed, 17 Dec 2025 18:17:26 +0800 Subject: [PATCH 38/53] chore: format --- src/main/services/agents/services/claudecode/index.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts index 0c36a6f61e..689c177ff5 100644 --- a/src/main/services/agents/services/claudecode/index.ts +++ b/src/main/services/agents/services/claudecode/index.ts @@ -107,7 +107,6 @@ class ClaudeCodeService implements AgentServiceInterface { const customGitBashPath = validateGitBashPath(configManager.get(ConfigKeys.GitBashPath) as string | undefined) - // Route through local API Server which handles format conversion via unified adapter // This enables Claude Code Agent to work with any provider (OpenAI, Gemini, etc.) // The API Server converts AI SDK responses to Anthropic SSE format transparently From 905e29007118e61139fae3e8b80aa73e4ad6cec9 Mon Sep 17 00:00:00 2001 From: suyao Date: Wed, 17 Dec 2025 18:21:42 +0800 Subject: [PATCH 39/53] chore: lint --- packages/shared/provider/sdk-config.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/shared/provider/sdk-config.ts b/packages/shared/provider/sdk-config.ts index 006de36ba4..b87a31ed25 100644 --- a/packages/shared/provider/sdk-config.ts +++ b/packages/shared/provider/sdk-config.ts @@ -6,6 +6,7 @@ */ import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider' +import { defaultAppHeaders } from '@shared/utils' import { isEmpty } from 'lodash' import { routeToEndpoint } from '../api' @@ -13,7 +14,6 @@ import { isOllamaProvider } from './detection' import { getAiSdkProviderId } from './mapping' import type { MinimalProvider } from './types' import { SystemProviderIds } from './types' -import { defaultAppHeaders } from '@shared/utils' /** * AI SDK configuration result From b33e5959556ff634fe64f3b99c97c4d8dd6cc8de Mon Sep 17 00:00:00 2001 From: suyao Date: Thu, 18 Dec 2025 14:10:42 +0800 Subject: [PATCH 40/53] Merge remote-tracking branch 'origin/main' into feat/proxy-api-server --- build/nsis-installer.nsh | 65 +- electron-builder.yml | 60 +- package.json | 2 +- packages/shared/IpcChannel.ts | 1 + packages/shared/config/constant.ts | 8 + packages/shared/provider/types.ts | 6 +- src/main/ipc.ts | 25 +- src/main/mcpServers/factory.ts | 2 +- src/main/mcpServers/filesystem.ts | 652 ------------------ src/main/mcpServers/filesystem/index.ts | 2 + src/main/mcpServers/filesystem/server.ts | 118 ++++ .../mcpServers/filesystem/tools/delete.ts | 93 +++ src/main/mcpServers/filesystem/tools/edit.ts | 130 ++++ src/main/mcpServers/filesystem/tools/glob.ts | 149 ++++ src/main/mcpServers/filesystem/tools/grep.ts | 266 +++++++ src/main/mcpServers/filesystem/tools/index.ts | 8 + src/main/mcpServers/filesystem/tools/ls.ts | 150 ++++ src/main/mcpServers/filesystem/tools/read.ts | 101 +++ src/main/mcpServers/filesystem/tools/write.ts | 83 +++ src/main/mcpServers/filesystem/types.ts | 627 +++++++++++++++++ src/main/services/ConfigManager.ts | 3 +- src/main/services/MCPService.ts | 20 + .../agents/services/claudecode/index.ts | 7 +- src/main/utils/__tests__/process.test.ts | 294 +++++++- src/main/utils/process.ts | 78 ++- src/preload/index.ts | 3 +- .../legacy/clients/openai/OpenAIApiClient.ts | 6 +- .../middleware/AiSdkMiddlewareBuilder.ts | 11 - .../plugins/searchOrchestrationPlugin.ts | 46 +- .../__tests__/model-parameters.test.ts | 2 +- .../aiCore/utils/__tests__/reasoning.test.ts | 107 ++- src/renderer/src/aiCore/utils/reasoning.ts | 54 +- .../src/assets/images/models/mimo.svg | 17 + .../src/assets/images/providers/mimo.svg | 17 + src/renderer/src/components/Icons/SVGIcon.tsx | 12 + .../components/Popups/agent/AgentModal.tsx | 177 +++-- .../config/models/__tests__/reasoning.test.ts | 122 +++- src/renderer/src/config/models/default.ts | 16 +- src/renderer/src/config/models/logo.ts | 4 +- src/renderer/src/config/models/reasoning.ts | 95 ++- src/renderer/src/config/models/tooluse.ts | 5 +- src/renderer/src/config/models/vision.ts | 2 +- src/renderer/src/config/providers.ts | 26 +- src/renderer/src/i18n/label.ts | 22 +- src/renderer/src/i18n/locales/en-us.json | 19 +- src/renderer/src/i18n/locales/zh-cn.json | 19 +- src/renderer/src/i18n/locales/zh-tw.json | 19 +- src/renderer/src/i18n/translate/de-de.json | 19 +- src/renderer/src/i18n/translate/el-gr.json | 19 +- src/renderer/src/i18n/translate/es-es.json | 19 +- src/renderer/src/i18n/translate/fr-fr.json | 19 +- src/renderer/src/i18n/translate/ja-jp.json | 19 +- src/renderer/src/i18n/translate/pt-pt.json | 19 +- src/renderer/src/i18n/translate/ru-ru.json | 19 +- .../tools/components/ThinkingButton.tsx | 44 +- .../MCPSettings/BuiltinMCPServerList.tsx | 11 +- .../ProviderSettings/ProviderSetting.tsx | 1 + src/renderer/src/services/ApiService.ts | 57 ++ src/renderer/src/services/AssistantService.ts | 5 +- src/renderer/src/services/KnowledgeService.ts | 132 ++++ .../src/services/OrchestrateService.ts | 91 --- .../src/services/StreamProcessingService.ts | 4 + .../callbacks/citationCallbacks.ts | 7 +- src/renderer/src/store/index.ts | 2 +- src/renderer/src/store/mcp.ts | 10 + src/renderer/src/store/migrate.ts | 15 + src/renderer/src/store/thunk/messageThunk.ts | 6 +- src/renderer/src/types/index.ts | 19 +- yarn.lock | 20 +- 69 files changed, 3189 insertions(+), 1119 deletions(-) delete mode 100644 src/main/mcpServers/filesystem.ts create mode 100644 src/main/mcpServers/filesystem/index.ts create mode 100644 src/main/mcpServers/filesystem/server.ts create mode 100644 src/main/mcpServers/filesystem/tools/delete.ts create mode 100644 src/main/mcpServers/filesystem/tools/edit.ts create mode 100644 src/main/mcpServers/filesystem/tools/glob.ts create mode 100644 src/main/mcpServers/filesystem/tools/grep.ts create mode 100644 src/main/mcpServers/filesystem/tools/index.ts create mode 100644 src/main/mcpServers/filesystem/tools/ls.ts create mode 100644 src/main/mcpServers/filesystem/tools/read.ts create mode 100644 src/main/mcpServers/filesystem/tools/write.ts create mode 100644 src/main/mcpServers/filesystem/types.ts create mode 100644 src/renderer/src/assets/images/models/mimo.svg create mode 100644 src/renderer/src/assets/images/providers/mimo.svg delete mode 100644 src/renderer/src/services/OrchestrateService.ts diff --git a/build/nsis-installer.nsh b/build/nsis-installer.nsh index 769ccaaa19..e644e18f3d 100644 --- a/build/nsis-installer.nsh +++ b/build/nsis-installer.nsh @@ -12,8 +12,13 @@ ; https://github.com/electron-userland/electron-builder/issues/1122 !ifndef BUILD_UNINSTALLER + ; Check VC++ Redistributable based on architecture stored in $1 Function checkVCRedist - ReadRegDWORD $0 HKLM "SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\x64" "Installed" + ${If} $1 == "arm64" + ReadRegDWORD $0 HKLM "SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\ARM64" "Installed" + ${Else} + ReadRegDWORD $0 HKLM "SOFTWARE\Microsoft\VisualStudio\14.0\VC\Runtimes\x64" "Installed" + ${EndIf} FunctionEnd Function checkArchitectureCompatibility @@ -97,29 +102,47 @@ Call checkVCRedist ${If} $0 != "1" - MessageBox MB_YESNO "\ - NOTE: ${PRODUCT_NAME} requires $\r$\n\ - 'Microsoft Visual C++ Redistributable'$\r$\n\ - to function properly.$\r$\n$\r$\n\ - Download and install now?" /SD IDYES IDYES InstallVCRedist IDNO DontInstall - InstallVCRedist: - inetc::get /CAPTION " " /BANNER "Downloading Microsoft Visual C++ Redistributable..." "https://aka.ms/vs/17/release/vc_redist.x64.exe" "$TEMP\vc_redist.x64.exe" - ExecWait "$TEMP\vc_redist.x64.exe /install /norestart" - ;IfErrors InstallError ContinueInstall ; vc_redist exit code is unreliable :( - Call checkVCRedist - ${If} $0 == "1" - Goto ContinueInstall - ${EndIf} + ; VC++ is required - install automatically since declining would abort anyway + ; Select download URL based on system architecture (stored in $1) + ${If} $1 == "arm64" + StrCpy $2 "https://aka.ms/vs/17/release/vc_redist.arm64.exe" + StrCpy $3 "$TEMP\vc_redist.arm64.exe" + ${Else} + StrCpy $2 "https://aka.ms/vs/17/release/vc_redist.x64.exe" + StrCpy $3 "$TEMP\vc_redist.x64.exe" + ${EndIf} - ;InstallError: - MessageBox MB_ICONSTOP "\ - There was an unexpected error installing$\r$\n\ - Microsoft Visual C++ Redistributable.$\r$\n\ - The installation of ${PRODUCT_NAME} cannot continue." - DontInstall: + inetc::get /CAPTION " " /BANNER "Downloading Microsoft Visual C++ Redistributable..." \ + $2 $3 /END + Pop $0 ; Get download status from inetc::get + ${If} $0 != "OK" + MessageBox MB_ICONSTOP|MB_YESNO "\ + Failed to download Microsoft Visual C++ Redistributable.$\r$\n$\r$\n\ + Error: $0$\r$\n$\r$\n\ + Would you like to open the download page in your browser?$\r$\n\ + $2" IDYES openDownloadUrl IDNO skipDownloadUrl + openDownloadUrl: + ExecShell "open" $2 + skipDownloadUrl: Abort + ${EndIf} + + ExecWait "$3 /install /quiet /norestart" + ; Note: vc_redist exit code is unreliable, verify via registry check instead + + Call checkVCRedist + ${If} $0 != "1" + MessageBox MB_ICONSTOP|MB_YESNO "\ + Microsoft Visual C++ Redistributable installation failed.$\r$\n$\r$\n\ + Would you like to open the download page in your browser?$\r$\n\ + $2$\r$\n$\r$\n\ + The installation of ${PRODUCT_NAME} cannot continue." IDYES openInstallUrl IDNO skipInstallUrl + openInstallUrl: + ExecShell "open" $2 + skipInstallUrl: + Abort + ${EndIf} ${EndIf} - ContinueInstall: Pop $4 Pop $3 Pop $2 diff --git a/electron-builder.yml b/electron-builder.yml index db1184be87..e3ab493666 100644 --- a/electron-builder.yml +++ b/electron-builder.yml @@ -134,54 +134,38 @@ artifactBuildCompleted: scripts/artifact-build-completed.js releaseInfo: releaseNotes: | - Cherry Studio 1.7.4 - New Browser MCP & Model Updates + Cherry Studio 1.7.5 - Filesystem MCP Overhaul & Topic Management - This release adds a powerful browser automation MCP server, new web search provider, and model support updates. + This release features a completely rewritten filesystem MCP server, new batch topic management, and improved assistant management. ✨ New Features - - [MCP] Add @cherry/browser CDP MCP server with session management for browser automation - - [Web Search] Add ExaMCP free web search provider (no API key required) - - [Model] Support GPT 5.2 series models - - [Model] Add capabilities support for Doubao Seed Code models (tool calling, reasoning, vision) - - 🔧 Improvements - - [Translate] Add reasoning effort option to translate service - - [i18n] Improve zh-TW Traditional Chinese locale - - [Settings] Update MCP Settings layout and styling + - [MCP] Rewrite filesystem MCP server with improved tool set (glob, ls, grep, read, write, edit, delete) + - [Topics] Add topic manage mode for batch delete and move operations with search functionality + - [Assistants] Merge import/subscribe popups and add export to assistant management + - [Knowledge] Use prompt injection for forced knowledge base search (faster response times) + - [Settings] Add tool use mode setting (prompt/function) to default assistant settings 🐛 Bug Fixes - - [Chat] Fix line numbers being wrongly copied from code blocks - - [Translate] Fix default to first supported reasoning effort when translating - - [Chat] Fix preserve thinking block in assistant messages - - [Web Search] Fix max search result limit - - [Embedding] Fix embedding dimensions retrieval for ModernAiProvider - - [Chat] Fix token calculation in prompt tool use plugin - - [Model] Fix Ollama provider options for Qwen model support - - [UI] Fix Chat component marginRight calculation for improved layout + - [Model] Correct typo in Gemini 3 Pro Image Preview model name + - [Installer] Auto-install VC++ Redistributable without user prompt + - [Notes] Fix notes directory validation and default path reset for cross-platform restore + - [OAuth] Bind OAuth callback server to localhost (127.0.0.1) for security - Cherry Studio 1.7.4 - 新增浏览器 MCP 与模型更新 + Cherry Studio 1.7.5 - 文件系统 MCP 重构与话题管理 - 本次更新新增强大的浏览器自动化 MCP 服务器、新的网页搜索提供商以及模型支持更新。 + 本次更新完全重写了文件系统 MCP 服务器,新增批量话题管理功能,并改进了助手管理。 ✨ 新功能 - - [MCP] 新增 @cherry/browser CDP MCP 服务器,支持会话管理的浏览器自动化 - - [网页搜索] 新增 ExaMCP 免费网页搜索提供商(无需 API 密钥) - - [模型] 支持 GPT 5.2 系列模型 - - [模型] 为豆包 Seed Code 模型添加能力支持(工具调用、推理、视觉) - - 🔧 功能改进 - - [翻译] 为翻译服务添加推理强度选项 - - [国际化] 改进繁体中文(zh-TW)本地化 - - [设置] 优化 MCP 设置布局和样式 + - [MCP] 重写文件系统 MCP 服务器,提供改进的工具集(glob、ls、grep、read、write、edit、delete) + - [话题] 新增话题管理模式,支持批量删除和移动操作,带搜索功能 + - [助手] 合并导入/订阅弹窗,并在助手管理中添加导出功能 + - [知识库] 使用提示词注入进行强制知识库搜索(响应更快) + - [设置] 在默认助手设置中添加工具使用模式设置(prompt/function) 🐛 问题修复 - - [聊天] 修复代码块中行号被错误复制的问题 - - [翻译] 修复翻译时默认使用第一个支持的推理强度 - - [聊天] 修复助手消息中思考块的保留问题 - - [网页搜索] 修复最大搜索结果数限制 - - [嵌入] 修复 ModernAiProvider 嵌入维度获取问题 - - [聊天] 修复提示词工具使用插件的 token 计算问题 - - [模型] 修复 Ollama 提供商对 Qwen 模型的支持选项 - - [界面] 修复聊天组件右边距计算以改善布局 + - [模型] 修正 Gemini 3 Pro Image Preview 模型名称的拼写错误 + - [安装程序] 自动安装 VC++ 运行库,无需用户确认 + - [笔记] 修复跨平台恢复场景下的笔记目录验证和默认路径重置逻辑 + - [OAuth] 将 OAuth 回调服务器绑定到 localhost (127.0.0.1) 以提高安全性 diff --git a/package.json b/package.json index 3fd6d1741f..b894030b58 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "CherryStudio", - "version": "1.7.4", + "version": "1.7.5", "private": true, "description": "A powerful AI assistant for producer.", "main": "./out/main/index.js", diff --git a/packages/shared/IpcChannel.ts b/packages/shared/IpcChannel.ts index 0ebe48266d..aec1d57b43 100644 --- a/packages/shared/IpcChannel.ts +++ b/packages/shared/IpcChannel.ts @@ -244,6 +244,7 @@ export enum IpcChannel { System_GetCpuName = 'system:getCpuName', System_CheckGitBash = 'system:checkGitBash', System_GetGitBashPath = 'system:getGitBashPath', + System_GetGitBashPathInfo = 'system:getGitBashPathInfo', System_SetGitBashPath = 'system:setGitBashPath', // DevTools diff --git a/packages/shared/config/constant.ts b/packages/shared/config/constant.ts index 1e02ce7706..af0191f4fa 100644 --- a/packages/shared/config/constant.ts +++ b/packages/shared/config/constant.ts @@ -488,3 +488,11 @@ export const MACOS_TERMINALS_WITH_COMMANDS: TerminalConfigWithCommand[] = [ // resources/scripts should be maintained manually export const HOME_CHERRY_DIR = '.cherrystudio' + +// Git Bash path configuration types +export type GitBashPathSource = 'manual' | 'auto' + +export interface GitBashPathInfo { + path: string | null + source: GitBashPathSource | null +} diff --git a/packages/shared/provider/types.ts b/packages/shared/provider/types.ts index 763ed210c4..3dd56376db 100644 --- a/packages/shared/provider/types.ts +++ b/packages/shared/provider/types.ts @@ -100,7 +100,8 @@ export const SystemProviderIdSchema = z.enum([ 'huggingface', 'sophnet', 'gateway', - 'cerebras' + 'cerebras', + 'mimo' ]) export type SystemProviderId = z.infer @@ -169,7 +170,8 @@ export const SystemProviderIds = { longcat: 'longcat', huggingface: 'huggingface', gateway: 'gateway', - cerebras: 'cerebras' + cerebras: 'cerebras', + mimo: 'mimo' } as const satisfies Record export type SystemProviderIdTypeMap = typeof SystemProviderIds diff --git a/src/main/ipc.ts b/src/main/ipc.ts index d7e82ff875..4cb3402414 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -6,7 +6,14 @@ import { loggerService } from '@logger' import { isLinux, isMac, isPortable, isWin } from '@main/constant' import { generateSignature } from '@main/integration/cherryai' import anthropicService from '@main/services/AnthropicService' -import { findGitBash, getBinaryPath, isBinaryExists, runInstallScript, validateGitBashPath } from '@main/utils/process' +import { + autoDiscoverGitBash, + getBinaryPath, + getGitBashPathInfo, + isBinaryExists, + runInstallScript, + validateGitBashPath +} from '@main/utils/process' import { handleZoomFactor } from '@main/utils/zoom' import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' import type { UpgradeChannel } from '@shared/config/constant' @@ -499,9 +506,8 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { } try { - const customPath = configManager.get(ConfigKeys.GitBashPath) as string | undefined - const bashPath = findGitBash(customPath) - + // Use autoDiscoverGitBash to handle auto-discovery and persistence + const bashPath = autoDiscoverGitBash() if (bashPath) { logger.info('Git Bash is available', { path: bashPath }) return true @@ -524,13 +530,22 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { return customPath ?? null }) + // Returns { path, source } where source is 'manual' | 'auto' | null + ipcMain.handle(IpcChannel.System_GetGitBashPathInfo, () => { + return getGitBashPathInfo() + }) + ipcMain.handle(IpcChannel.System_SetGitBashPath, (_, newPath: string | null) => { if (!isWin) { return false } if (!newPath) { + // Clear manual setting and re-run auto-discovery configManager.set(ConfigKeys.GitBashPath, null) + configManager.set(ConfigKeys.GitBashPathSource, null) + // Re-run auto-discovery to restore auto-discovered path if available + autoDiscoverGitBash() return true } @@ -539,7 +554,9 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { return false } + // Set path with 'manual' source configManager.set(ConfigKeys.GitBashPath, validated) + configManager.set(ConfigKeys.GitBashPathSource, 'manual') return true }) diff --git a/src/main/mcpServers/factory.ts b/src/main/mcpServers/factory.ts index ce736f6843..909901c1c8 100644 --- a/src/main/mcpServers/factory.ts +++ b/src/main/mcpServers/factory.ts @@ -36,7 +36,7 @@ export function createInMemoryMCPServer( return new FetchServer().server } case BuiltinMCPServerNames.filesystem: { - return new FileSystemServer(args).server + return new FileSystemServer(envs.WORKSPACE_ROOT).server } case BuiltinMCPServerNames.difyKnowledge: { const difyKey = envs.DIFY_KEY diff --git a/src/main/mcpServers/filesystem.ts b/src/main/mcpServers/filesystem.ts deleted file mode 100644 index ba10783881..0000000000 --- a/src/main/mcpServers/filesystem.ts +++ /dev/null @@ -1,652 +0,0 @@ -// port https://github.com/modelcontextprotocol/servers/blob/main/src/filesystem/index.ts - -import { loggerService } from '@logger' -import { Server } from '@modelcontextprotocol/sdk/server/index.js' -import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js' -import { createTwoFilesPatch } from 'diff' -import fs from 'fs/promises' -import { minimatch } from 'minimatch' -import os from 'os' -import path from 'path' -import * as z from 'zod' - -const logger = loggerService.withContext('MCP:FileSystemServer') - -// Normalize all paths consistently -function normalizePath(p: string): string { - return path.normalize(p) -} - -function expandHome(filepath: string): string { - if (filepath.startsWith('~/') || filepath === '~') { - return path.join(os.homedir(), filepath.slice(1)) - } - return filepath -} - -// Security utilities -async function validatePath(allowedDirectories: string[], requestedPath: string): Promise { - const expandedPath = expandHome(requestedPath) - const absolute = path.isAbsolute(expandedPath) - ? path.resolve(expandedPath) - : path.resolve(process.cwd(), expandedPath) - - const normalizedRequested = normalizePath(absolute) - - // Check if path is within allowed directories - const isAllowed = allowedDirectories.some((dir) => normalizedRequested.startsWith(dir)) - if (!isAllowed) { - throw new Error( - `Access denied - path outside allowed directories: ${absolute} not in ${allowedDirectories.join(', ')}` - ) - } - - // Handle symlinks by checking their real path - try { - const realPath = await fs.realpath(absolute) - const normalizedReal = normalizePath(realPath) - const isRealPathAllowed = allowedDirectories.some((dir) => normalizedReal.startsWith(dir)) - if (!isRealPathAllowed) { - throw new Error('Access denied - symlink target outside allowed directories') - } - return realPath - } catch (error) { - // For new files that don't exist yet, verify parent directory - const parentDir = path.dirname(absolute) - try { - const realParentPath = await fs.realpath(parentDir) - const normalizedParent = normalizePath(realParentPath) - const isParentAllowed = allowedDirectories.some((dir) => normalizedParent.startsWith(dir)) - if (!isParentAllowed) { - throw new Error('Access denied - parent directory outside allowed directories') - } - return absolute - } catch { - throw new Error(`Parent directory does not exist: ${parentDir}`) - } - } -} - -// Schema definitions -const ReadFileArgsSchema = z.object({ - path: z.string() -}) - -const ReadMultipleFilesArgsSchema = z.object({ - paths: z.array(z.string()) -}) - -const WriteFileArgsSchema = z.object({ - path: z.string(), - content: z.string() -}) - -const EditOperation = z.object({ - oldText: z.string().describe('Text to search for - must match exactly'), - newText: z.string().describe('Text to replace with') -}) - -const EditFileArgsSchema = z.object({ - path: z.string(), - edits: z.array(EditOperation), - dryRun: z.boolean().default(false).describe('Preview changes using git-style diff format') -}) - -const CreateDirectoryArgsSchema = z.object({ - path: z.string() -}) - -const ListDirectoryArgsSchema = z.object({ - path: z.string() -}) - -const DirectoryTreeArgsSchema = z.object({ - path: z.string() -}) - -const MoveFileArgsSchema = z.object({ - source: z.string(), - destination: z.string() -}) - -const SearchFilesArgsSchema = z.object({ - path: z.string(), - pattern: z.string(), - excludePatterns: z.array(z.string()).optional().default([]) -}) - -const GetFileInfoArgsSchema = z.object({ - path: z.string() -}) - -interface FileInfo { - size: number - created: Date - modified: Date - accessed: Date - isDirectory: boolean - isFile: boolean - permissions: string -} - -// Tool implementations -async function getFileStats(filePath: string): Promise { - const stats = await fs.stat(filePath) - return { - size: stats.size, - created: stats.birthtime, - modified: stats.mtime, - accessed: stats.atime, - isDirectory: stats.isDirectory(), - isFile: stats.isFile(), - permissions: stats.mode.toString(8).slice(-3) - } -} - -async function searchFiles( - allowedDirectories: string[], - rootPath: string, - pattern: string, - excludePatterns: string[] = [] -): Promise { - const results: string[] = [] - - async function search(currentPath: string) { - const entries = await fs.readdir(currentPath, { withFileTypes: true }) - - for (const entry of entries) { - const fullPath = path.join(currentPath, entry.name) - - try { - // Validate each path before processing - await validatePath(allowedDirectories, fullPath) - - // Check if path matches any exclude pattern - const relativePath = path.relative(rootPath, fullPath) - const shouldExclude = excludePatterns.some((pattern) => { - const globPattern = pattern.includes('*') ? pattern : `**/${pattern}/**` - return minimatch(relativePath, globPattern, { dot: true }) - }) - - if (shouldExclude) { - continue - } - - if (entry.name.toLowerCase().includes(pattern.toLowerCase())) { - results.push(fullPath) - } - - if (entry.isDirectory()) { - await search(fullPath) - } - } catch (error) { - // Skip invalid paths during search - } - } - } - - await search(rootPath) - return results -} - -// file editing and diffing utilities -function normalizeLineEndings(text: string): string { - return text.replace(/\r\n/g, '\n') -} - -function createUnifiedDiff(originalContent: string, newContent: string, filepath: string = 'file'): string { - // Ensure consistent line endings for diff - const normalizedOriginal = normalizeLineEndings(originalContent) - const normalizedNew = normalizeLineEndings(newContent) - - return createTwoFilesPatch(filepath, filepath, normalizedOriginal, normalizedNew, 'original', 'modified') -} - -async function applyFileEdits( - filePath: string, - edits: Array<{ oldText: string; newText: string }>, - dryRun = false -): Promise { - // Read file content and normalize line endings - const content = normalizeLineEndings(await fs.readFile(filePath, 'utf-8')) - - // Apply edits sequentially - let modifiedContent = content - for (const edit of edits) { - const normalizedOld = normalizeLineEndings(edit.oldText) - const normalizedNew = normalizeLineEndings(edit.newText) - - // If exact match exists, use it - if (modifiedContent.includes(normalizedOld)) { - modifiedContent = modifiedContent.replace(normalizedOld, normalizedNew) - continue - } - - // Otherwise, try line-by-line matching with flexibility for whitespace - const oldLines = normalizedOld.split('\n') - const contentLines = modifiedContent.split('\n') - let matchFound = false - - for (let i = 0; i <= contentLines.length - oldLines.length; i++) { - const potentialMatch = contentLines.slice(i, i + oldLines.length) - - // Compare lines with normalized whitespace - const isMatch = oldLines.every((oldLine, j) => { - const contentLine = potentialMatch[j] - return oldLine.trim() === contentLine.trim() - }) - - if (isMatch) { - // Preserve original indentation of first line - const originalIndent = contentLines[i].match(/^\s*/)?.[0] || '' - const newLines = normalizedNew.split('\n').map((line, j) => { - if (j === 0) return originalIndent + line.trimStart() - // For subsequent lines, try to preserve relative indentation - const oldIndent = oldLines[j]?.match(/^\s*/)?.[0] || '' - const newIndent = line.match(/^\s*/)?.[0] || '' - if (oldIndent && newIndent) { - const relativeIndent = newIndent.length - oldIndent.length - return originalIndent + ' '.repeat(Math.max(0, relativeIndent)) + line.trimStart() - } - return line - }) - - contentLines.splice(i, oldLines.length, ...newLines) - modifiedContent = contentLines.join('\n') - matchFound = true - break - } - } - - if (!matchFound) { - throw new Error(`Could not find exact match for edit:\n${edit.oldText}`) - } - } - - // Create unified diff - const diff = createUnifiedDiff(content, modifiedContent, filePath) - - // Format diff with appropriate number of backticks - let numBackticks = 3 - while (diff.includes('`'.repeat(numBackticks))) { - numBackticks++ - } - const formattedDiff = `${'`'.repeat(numBackticks)}diff\n${diff}${'`'.repeat(numBackticks)}\n\n` - - if (!dryRun) { - await fs.writeFile(filePath, modifiedContent, 'utf-8') - } - - return formattedDiff -} - -class FileSystemServer { - public server: Server - private allowedDirectories: string[] - constructor(allowedDirs: string[]) { - if (!Array.isArray(allowedDirs) || allowedDirs.length === 0) { - throw new Error('No allowed directories provided, please specify at least one directory in args') - } - - this.allowedDirectories = allowedDirs.map((dir) => normalizePath(path.resolve(expandHome(dir)))) - - // Validate that all directories exist and are accessible - this.validateDirs().catch((error) => { - logger.error('Error validating allowed directories:', error) - throw new Error(`Error validating allowed directories: ${error}`) - }) - - this.server = new Server( - { - name: 'secure-filesystem-server', - version: '0.2.0' - }, - { - capabilities: { - tools: {} - } - } - ) - this.initialize() - } - - async validateDirs() { - // Validate that all directories exist and are accessible - await Promise.all( - this.allowedDirectories.map(async (dir) => { - try { - const stats = await fs.stat(expandHome(dir)) - if (!stats.isDirectory()) { - logger.error(`Error: ${dir} is not a directory`) - throw new Error(`Error: ${dir} is not a directory`) - } - } catch (error: any) { - logger.error(`Error accessing directory ${dir}:`, error) - throw new Error(`Error accessing directory ${dir}:`, error) - } - }) - ) - } - - initialize() { - // Tool handlers - this.server.setRequestHandler(ListToolsRequestSchema, async () => { - return { - tools: [ - { - name: 'read_file', - description: - 'Read the complete contents of a file from the file system. ' + - 'Handles various text encodings and provides detailed error messages ' + - 'if the file cannot be read. Use this tool when you need to examine ' + - 'the contents of a single file. Only works within allowed directories.', - inputSchema: z.toJSONSchema(ReadFileArgsSchema) - }, - { - name: 'read_multiple_files', - description: - 'Read the contents of multiple files simultaneously. This is more ' + - 'efficient than reading files one by one when you need to analyze ' + - "or compare multiple files. Each file's content is returned with its " + - "path as a reference. Failed reads for individual files won't stop " + - 'the entire operation. Only works within allowed directories.', - inputSchema: z.toJSONSchema(ReadMultipleFilesArgsSchema) - }, - { - name: 'write_file', - description: - 'Create a new file or completely overwrite an existing file with new content. ' + - 'Use with caution as it will overwrite existing files without warning. ' + - 'Handles text content with proper encoding. Only works within allowed directories.', - inputSchema: z.toJSONSchema(WriteFileArgsSchema) - }, - { - name: 'edit_file', - description: - 'Make line-based edits to a text file. Each edit replaces exact line sequences ' + - 'with new content. Returns a git-style diff showing the changes made. ' + - 'Only works within allowed directories.', - inputSchema: z.toJSONSchema(EditFileArgsSchema) - }, - { - name: 'create_directory', - description: - 'Create a new directory or ensure a directory exists. Can create multiple ' + - 'nested directories in one operation. If the directory already exists, ' + - 'this operation will succeed silently. Perfect for setting up directory ' + - 'structures for projects or ensuring required paths exist. Only works within allowed directories.', - inputSchema: z.toJSONSchema(CreateDirectoryArgsSchema) - }, - { - name: 'list_directory', - description: - 'Get a detailed listing of all files and directories in a specified path. ' + - 'Results clearly distinguish between files and directories with [FILE] and [DIR] ' + - 'prefixes. This tool is essential for understanding directory structure and ' + - 'finding specific files within a directory. Only works within allowed directories.', - inputSchema: z.toJSONSchema(ListDirectoryArgsSchema) - }, - { - name: 'directory_tree', - description: - 'Get a recursive tree view of files and directories as a JSON structure. ' + - "Each entry includes 'name', 'type' (file/directory), and 'children' for directories. " + - 'Files have no children array, while directories always have a children array (which may be empty). ' + - 'The output is formatted with 2-space indentation for readability. Only works within allowed directories.', - inputSchema: z.toJSONSchema(DirectoryTreeArgsSchema) - }, - { - name: 'move_file', - description: - 'Move or rename files and directories. Can move files between directories ' + - 'and rename them in a single operation. If the destination exists, the ' + - 'operation will fail. Works across different directories and can be used ' + - 'for simple renaming within the same directory. Both source and destination must be within allowed directories.', - inputSchema: z.toJSONSchema(MoveFileArgsSchema) - }, - { - name: 'search_files', - description: - 'Recursively search for files and directories matching a pattern. ' + - 'Searches through all subdirectories from the starting path. The search ' + - 'is case-insensitive and matches partial names. Returns full paths to all ' + - "matching items. Great for finding files when you don't know their exact location. " + - 'Only searches within allowed directories.', - inputSchema: z.toJSONSchema(SearchFilesArgsSchema) - }, - { - name: 'get_file_info', - description: - 'Retrieve detailed metadata about a file or directory. Returns comprehensive ' + - 'information including size, creation time, last modified time, permissions, ' + - 'and type. This tool is perfect for understanding file characteristics ' + - 'without reading the actual content. Only works within allowed directories.', - inputSchema: z.toJSONSchema(GetFileInfoArgsSchema) - }, - { - name: 'list_allowed_directories', - description: - 'Returns the list of directories that this server is allowed to access. ' + - 'Use this to understand which directories are available before trying to access files.', - inputSchema: { - type: 'object', - properties: {}, - required: [] - } - } - ] - } - }) - - this.server.setRequestHandler(CallToolRequestSchema, async (request) => { - try { - const { name, arguments: args } = request.params - - switch (name) { - case 'read_file': { - const parsed = ReadFileArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for read_file: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const content = await fs.readFile(validPath, 'utf-8') - return { - content: [{ type: 'text', text: content }] - } - } - - case 'read_multiple_files': { - const parsed = ReadMultipleFilesArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for read_multiple_files: ${parsed.error}`) - } - const results = await Promise.all( - parsed.data.paths.map(async (filePath: string) => { - try { - const validPath = await validatePath(this.allowedDirectories, filePath) - const content = await fs.readFile(validPath, 'utf-8') - return `${filePath}:\n${content}\n` - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error) - return `${filePath}: Error - ${errorMessage}` - } - }) - ) - return { - content: [{ type: 'text', text: results.join('\n---\n') }] - } - } - - case 'write_file': { - const parsed = WriteFileArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for write_file: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - await fs.writeFile(validPath, parsed.data.content, 'utf-8') - return { - content: [{ type: 'text', text: `Successfully wrote to ${parsed.data.path}` }] - } - } - - case 'edit_file': { - const parsed = EditFileArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for edit_file: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const result = await applyFileEdits(validPath, parsed.data.edits, parsed.data.dryRun) - return { - content: [{ type: 'text', text: result }] - } - } - - case 'create_directory': { - const parsed = CreateDirectoryArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for create_directory: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - await fs.mkdir(validPath, { recursive: true }) - return { - content: [{ type: 'text', text: `Successfully created directory ${parsed.data.path}` }] - } - } - - case 'list_directory': { - const parsed = ListDirectoryArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for list_directory: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const entries = await fs.readdir(validPath, { withFileTypes: true }) - const formatted = entries - .map((entry) => `${entry.isDirectory() ? '[DIR]' : '[FILE]'} ${entry.name}`) - .join('\n') - return { - content: [{ type: 'text', text: formatted }] - } - } - - case 'directory_tree': { - const parsed = DirectoryTreeArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for directory_tree: ${parsed.error}`) - } - - interface TreeEntry { - name: string - type: 'file' | 'directory' - children?: TreeEntry[] - } - - async function buildTree(allowedDirectories: string[], currentPath: string): Promise { - const validPath = await validatePath(allowedDirectories, currentPath) - const entries = await fs.readdir(validPath, { withFileTypes: true }) - const result: TreeEntry[] = [] - - for (const entry of entries) { - const entryData: TreeEntry = { - name: entry.name, - type: entry.isDirectory() ? 'directory' : 'file' - } - - if (entry.isDirectory()) { - const subPath = path.join(currentPath, entry.name) - entryData.children = await buildTree(allowedDirectories, subPath) - } - - result.push(entryData) - } - - return result - } - - const treeData = await buildTree(this.allowedDirectories, parsed.data.path) - return { - content: [ - { - type: 'text', - text: JSON.stringify(treeData, null, 2) - } - ] - } - } - - case 'move_file': { - const parsed = MoveFileArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for move_file: ${parsed.error}`) - } - const validSourcePath = await validatePath(this.allowedDirectories, parsed.data.source) - const validDestPath = await validatePath(this.allowedDirectories, parsed.data.destination) - await fs.rename(validSourcePath, validDestPath) - return { - content: [ - { type: 'text', text: `Successfully moved ${parsed.data.source} to ${parsed.data.destination}` } - ] - } - } - - case 'search_files': { - const parsed = SearchFilesArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for search_files: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const results = await searchFiles( - this.allowedDirectories, - validPath, - parsed.data.pattern, - parsed.data.excludePatterns - ) - return { - content: [{ type: 'text', text: results.length > 0 ? results.join('\n') : 'No matches found' }] - } - } - - case 'get_file_info': { - const parsed = GetFileInfoArgsSchema.safeParse(args) - if (!parsed.success) { - throw new Error(`Invalid arguments for get_file_info: ${parsed.error}`) - } - const validPath = await validatePath(this.allowedDirectories, parsed.data.path) - const info = await getFileStats(validPath) - return { - content: [ - { - type: 'text', - text: Object.entries(info) - .map(([key, value]) => `${key}: ${value}`) - .join('\n') - } - ] - } - } - - case 'list_allowed_directories': { - return { - content: [ - { - type: 'text', - text: `Allowed directories:\n${this.allowedDirectories.join('\n')}` - } - ] - } - } - - default: - throw new Error(`Unknown tool: ${name}`) - } - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error) - return { - content: [{ type: 'text', text: `Error: ${errorMessage}` }], - isError: true - } - } - }) - } -} - -export default FileSystemServer diff --git a/src/main/mcpServers/filesystem/index.ts b/src/main/mcpServers/filesystem/index.ts new file mode 100644 index 0000000000..cec4c31cdf --- /dev/null +++ b/src/main/mcpServers/filesystem/index.ts @@ -0,0 +1,2 @@ +// Re-export FileSystemServer to maintain existing import pattern +export { default, FileSystemServer } from './server' diff --git a/src/main/mcpServers/filesystem/server.ts b/src/main/mcpServers/filesystem/server.ts new file mode 100644 index 0000000000..164ba0c9c4 --- /dev/null +++ b/src/main/mcpServers/filesystem/server.ts @@ -0,0 +1,118 @@ +import { Server } from '@modelcontextprotocol/sdk/server/index.js' +import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js' +import { app } from 'electron' +import fs from 'fs/promises' +import path from 'path' + +import { + deleteToolDefinition, + editToolDefinition, + globToolDefinition, + grepToolDefinition, + handleDeleteTool, + handleEditTool, + handleGlobTool, + handleGrepTool, + handleLsTool, + handleReadTool, + handleWriteTool, + lsToolDefinition, + readToolDefinition, + writeToolDefinition +} from './tools' +import { logger } from './types' + +export class FileSystemServer { + public server: Server + private baseDir: string + + constructor(baseDir?: string) { + if (baseDir && path.isAbsolute(baseDir)) { + this.baseDir = baseDir + logger.info(`Using provided baseDir for filesystem MCP: ${baseDir}`) + } else { + const userData = app.getPath('userData') + this.baseDir = path.join(userData, 'Data', 'Workspace') + logger.info(`Using default workspace for filesystem MCP baseDir: ${this.baseDir}`) + } + + this.server = new Server( + { + name: 'filesystem-server', + version: '2.0.0' + }, + { + capabilities: { + tools: {} + } + } + ) + + this.initialize() + } + + async initialize() { + try { + await fs.mkdir(this.baseDir, { recursive: true }) + } catch (error) { + logger.error('Failed to create filesystem MCP baseDir', { error, baseDir: this.baseDir }) + } + + // Register tool list handler + this.server.setRequestHandler(ListToolsRequestSchema, async () => { + return { + tools: [ + globToolDefinition, + lsToolDefinition, + grepToolDefinition, + readToolDefinition, + editToolDefinition, + writeToolDefinition, + deleteToolDefinition + ] + } + }) + + // Register tool call handler + this.server.setRequestHandler(CallToolRequestSchema, async (request) => { + try { + const { name, arguments: args } = request.params + + switch (name) { + case 'glob': + return await handleGlobTool(args, this.baseDir) + + case 'ls': + return await handleLsTool(args, this.baseDir) + + case 'grep': + return await handleGrepTool(args, this.baseDir) + + case 'read': + return await handleReadTool(args, this.baseDir) + + case 'edit': + return await handleEditTool(args, this.baseDir) + + case 'write': + return await handleWriteTool(args, this.baseDir) + + case 'delete': + return await handleDeleteTool(args, this.baseDir) + + default: + throw new Error(`Unknown tool: ${name}`) + } + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + logger.error(`Tool execution error for ${request.params.name}:`, { error }) + return { + content: [{ type: 'text', text: `Error: ${errorMessage}` }], + isError: true + } + } + }) + } +} + +export default FileSystemServer diff --git a/src/main/mcpServers/filesystem/tools/delete.ts b/src/main/mcpServers/filesystem/tools/delete.ts new file mode 100644 index 0000000000..83becc4f17 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/delete.ts @@ -0,0 +1,93 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { logger, validatePath } from '../types' + +// Schema definition +export const DeleteToolSchema = z.object({ + path: z.string().describe('The path to the file or directory to delete'), + recursive: z.boolean().optional().describe('For directories, whether to delete recursively (default: false)') +}) + +// Tool definition with detailed description +export const deleteToolDefinition = { + name: 'delete', + description: `Deletes a file or directory from the filesystem. + +CAUTION: This operation cannot be undone! + +- For files: simply provide the path +- For empty directories: provide the path +- For non-empty directories: set recursive=true +- The path must be an absolute path, not a relative path +- Always verify the path before deleting to avoid data loss`, + inputSchema: z.toJSONSchema(DeleteToolSchema) +} + +// Handler implementation +export async function handleDeleteTool(args: unknown, baseDir: string) { + const parsed = DeleteToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for delete: ${parsed.error}`) + } + + const targetPath = parsed.data.path + const validPath = await validatePath(targetPath, baseDir) + const recursive = parsed.data.recursive || false + + // Check if path exists and get stats + let stats + try { + stats = await fs.stat(validPath) + } catch (error: any) { + if (error.code === 'ENOENT') { + throw new Error(`Path not found: ${targetPath}`) + } + throw error + } + + const isDirectory = stats.isDirectory() + const relativePath = path.relative(baseDir, validPath) + + // Perform deletion + try { + if (isDirectory) { + if (recursive) { + // Delete directory recursively + await fs.rm(validPath, { recursive: true, force: true }) + } else { + // Try to delete empty directory + await fs.rmdir(validPath) + } + } else { + // Delete file + await fs.unlink(validPath) + } + } catch (error: any) { + if (error.code === 'ENOTEMPTY') { + throw new Error(`Directory not empty: ${targetPath}. Use recursive=true to delete non-empty directories.`) + } + throw new Error(`Failed to delete: ${error.message}`) + } + + // Log the operation + logger.info('Path deleted', { + path: validPath, + type: isDirectory ? 'directory' : 'file', + recursive: isDirectory ? recursive : undefined + }) + + // Format output + const itemType = isDirectory ? 'Directory' : 'File' + const recursiveNote = isDirectory && recursive ? ' (recursive)' : '' + + return { + content: [ + { + type: 'text', + text: `${itemType} deleted${recursiveNote}: ${relativePath}` + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/edit.ts b/src/main/mcpServers/filesystem/tools/edit.ts new file mode 100644 index 0000000000..c1a0e637ce --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/edit.ts @@ -0,0 +1,130 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { logger, replaceWithFuzzyMatch, validatePath } from '../types' + +// Schema definition +export const EditToolSchema = z.object({ + file_path: z.string().describe('The path to the file to modify'), + old_string: z.string().describe('The text to replace'), + new_string: z.string().describe('The text to replace it with'), + replace_all: z.boolean().optional().default(false).describe('Replace all occurrences of old_string (default false)') +}) + +// Tool definition with detailed description +export const editToolDefinition = { + name: 'edit', + description: `Performs exact string replacements in files. + +- You must use the 'read' tool at least once before editing +- The file_path must be an absolute path, not a relative path +- Preserve exact indentation from read output (after the line number prefix) +- Never include line number prefixes in old_string or new_string +- ALWAYS prefer editing existing files over creating new ones +- The edit will FAIL if old_string is not found in the file +- The edit will FAIL if old_string appears multiple times (provide more context or use replace_all) +- The edit will FAIL if old_string equals new_string +- Use replace_all to rename variables or replace all occurrences`, + inputSchema: z.toJSONSchema(EditToolSchema) +} + +// Handler implementation +export async function handleEditTool(args: unknown, baseDir: string) { + const parsed = EditToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for edit: ${parsed.error}`) + } + + const { file_path: filePath, old_string: oldString, new_string: newString, replace_all: replaceAll } = parsed.data + + // Validate path + const validPath = await validatePath(filePath, baseDir) + + // Check if file exists + try { + const stats = await fs.stat(validPath) + if (!stats.isFile()) { + throw new Error(`Path is not a file: ${filePath}`) + } + } catch (error: any) { + if (error.code === 'ENOENT') { + // If old_string is empty, this is a create new file operation + if (oldString === '') { + // Create parent directory if needed + const parentDir = path.dirname(validPath) + await fs.mkdir(parentDir, { recursive: true }) + + // Write the new content + await fs.writeFile(validPath, newString, 'utf-8') + + logger.info('File created', { path: validPath }) + + const relativePath = path.relative(baseDir, validPath) + return { + content: [ + { + type: 'text', + text: `Created new file: ${relativePath}\nLines: ${newString.split('\n').length}` + } + ] + } + } + throw new Error(`File not found: ${filePath}`) + } + throw error + } + + // Read current content + const content = await fs.readFile(validPath, 'utf-8') + + // Handle special case: old_string is empty (create file with content) + if (oldString === '') { + await fs.writeFile(validPath, newString, 'utf-8') + + logger.info('File overwritten', { path: validPath }) + + const relativePath = path.relative(baseDir, validPath) + return { + content: [ + { + type: 'text', + text: `Overwrote file: ${relativePath}\nLines: ${newString.split('\n').length}` + } + ] + } + } + + // Perform the replacement with fuzzy matching + const newContent = replaceWithFuzzyMatch(content, oldString, newString, replaceAll) + + // Write the modified content + await fs.writeFile(validPath, newContent, 'utf-8') + + logger.info('File edited', { + path: validPath, + replaceAll + }) + + // Generate a simple diff summary + const oldLines = content.split('\n').length + const newLines = newContent.split('\n').length + const lineDiff = newLines - oldLines + + const relativePath = path.relative(baseDir, validPath) + let diffSummary = `Edited: ${relativePath}` + if (lineDiff > 0) { + diffSummary += `\n+${lineDiff} lines` + } else if (lineDiff < 0) { + diffSummary += `\n${lineDiff} lines` + } + + return { + content: [ + { + type: 'text', + text: diffSummary + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/glob.ts b/src/main/mcpServers/filesystem/tools/glob.ts new file mode 100644 index 0000000000..d6a6b4a757 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/glob.ts @@ -0,0 +1,149 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import type { FileInfo } from '../types' +import { logger, MAX_FILES_LIMIT, runRipgrep, validatePath } from '../types' + +// Schema definition +export const GlobToolSchema = z.object({ + pattern: z.string().describe('The glob pattern to match files against'), + path: z + .string() + .optional() + .describe('The directory to search in (must be absolute path). Defaults to the base directory') +}) + +// Tool definition with detailed description +export const globToolDefinition = { + name: 'glob', + description: `Fast file pattern matching tool that works with any codebase size. + +- Supports glob patterns like "**/*.js" or "src/**/*.ts" +- Returns matching absolute file paths sorted by modification time (newest first) +- Use this when you need to find files by name patterns +- Patterns without "/" (e.g., "*.txt") match files at ANY depth in the directory tree +- Patterns with "/" (e.g., "src/*.ts") match relative to the search path +- Pattern syntax: * (any chars), ** (any path), {a,b} (alternatives), ? (single char) +- Results are limited to 100 files +- The path parameter must be an absolute path if specified +- If path is not specified, defaults to the base directory +- IMPORTANT: Omit the path field for the default directory (don't use "undefined" or "null")`, + inputSchema: z.toJSONSchema(GlobToolSchema) +} + +// Handler implementation +export async function handleGlobTool(args: unknown, baseDir: string) { + const parsed = GlobToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for glob: ${parsed.error}`) + } + + const searchPath = parsed.data.path || baseDir + const validPath = await validatePath(searchPath, baseDir) + + // Verify the search directory exists + try { + const stats = await fs.stat(validPath) + if (!stats.isDirectory()) { + throw new Error(`Path is not a directory: ${validPath}`) + } + } catch (error: unknown) { + if (error && typeof error === 'object' && 'code' in error && error.code === 'ENOENT') { + throw new Error(`Directory not found: ${validPath}`) + } + throw error + } + + // Validate pattern + const pattern = parsed.data.pattern.trim() + if (!pattern) { + throw new Error('Pattern cannot be empty') + } + + const files: FileInfo[] = [] + let truncated = false + + // Build ripgrep arguments for file listing using --glob=pattern format + const rgArgs: string[] = [ + '--files', + '--follow', + '--hidden', + `--glob=${pattern}`, + '--glob=!.git/*', + '--glob=!node_modules/*', + '--glob=!dist/*', + '--glob=!build/*', + '--glob=!__pycache__/*', + validPath + ] + + // Use ripgrep for file listing + logger.debug('Running ripgrep with args', { rgArgs }) + const rgResult = await runRipgrep(rgArgs) + logger.debug('Ripgrep result', { + ok: rgResult.ok, + exitCode: rgResult.exitCode, + stdoutLength: rgResult.stdout.length, + stdoutPreview: rgResult.stdout.slice(0, 500) + }) + + // Process results if we have stdout content + // Exit code 2 can indicate partial errors (e.g., permission denied on some dirs) but still have valid results + if (rgResult.ok && rgResult.stdout.length > 0) { + const lines = rgResult.stdout.split('\n').filter(Boolean) + logger.debug('Parsed lines from ripgrep', { lineCount: lines.length, lines }) + + for (const line of lines) { + if (files.length >= MAX_FILES_LIMIT) { + truncated = true + break + } + + const filePath = line.trim() + if (!filePath) continue + + const absolutePath = path.isAbsolute(filePath) ? filePath : path.resolve(validPath, filePath) + + try { + const stats = await fs.stat(absolutePath) + files.push({ + path: absolutePath, + type: 'file', // ripgrep --files only returns files + size: stats.size, + modified: stats.mtime + }) + } catch (error) { + logger.debug('Failed to stat file from ripgrep output, skipping', { file: absolutePath, error }) + } + } + } + + // Sort by modification time (newest first) + files.sort((a, b) => { + const aTime = a.modified ? a.modified.getTime() : 0 + const bTime = b.modified ? b.modified.getTime() : 0 + return bTime - aTime + }) + + // Format output - always use absolute paths + const output: string[] = [] + if (files.length === 0) { + output.push(`No files found matching pattern "${parsed.data.pattern}" in ${validPath}`) + } else { + output.push(...files.map((f) => f.path)) + if (truncated) { + output.push('') + output.push(`(Results truncated to ${MAX_FILES_LIMIT} files. Consider using a more specific pattern.)`) + } + } + + return { + content: [ + { + type: 'text', + text: output.join('\n') + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/grep.ts b/src/main/mcpServers/filesystem/tools/grep.ts new file mode 100644 index 0000000000..d822db9d88 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/grep.ts @@ -0,0 +1,266 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import type { GrepMatch } from '../types' +import { isBinaryFile, MAX_GREP_MATCHES, MAX_LINE_LENGTH, runRipgrep, validatePath } from '../types' + +// Schema definition +export const GrepToolSchema = z.object({ + pattern: z.string().describe('The regex pattern to search for in file contents'), + path: z + .string() + .optional() + .describe('The directory to search in (must be absolute path). Defaults to the base directory'), + include: z.string().optional().describe('File pattern to include in the search (e.g. "*.js", "*.{ts,tsx}")') +}) + +// Tool definition with detailed description +export const grepToolDefinition = { + name: 'grep', + description: `Fast content search tool that works with any codebase size. + +- Searches file contents using regular expressions +- Supports full regex syntax (e.g., "log.*Error", "function\\s+\\w+") +- Filter files by pattern with include (e.g., "*.js", "*.{ts,tsx}") +- Returns absolute file paths and line numbers with matching content +- Results are limited to 100 matches +- Binary files are automatically skipped +- Common directories (node_modules, .git, dist) are excluded +- The path parameter must be an absolute path if specified +- If path is not specified, defaults to the base directory`, + inputSchema: z.toJSONSchema(GrepToolSchema) +} + +// Handler implementation +export async function handleGrepTool(args: unknown, baseDir: string) { + const parsed = GrepToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for grep: ${parsed.error}`) + } + + const data = parsed.data + + if (!data.pattern) { + throw new Error('Pattern is required for grep') + } + + const searchPath = data.path || baseDir + const validPath = await validatePath(searchPath, baseDir) + + const matches: GrepMatch[] = [] + let truncated = false + let regex: RegExp + + // Build ripgrep arguments + const rgArgs: string[] = [ + '--no-heading', + '--line-number', + '--color', + 'never', + '--ignore-case', + '--glob', + '!.git/**', + '--glob', + '!node_modules/**', + '--glob', + '!dist/**', + '--glob', + '!build/**', + '--glob', + '!__pycache__/**' + ] + + if (data.include) { + for (const pat of data.include + .split(',') + .map((p) => p.trim()) + .filter(Boolean)) { + rgArgs.push('--glob', pat) + } + } + + rgArgs.push(data.pattern) + rgArgs.push(validPath) + + try { + regex = new RegExp(data.pattern, 'gi') + } catch (error) { + throw new Error(`Invalid regex pattern: ${data.pattern}`) + } + + async function searchFile(filePath: string): Promise { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + return + } + + try { + // Skip binary files + if (await isBinaryFile(filePath)) { + return + } + + const content = await fs.readFile(filePath, 'utf-8') + const lines = content.split('\n') + + lines.forEach((line, index) => { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + return + } + + if (regex.test(line)) { + // Truncate long lines + const truncatedLine = line.length > MAX_LINE_LENGTH ? line.substring(0, MAX_LINE_LENGTH) + '...' : line + + matches.push({ + file: filePath, + line: index + 1, + content: truncatedLine.trim() + }) + } + }) + } catch (error) { + // Skip files we can't read + } + } + + async function searchDirectory(dir: string): Promise { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + return + } + + try { + const entries = await fs.readdir(dir, { withFileTypes: true }) + + for (const entry of entries) { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + break + } + + const fullPath = path.join(dir, entry.name) + + // Skip common ignore patterns + if (entry.name.startsWith('.') && entry.name !== '.env.example') { + continue + } + if (['node_modules', 'dist', 'build', '__pycache__', '.git'].includes(entry.name)) { + continue + } + + if (entry.isFile()) { + // Check if file matches include pattern + if (data.include) { + const includePatterns = data.include.split(',').map((p) => p.trim()) + const fileName = path.basename(fullPath) + const matchesInclude = includePatterns.some((pattern) => { + // Simple glob pattern matching + const regexPattern = pattern + .replace(/\*/g, '.*') + .replace(/\?/g, '.') + .replace(/\{([^}]+)\}/g, (_, group) => `(${group.split(',').join('|')})`) + return new RegExp(`^${regexPattern}$`).test(fileName) + }) + if (!matchesInclude) { + continue + } + } + + await searchFile(fullPath) + } else if (entry.isDirectory()) { + await searchDirectory(fullPath) + } + } + } catch (error) { + // Skip directories we can't read + } + } + + // Perform the search + let usedRipgrep = false + try { + const rgResult = await runRipgrep(rgArgs) + if (rgResult.ok && rgResult.exitCode !== null && rgResult.exitCode !== 2) { + usedRipgrep = true + const lines = rgResult.stdout.split('\n').filter(Boolean) + for (const line of lines) { + if (matches.length >= MAX_GREP_MATCHES) { + truncated = true + break + } + + const firstColon = line.indexOf(':') + const secondColon = line.indexOf(':', firstColon + 1) + if (firstColon === -1 || secondColon === -1) continue + + const filePart = line.slice(0, firstColon) + const linePart = line.slice(firstColon + 1, secondColon) + const contentPart = line.slice(secondColon + 1) + const lineNum = Number.parseInt(linePart, 10) + if (!Number.isFinite(lineNum)) continue + + const absoluteFilePath = path.isAbsolute(filePart) ? filePart : path.resolve(baseDir, filePart) + const truncatedLine = + contentPart.length > MAX_LINE_LENGTH ? contentPart.substring(0, MAX_LINE_LENGTH) + '...' : contentPart + + matches.push({ + file: absoluteFilePath, + line: lineNum, + content: truncatedLine.trim() + }) + } + } + } catch { + usedRipgrep = false + } + + if (!usedRipgrep) { + const stats = await fs.stat(validPath) + if (stats.isFile()) { + await searchFile(validPath) + } else { + await searchDirectory(validPath) + } + } + + // Format output + const output: string[] = [] + + if (matches.length === 0) { + output.push('No matches found') + } else { + // Group matches by file + const fileGroups = new Map() + matches.forEach((match) => { + if (!fileGroups.has(match.file)) { + fileGroups.set(match.file, []) + } + fileGroups.get(match.file)!.push(match) + }) + + // Format grouped matches - always use absolute paths + fileGroups.forEach((fileMatches, filePath) => { + output.push(`\n${filePath}:`) + fileMatches.forEach((match) => { + output.push(` ${match.line}: ${match.content}`) + }) + }) + + if (truncated) { + output.push('') + output.push(`(Results truncated to ${MAX_GREP_MATCHES} matches. Consider using a more specific pattern or path.)`) + } + } + + return { + content: [ + { + type: 'text', + text: output.join('\n') + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/index.ts b/src/main/mcpServers/filesystem/tools/index.ts new file mode 100644 index 0000000000..2e02d613c4 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/index.ts @@ -0,0 +1,8 @@ +// Export all tool definitions and handlers +export { deleteToolDefinition, handleDeleteTool } from './delete' +export { editToolDefinition, handleEditTool } from './edit' +export { globToolDefinition, handleGlobTool } from './glob' +export { grepToolDefinition, handleGrepTool } from './grep' +export { handleLsTool, lsToolDefinition } from './ls' +export { handleReadTool, readToolDefinition } from './read' +export { handleWriteTool, writeToolDefinition } from './write' diff --git a/src/main/mcpServers/filesystem/tools/ls.ts b/src/main/mcpServers/filesystem/tools/ls.ts new file mode 100644 index 0000000000..22672c9fb9 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/ls.ts @@ -0,0 +1,150 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { MAX_FILES_LIMIT, validatePath } from '../types' + +// Schema definition +export const LsToolSchema = z.object({ + path: z.string().optional().describe('The directory to list (must be absolute path). Defaults to the base directory'), + recursive: z.boolean().optional().describe('Whether to list directories recursively (default: false)') +}) + +// Tool definition with detailed description +export const lsToolDefinition = { + name: 'ls', + description: `Lists files and directories in a specified path. + +- Returns a tree-like structure with icons (📁 directories, 📄 files) +- Shows the absolute directory path in the header +- Entries are sorted alphabetically with directories first +- Can list recursively with recursive=true (up to 5 levels deep) +- Common directories (node_modules, dist, .git) are excluded +- Hidden files (starting with .) are excluded except .env.example +- Results are limited to 100 entries +- The path parameter must be an absolute path if specified +- If path is not specified, defaults to the base directory`, + inputSchema: z.toJSONSchema(LsToolSchema) +} + +// Handler implementation +export async function handleLsTool(args: unknown, baseDir: string) { + const parsed = LsToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for ls: ${parsed.error}`) + } + + const targetPath = parsed.data.path || baseDir + const validPath = await validatePath(targetPath, baseDir) + const recursive = parsed.data.recursive || false + + interface TreeNode { + name: string + type: 'file' | 'directory' + children?: TreeNode[] + } + + let fileCount = 0 + let truncated = false + + async function buildTree(dirPath: string, depth: number = 0): Promise { + if (fileCount >= MAX_FILES_LIMIT) { + truncated = true + return [] + } + + try { + const entries = await fs.readdir(dirPath, { withFileTypes: true }) + const nodes: TreeNode[] = [] + + // Sort entries: directories first, then files, alphabetically + entries.sort((a, b) => { + if (a.isDirectory() && !b.isDirectory()) return -1 + if (!a.isDirectory() && b.isDirectory()) return 1 + return a.name.localeCompare(b.name) + }) + + for (const entry of entries) { + if (fileCount >= MAX_FILES_LIMIT) { + truncated = true + break + } + + // Skip hidden files and common ignore patterns + if (entry.name.startsWith('.') && entry.name !== '.env.example') { + continue + } + if (['node_modules', 'dist', 'build', '__pycache__'].includes(entry.name)) { + continue + } + + fileCount++ + const node: TreeNode = { + name: entry.name, + type: entry.isDirectory() ? 'directory' : 'file' + } + + if (entry.isDirectory() && recursive && depth < 5) { + // Limit depth to prevent infinite recursion + const childPath = path.join(dirPath, entry.name) + node.children = await buildTree(childPath, depth + 1) + } + + nodes.push(node) + } + + return nodes + } catch (error) { + return [] + } + } + + // Build the tree + const tree = await buildTree(validPath) + + // Format as text output + function formatTree(nodes: TreeNode[], prefix: string = ''): string[] { + const lines: string[] = [] + + nodes.forEach((node, index) => { + const isLastNode = index === nodes.length - 1 + const connector = isLastNode ? '└── ' : '├── ' + const icon = node.type === 'directory' ? '📁 ' : '📄 ' + + lines.push(prefix + connector + icon + node.name) + + if (node.children && node.children.length > 0) { + const childPrefix = prefix + (isLastNode ? ' ' : '│ ') + lines.push(...formatTree(node.children, childPrefix)) + } + }) + + return lines + } + + // Generate output + const output: string[] = [] + output.push(`Directory: ${validPath}`) + output.push('') + + if (tree.length === 0) { + output.push('(empty directory)') + } else { + const treeLines = formatTree(tree, '') + output.push(...treeLines) + + if (truncated) { + output.push('') + output.push(`(Results truncated to ${MAX_FILES_LIMIT} files. Consider listing a more specific directory.)`) + } + } + + return { + content: [ + { + type: 'text', + text: output.join('\n') + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/read.ts b/src/main/mcpServers/filesystem/tools/read.ts new file mode 100644 index 0000000000..460c88dda4 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/read.ts @@ -0,0 +1,101 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { DEFAULT_READ_LIMIT, isBinaryFile, MAX_LINE_LENGTH, validatePath } from '../types' + +// Schema definition +export const ReadToolSchema = z.object({ + file_path: z.string().describe('The path to the file to read'), + offset: z.number().optional().describe('The line number to start reading from (1-based)'), + limit: z.number().optional().describe('The number of lines to read (defaults to 2000)') +}) + +// Tool definition with detailed description +export const readToolDefinition = { + name: 'read', + description: `Reads a file from the local filesystem. + +- Assumes this tool can read all files on the machine +- The file_path parameter must be an absolute path, not a relative path +- By default, reads up to 2000 lines starting from the beginning +- You can optionally specify a line offset and limit for long files +- Any lines longer than 2000 characters will be truncated +- Results are returned with line numbers starting at 1 +- Binary files are detected and rejected with an error +- Empty files return a warning`, + inputSchema: z.toJSONSchema(ReadToolSchema) +} + +// Handler implementation +export async function handleReadTool(args: unknown, baseDir: string) { + const parsed = ReadToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for read: ${parsed.error}`) + } + + const filePath = parsed.data.file_path + const validPath = await validatePath(filePath, baseDir) + + // Check if file exists + try { + const stats = await fs.stat(validPath) + if (!stats.isFile()) { + throw new Error(`Path is not a file: ${filePath}`) + } + } catch (error: any) { + if (error.code === 'ENOENT') { + throw new Error(`File not found: ${filePath}`) + } + throw error + } + + // Check if file is binary + if (await isBinaryFile(validPath)) { + throw new Error(`Cannot read binary file: ${filePath}`) + } + + // Read file content + const content = await fs.readFile(validPath, 'utf-8') + const lines = content.split('\n') + + // Apply offset and limit + const offset = (parsed.data.offset || 1) - 1 // Convert to 0-based + const limit = parsed.data.limit || DEFAULT_READ_LIMIT + + if (offset < 0 || offset >= lines.length) { + throw new Error(`Invalid offset: ${offset + 1}. File has ${lines.length} lines.`) + } + + const selectedLines = lines.slice(offset, offset + limit) + + // Format output with line numbers and truncate long lines + const output: string[] = [] + const relativePath = path.relative(baseDir, validPath) + + output.push(`File: ${relativePath}`) + if (offset > 0 || limit < lines.length) { + output.push(`Lines ${offset + 1} to ${Math.min(offset + limit, lines.length)} of ${lines.length}`) + } + output.push('') + + selectedLines.forEach((line, index) => { + const lineNumber = offset + index + 1 + const truncatedLine = line.length > MAX_LINE_LENGTH ? line.substring(0, MAX_LINE_LENGTH) + '...' : line + output.push(`${lineNumber.toString().padStart(6)}\t${truncatedLine}`) + }) + + if (offset + limit < lines.length) { + output.push('') + output.push(`(${lines.length - (offset + limit)} more lines not shown)`) + } + + return { + content: [ + { + type: 'text', + text: output.join('\n') + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/tools/write.ts b/src/main/mcpServers/filesystem/tools/write.ts new file mode 100644 index 0000000000..2898f2f874 --- /dev/null +++ b/src/main/mcpServers/filesystem/tools/write.ts @@ -0,0 +1,83 @@ +import fs from 'fs/promises' +import path from 'path' +import * as z from 'zod' + +import { logger, validatePath } from '../types' + +// Schema definition +export const WriteToolSchema = z.object({ + file_path: z.string().describe('The path to the file to write'), + content: z.string().describe('The content to write to the file') +}) + +// Tool definition with detailed description +export const writeToolDefinition = { + name: 'write', + description: `Writes a file to the local filesystem. + +- This tool will overwrite the existing file if one exists at the path +- You MUST use the read tool first to understand what you're overwriting +- ALWAYS prefer using the 'edit' tool for existing files +- NEVER proactively create documentation files unless explicitly requested +- Parent directories will be created automatically if they don't exist +- The file_path must be an absolute path, not a relative path`, + inputSchema: z.toJSONSchema(WriteToolSchema) +} + +// Handler implementation +export async function handleWriteTool(args: unknown, baseDir: string) { + const parsed = WriteToolSchema.safeParse(args) + if (!parsed.success) { + throw new Error(`Invalid arguments for write: ${parsed.error}`) + } + + const filePath = parsed.data.file_path + const validPath = await validatePath(filePath, baseDir) + + // Create parent directory if it doesn't exist + const parentDir = path.dirname(validPath) + try { + await fs.mkdir(parentDir, { recursive: true }) + } catch (error: any) { + if (error.code !== 'EEXIST') { + throw new Error(`Failed to create parent directory: ${error.message}`) + } + } + + // Check if file exists (for logging) + let isOverwrite = false + try { + await fs.stat(validPath) + isOverwrite = true + } catch { + // File doesn't exist, that's fine + } + + // Write the file + try { + await fs.writeFile(validPath, parsed.data.content, 'utf-8') + } catch (error: any) { + throw new Error(`Failed to write file: ${error.message}`) + } + + // Log the operation + logger.info('File written', { + path: validPath, + overwrite: isOverwrite, + size: parsed.data.content.length + }) + + // Format output + const relativePath = path.relative(baseDir, validPath) + const action = isOverwrite ? 'Updated' : 'Created' + const lines = parsed.data.content.split('\n').length + + return { + content: [ + { + type: 'text', + text: `${action} file: ${relativePath}\n` + `Size: ${parsed.data.content.length} bytes\n` + `Lines: ${lines}` + } + ] + } +} diff --git a/src/main/mcpServers/filesystem/types.ts b/src/main/mcpServers/filesystem/types.ts new file mode 100644 index 0000000000..922fe0b23a --- /dev/null +++ b/src/main/mcpServers/filesystem/types.ts @@ -0,0 +1,627 @@ +import { loggerService } from '@logger' +import { isMac, isWin } from '@main/constant' +import { spawn } from 'child_process' +import fs from 'fs/promises' +import os from 'os' +import path from 'path' + +export const logger = loggerService.withContext('MCP:FileSystemServer') + +// Constants +export const MAX_LINE_LENGTH = 2000 +export const DEFAULT_READ_LIMIT = 2000 +export const MAX_FILES_LIMIT = 100 +export const MAX_GREP_MATCHES = 100 + +// Common types +export interface FileInfo { + path: string + type: 'file' | 'directory' + size?: number + modified?: Date +} + +export interface GrepMatch { + file: string + line: number + content: string +} + +// Utility functions for path handling +export function normalizePath(p: string): string { + return path.normalize(p) +} + +export function expandHome(filepath: string): string { + if (filepath.startsWith('~/') || filepath === '~') { + return path.join(os.homedir(), filepath.slice(1)) + } + return filepath +} + +// Security validation +export async function validatePath(requestedPath: string, baseDir?: string): Promise { + const expandedPath = expandHome(requestedPath) + const root = baseDir ?? process.cwd() + const absolute = path.isAbsolute(expandedPath) ? path.resolve(expandedPath) : path.resolve(root, expandedPath) + + // Handle symlinks by checking their real path + try { + const realPath = await fs.realpath(absolute) + return normalizePath(realPath) + } catch (error) { + // For new files that don't exist yet, verify parent directory + const parentDir = path.dirname(absolute) + try { + const realParentPath = await fs.realpath(parentDir) + normalizePath(realParentPath) + return normalizePath(absolute) + } catch { + return normalizePath(absolute) + } + } +} + +// ============================================================================ +// Edit Tool Utilities - Fuzzy matching replacers from opencode +// ============================================================================ + +export type Replacer = (content: string, find: string) => Generator + +// Similarity thresholds for block anchor fallback matching +const SINGLE_CANDIDATE_SIMILARITY_THRESHOLD = 0.0 +const MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD = 0.3 + +/** + * Levenshtein distance algorithm implementation + */ +function levenshtein(a: string, b: string): number { + if (a === '' || b === '') { + return Math.max(a.length, b.length) + } + const matrix = Array.from({ length: a.length + 1 }, (_, i) => + Array.from({ length: b.length + 1 }, (_, j) => (i === 0 ? j : j === 0 ? i : 0)) + ) + + for (let i = 1; i <= a.length; i++) { + for (let j = 1; j <= b.length; j++) { + const cost = a[i - 1] === b[j - 1] ? 0 : 1 + matrix[i][j] = Math.min(matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + cost) + } + } + return matrix[a.length][b.length] +} + +export const SimpleReplacer: Replacer = function* (_content, find) { + yield find +} + +export const LineTrimmedReplacer: Replacer = function* (content, find) { + const originalLines = content.split('\n') + const searchLines = find.split('\n') + + if (searchLines[searchLines.length - 1] === '') { + searchLines.pop() + } + + for (let i = 0; i <= originalLines.length - searchLines.length; i++) { + let matches = true + + for (let j = 0; j < searchLines.length; j++) { + const originalTrimmed = originalLines[i + j].trim() + const searchTrimmed = searchLines[j].trim() + + if (originalTrimmed !== searchTrimmed) { + matches = false + break + } + } + + if (matches) { + let matchStartIndex = 0 + for (let k = 0; k < i; k++) { + matchStartIndex += originalLines[k].length + 1 + } + + let matchEndIndex = matchStartIndex + for (let k = 0; k < searchLines.length; k++) { + matchEndIndex += originalLines[i + k].length + if (k < searchLines.length - 1) { + matchEndIndex += 1 + } + } + + yield content.substring(matchStartIndex, matchEndIndex) + } + } +} + +export const BlockAnchorReplacer: Replacer = function* (content, find) { + const originalLines = content.split('\n') + const searchLines = find.split('\n') + + if (searchLines.length < 3) { + return + } + + if (searchLines[searchLines.length - 1] === '') { + searchLines.pop() + } + + const firstLineSearch = searchLines[0].trim() + const lastLineSearch = searchLines[searchLines.length - 1].trim() + const searchBlockSize = searchLines.length + + const candidates: Array<{ startLine: number; endLine: number }> = [] + for (let i = 0; i < originalLines.length; i++) { + if (originalLines[i].trim() !== firstLineSearch) { + continue + } + + for (let j = i + 2; j < originalLines.length; j++) { + if (originalLines[j].trim() === lastLineSearch) { + candidates.push({ startLine: i, endLine: j }) + break + } + } + } + + if (candidates.length === 0) { + return + } + + if (candidates.length === 1) { + const { startLine, endLine } = candidates[0] + const actualBlockSize = endLine - startLine + 1 + + let similarity = 0 + const linesToCheck = Math.min(searchBlockSize - 2, actualBlockSize - 2) + + if (linesToCheck > 0) { + for (let j = 1; j < searchBlockSize - 1 && j < actualBlockSize - 1; j++) { + const originalLine = originalLines[startLine + j].trim() + const searchLine = searchLines[j].trim() + const maxLen = Math.max(originalLine.length, searchLine.length) + if (maxLen === 0) { + continue + } + const distance = levenshtein(originalLine, searchLine) + similarity += (1 - distance / maxLen) / linesToCheck + + if (similarity >= SINGLE_CANDIDATE_SIMILARITY_THRESHOLD) { + break + } + } + } else { + similarity = 1.0 + } + + if (similarity >= SINGLE_CANDIDATE_SIMILARITY_THRESHOLD) { + let matchStartIndex = 0 + for (let k = 0; k < startLine; k++) { + matchStartIndex += originalLines[k].length + 1 + } + let matchEndIndex = matchStartIndex + for (let k = startLine; k <= endLine; k++) { + matchEndIndex += originalLines[k].length + if (k < endLine) { + matchEndIndex += 1 + } + } + yield content.substring(matchStartIndex, matchEndIndex) + } + return + } + + let bestMatch: { startLine: number; endLine: number } | null = null + let maxSimilarity = -1 + + for (const candidate of candidates) { + const { startLine, endLine } = candidate + const actualBlockSize = endLine - startLine + 1 + + let similarity = 0 + const linesToCheck = Math.min(searchBlockSize - 2, actualBlockSize - 2) + + if (linesToCheck > 0) { + for (let j = 1; j < searchBlockSize - 1 && j < actualBlockSize - 1; j++) { + const originalLine = originalLines[startLine + j].trim() + const searchLine = searchLines[j].trim() + const maxLen = Math.max(originalLine.length, searchLine.length) + if (maxLen === 0) { + continue + } + const distance = levenshtein(originalLine, searchLine) + similarity += 1 - distance / maxLen + } + similarity /= linesToCheck + } else { + similarity = 1.0 + } + + if (similarity > maxSimilarity) { + maxSimilarity = similarity + bestMatch = candidate + } + } + + if (maxSimilarity >= MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD && bestMatch) { + const { startLine, endLine } = bestMatch + let matchStartIndex = 0 + for (let k = 0; k < startLine; k++) { + matchStartIndex += originalLines[k].length + 1 + } + let matchEndIndex = matchStartIndex + for (let k = startLine; k <= endLine; k++) { + matchEndIndex += originalLines[k].length + if (k < endLine) { + matchEndIndex += 1 + } + } + yield content.substring(matchStartIndex, matchEndIndex) + } +} + +export const WhitespaceNormalizedReplacer: Replacer = function* (content, find) { + const normalizeWhitespace = (text: string) => text.replace(/\s+/g, ' ').trim() + const normalizedFind = normalizeWhitespace(find) + + const lines = content.split('\n') + for (let i = 0; i < lines.length; i++) { + const line = lines[i] + if (normalizeWhitespace(line) === normalizedFind) { + yield line + } else { + const normalizedLine = normalizeWhitespace(line) + if (normalizedLine.includes(normalizedFind)) { + const words = find.trim().split(/\s+/) + if (words.length > 0) { + const pattern = words.map((word) => word.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')).join('\\s+') + try { + const regex = new RegExp(pattern) + const match = line.match(regex) + if (match) { + yield match[0] + } + } catch { + // Invalid regex pattern, skip + } + } + } + } + } + + const findLines = find.split('\n') + if (findLines.length > 1) { + for (let i = 0; i <= lines.length - findLines.length; i++) { + const block = lines.slice(i, i + findLines.length) + if (normalizeWhitespace(block.join('\n')) === normalizedFind) { + yield block.join('\n') + } + } + } +} + +export const IndentationFlexibleReplacer: Replacer = function* (content, find) { + const removeIndentation = (text: string) => { + const lines = text.split('\n') + const nonEmptyLines = lines.filter((line) => line.trim().length > 0) + if (nonEmptyLines.length === 0) return text + + const minIndent = Math.min( + ...nonEmptyLines.map((line) => { + const match = line.match(/^(\s*)/) + return match ? match[1].length : 0 + }) + ) + + return lines.map((line) => (line.trim().length === 0 ? line : line.slice(minIndent))).join('\n') + } + + const normalizedFind = removeIndentation(find) + const contentLines = content.split('\n') + const findLines = find.split('\n') + + for (let i = 0; i <= contentLines.length - findLines.length; i++) { + const block = contentLines.slice(i, i + findLines.length).join('\n') + if (removeIndentation(block) === normalizedFind) { + yield block + } + } +} + +export const EscapeNormalizedReplacer: Replacer = function* (content, find) { + const unescapeString = (str: string): string => { + return str.replace(/\\(n|t|r|'|"|`|\\|\n|\$)/g, (match, capturedChar) => { + switch (capturedChar) { + case 'n': + return '\n' + case 't': + return '\t' + case 'r': + return '\r' + case "'": + return "'" + case '"': + return '"' + case '`': + return '`' + case '\\': + return '\\' + case '\n': + return '\n' + case '$': + return '$' + default: + return match + } + }) + } + + const unescapedFind = unescapeString(find) + + if (content.includes(unescapedFind)) { + yield unescapedFind + } + + const lines = content.split('\n') + const findLines = unescapedFind.split('\n') + + for (let i = 0; i <= lines.length - findLines.length; i++) { + const block = lines.slice(i, i + findLines.length).join('\n') + const unescapedBlock = unescapeString(block) + + if (unescapedBlock === unescapedFind) { + yield block + } + } +} + +export const TrimmedBoundaryReplacer: Replacer = function* (content, find) { + const trimmedFind = find.trim() + + if (trimmedFind === find) { + return + } + + if (content.includes(trimmedFind)) { + yield trimmedFind + } + + const lines = content.split('\n') + const findLines = find.split('\n') + + for (let i = 0; i <= lines.length - findLines.length; i++) { + const block = lines.slice(i, i + findLines.length).join('\n') + + if (block.trim() === trimmedFind) { + yield block + } + } +} + +export const ContextAwareReplacer: Replacer = function* (content, find) { + const findLines = find.split('\n') + if (findLines.length < 3) { + return + } + + if (findLines[findLines.length - 1] === '') { + findLines.pop() + } + + const contentLines = content.split('\n') + + const firstLine = findLines[0].trim() + const lastLine = findLines[findLines.length - 1].trim() + + for (let i = 0; i < contentLines.length; i++) { + if (contentLines[i].trim() !== firstLine) continue + + for (let j = i + 2; j < contentLines.length; j++) { + if (contentLines[j].trim() === lastLine) { + const blockLines = contentLines.slice(i, j + 1) + const block = blockLines.join('\n') + + if (blockLines.length === findLines.length) { + let matchingLines = 0 + let totalNonEmptyLines = 0 + + for (let k = 1; k < blockLines.length - 1; k++) { + const blockLine = blockLines[k].trim() + const findLine = findLines[k].trim() + + if (blockLine.length > 0 || findLine.length > 0) { + totalNonEmptyLines++ + if (blockLine === findLine) { + matchingLines++ + } + } + } + + if (totalNonEmptyLines === 0 || matchingLines / totalNonEmptyLines >= 0.5) { + yield block + break + } + } + break + } + } + } +} + +export const MultiOccurrenceReplacer: Replacer = function* (content, find) { + let startIndex = 0 + + while (true) { + const index = content.indexOf(find, startIndex) + if (index === -1) break + + yield find + startIndex = index + find.length + } +} + +/** + * All replacers in order of specificity + */ +export const ALL_REPLACERS: Replacer[] = [ + SimpleReplacer, + LineTrimmedReplacer, + BlockAnchorReplacer, + WhitespaceNormalizedReplacer, + IndentationFlexibleReplacer, + EscapeNormalizedReplacer, + TrimmedBoundaryReplacer, + ContextAwareReplacer, + MultiOccurrenceReplacer +] + +/** + * Replace oldString with newString in content using fuzzy matching + */ +export function replaceWithFuzzyMatch( + content: string, + oldString: string, + newString: string, + replaceAll = false +): string { + if (oldString === newString) { + throw new Error('old_string and new_string must be different') + } + + let notFound = true + + for (const replacer of ALL_REPLACERS) { + for (const search of replacer(content, oldString)) { + const index = content.indexOf(search) + if (index === -1) continue + notFound = false + if (replaceAll) { + return content.replaceAll(search, newString) + } + const lastIndex = content.lastIndexOf(search) + if (index !== lastIndex) continue + return content.substring(0, index) + newString + content.substring(index + search.length) + } + } + + if (notFound) { + throw new Error('old_string not found in content') + } + throw new Error( + 'Found multiple matches for old_string. Provide more surrounding lines in old_string to identify the correct match.' + ) +} + +// ============================================================================ +// Binary File Detection +// ============================================================================ + +// Check if a file is likely binary +export async function isBinaryFile(filePath: string): Promise { + try { + const buffer = Buffer.alloc(4096) + const fd = await fs.open(filePath, 'r') + const { bytesRead } = await fd.read(buffer, 0, buffer.length, 0) + await fd.close() + + if (bytesRead === 0) return false + + const view = buffer.subarray(0, bytesRead) + + let zeroBytes = 0 + let evenZeros = 0 + let oddZeros = 0 + let nonPrintable = 0 + + for (let i = 0; i < view.length; i++) { + const b = view[i] + + if (b === 0) { + zeroBytes++ + if (i % 2 === 0) evenZeros++ + else oddZeros++ + continue + } + + // treat common whitespace as printable + if (b === 9 || b === 10 || b === 13) continue + + // basic ASCII printable range + if (b >= 32 && b <= 126) continue + + // bytes >= 128 are likely part of UTF-8 sequences; count as printable + if (b >= 128) continue + + nonPrintable++ + } + + // If there are lots of null bytes, it's probably binary unless it looks like UTF-16 text. + if (zeroBytes > 0) { + const evenSlots = Math.ceil(view.length / 2) + const oddSlots = Math.floor(view.length / 2) + const evenZeroRatio = evenSlots > 0 ? evenZeros / evenSlots : 0 + const oddZeroRatio = oddSlots > 0 ? oddZeros / oddSlots : 0 + + // UTF-16LE/BE tends to have zeros on every other byte. + if (evenZeroRatio > 0.7 || oddZeroRatio > 0.7) return false + + if (zeroBytes / view.length > 0.05) return true + } + + // Heuristic: too many non-printable bytes => binary. + return nonPrintable / view.length > 0.3 + } catch { + return false + } +} + +// ============================================================================ +// Ripgrep Utilities +// ============================================================================ + +export interface RipgrepResult { + ok: boolean + stdout: string + exitCode: number | null +} + +export function getRipgrepAddonPath(): string { + const pkgJsonPath = require.resolve('@anthropic-ai/claude-agent-sdk/package.json') + const pkgRoot = path.dirname(pkgJsonPath) + const platform = isMac ? 'darwin' : isWin ? 'win32' : 'linux' + const arch = process.arch === 'arm64' ? 'arm64' : 'x64' + return path.join(pkgRoot, 'vendor', 'ripgrep', `${arch}-${platform}`, 'ripgrep.node') +} + +export async function runRipgrep(args: string[]): Promise { + const addonPath = getRipgrepAddonPath() + const childScript = `const { ripgrepMain } = require(process.env.RIPGREP_ADDON_PATH); process.exit(ripgrepMain(process.argv.slice(1)));` + + return new Promise((resolve) => { + const child = spawn(process.execPath, ['--eval', childScript, 'rg', ...args], { + cwd: process.cwd(), + env: { + ...process.env, + ELECTRON_RUN_AS_NODE: '1', + RIPGREP_ADDON_PATH: addonPath + }, + stdio: ['ignore', 'pipe', 'pipe'] + }) + + let stdout = '' + + child.stdout?.on('data', (chunk) => { + stdout += chunk.toString('utf-8') + }) + + child.on('error', () => { + resolve({ ok: false, stdout: '', exitCode: null }) + }) + + child.on('close', (code) => { + resolve({ ok: true, stdout, exitCode: code }) + }) + }) +} diff --git a/src/main/services/ConfigManager.ts b/src/main/services/ConfigManager.ts index c693d4b05a..6f2bbd44a4 100644 --- a/src/main/services/ConfigManager.ts +++ b/src/main/services/ConfigManager.ts @@ -32,7 +32,8 @@ export enum ConfigKeys { Proxy = 'proxy', EnableDeveloperMode = 'enableDeveloperMode', ClientId = 'clientId', - GitBashPath = 'gitBashPath' + GitBashPath = 'gitBashPath', + GitBashPathSource = 'gitBashPathSource' // 'manual' | 'auto' | null } export class ConfigManager { diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index cc6bbaa366..ebdc2247fc 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -249,6 +249,26 @@ class McpService { StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport > => { // Create appropriate transport based on configuration + + // Special case for nowledgeMem - uses HTTP transport instead of in-memory + if (isBuiltinMCPServer(server) && server.name === BuiltinMCPServerNames.nowledgeMem) { + const nowledgeMemUrl = 'http://127.0.0.1:14242/mcp' + const options: StreamableHTTPClientTransportOptions = { + fetch: async (url, init) => { + return net.fetch(typeof url === 'string' ? url : url.toString(), init) + }, + requestInit: { + headers: { + ...defaultAppHeaders(), + APP: 'Cherry Studio' + } + }, + authProvider + } + getServerLogger(server).debug(`Using StreamableHTTPClientTransport for ${server.name}`) + return new StreamableHTTPClientTransport(new URL(nowledgeMemUrl), options) + } + if (isBuiltinMCPServer(server) && server.name !== BuiltinMCPServerNames.mcpAutoInstall) { getServerLogger(server).debug(`Using in-memory transport`) const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair() diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts index 689c177ff5..50dd5a6d3d 100644 --- a/src/main/services/agents/services/claudecode/index.ts +++ b/src/main/services/agents/services/claudecode/index.ts @@ -15,8 +15,8 @@ import { query } from '@anthropic-ai/claude-agent-sdk' import { loggerService } from '@logger' import { config as apiConfigService } from '@main/apiServer/config' import { validateModelId } from '@main/apiServer/utils' -import { ConfigKeys, configManager } from '@main/services/ConfigManager' -import { validateGitBashPath } from '@main/utils/process' +import { isWin } from '@main/constant' +import { autoDiscoverGitBash } from '@main/utils/process' import getLoginShellEnvironment from '@main/utils/shell-env' import { app } from 'electron' @@ -105,7 +105,8 @@ class ClaudeCodeService implements AgentServiceInterface { Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy')) ) as Record - const customGitBashPath = validateGitBashPath(configManager.get(ConfigKeys.GitBashPath) as string | undefined) + // Auto-discover Git Bash path on Windows (already logs internally) + const customGitBashPath = isWin ? autoDiscoverGitBash() : null // Route through local API Server which handles format conversion via unified adapter // This enables Claude Code Agent to work with any provider (OpenAI, Gemini, etc.) diff --git a/src/main/utils/__tests__/process.test.ts b/src/main/utils/__tests__/process.test.ts index 0485ec5fad..a1ac2fd9a5 100644 --- a/src/main/utils/__tests__/process.test.ts +++ b/src/main/utils/__tests__/process.test.ts @@ -1,9 +1,21 @@ +import { configManager } from '@main/services/ConfigManager' import { execFileSync } from 'child_process' import fs from 'fs' import path from 'path' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { findExecutable, findGitBash, validateGitBashPath } from '../process' +import { autoDiscoverGitBash, findExecutable, findGitBash, validateGitBashPath } from '../process' + +// Mock configManager +vi.mock('@main/services/ConfigManager', () => ({ + ConfigKeys: { + GitBashPath: 'gitBashPath' + }, + configManager: { + get: vi.fn(), + set: vi.fn() + } +})) // Mock dependencies vi.mock('child_process') @@ -695,4 +707,284 @@ describe.skipIf(process.platform !== 'win32')('process utilities', () => { }) }) }) + + describe('autoDiscoverGitBash', () => { + const originalEnvVar = process.env.CLAUDE_CODE_GIT_BASH_PATH + + beforeEach(() => { + vi.mocked(configManager.get).mockReset() + vi.mocked(configManager.set).mockReset() + delete process.env.CLAUDE_CODE_GIT_BASH_PATH + }) + + afterEach(() => { + // Restore original environment variable + if (originalEnvVar !== undefined) { + process.env.CLAUDE_CODE_GIT_BASH_PATH = originalEnvVar + } else { + delete process.env.CLAUDE_CODE_GIT_BASH_PATH + } + }) + + /** + * Helper to mock fs.existsSync with a set of valid paths + */ + const mockExistingPaths = (...validPaths: string[]) => { + vi.mocked(fs.existsSync).mockImplementation((p) => validPaths.includes(p as string)) + } + + describe('with no existing config path', () => { + it('should discover and persist Git Bash path when not configured', () => { + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + process.env.ProgramFiles = 'C:\\Program Files' + mockExistingPaths(gitPath, bashPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(bashPath) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath) + }) + + it('should return null and not persist when Git Bash is not found', () => { + vi.mocked(configManager.get).mockReturnValue(undefined) + vi.mocked(fs.existsSync).mockReturnValue(false) + vi.mocked(execFileSync).mockImplementation(() => { + throw new Error('Not found') + }) + + const result = autoDiscoverGitBash() + + expect(result).toBeNull() + expect(configManager.set).not.toHaveBeenCalled() + }) + }) + + describe('environment variable precedence', () => { + it('should use env var over valid config path', () => { + const envPath = 'C:\\EnvGit\\bin\\bash.exe' + const configPath = 'C:\\ConfigGit\\bin\\bash.exe' + + process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath + vi.mocked(configManager.get).mockReturnValue(configPath) + mockExistingPaths(envPath, configPath) + + const result = autoDiscoverGitBash() + + // Env var should take precedence + expect(result).toBe(envPath) + // Should not persist env var path (it's a runtime override) + expect(configManager.set).not.toHaveBeenCalled() + }) + + it('should fall back to config path when env var is invalid', () => { + const envPath = 'C:\\Invalid\\bash.exe' + const configPath = 'C:\\ConfigGit\\bin\\bash.exe' + + process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath + vi.mocked(configManager.get).mockReturnValue(configPath) + // Env path is invalid (doesn't exist), only config path exists + mockExistingPaths(configPath) + + const result = autoDiscoverGitBash() + + // Should fall back to config path + expect(result).toBe(configPath) + expect(configManager.set).not.toHaveBeenCalled() + }) + + it('should fall back to auto-discovery when both env var and config are invalid', () => { + const envPath = 'C:\\InvalidEnv\\bash.exe' + const configPath = 'C:\\InvalidConfig\\bash.exe' + const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + process.env.CLAUDE_CODE_GIT_BASH_PATH = envPath + process.env.ProgramFiles = 'C:\\Program Files' + vi.mocked(configManager.get).mockReturnValue(configPath) + // Both env and config paths are invalid, only standard Git exists + mockExistingPaths(gitPath, discoveredPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(discoveredPath) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', discoveredPath) + }) + }) + + describe('with valid existing config path', () => { + it('should validate and return existing path without re-discovering', () => { + const existingPath = 'C:\\CustomGit\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + mockExistingPaths(existingPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(existingPath) + // Should not call findGitBash or persist again + expect(configManager.set).not.toHaveBeenCalled() + // Should not call execFileSync (which findGitBash would use for discovery) + expect(execFileSync).not.toHaveBeenCalled() + }) + + it('should not override existing valid config with auto-discovery', () => { + const existingPath = 'C:\\CustomGit\\bin\\bash.exe' + const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + mockExistingPaths(existingPath, discoveredPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(existingPath) + expect(configManager.set).not.toHaveBeenCalled() + }) + }) + + describe('with invalid existing config path', () => { + it('should attempt auto-discovery when existing path does not exist', () => { + const existingPath = 'C:\\NonExistent\\bin\\bash.exe' + const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + process.env.ProgramFiles = 'C:\\Program Files' + // Invalid path doesn't exist, but Git is installed at standard location + mockExistingPaths(gitPath, discoveredPath) + + const result = autoDiscoverGitBash() + + // Should discover and return the new path + expect(result).toBe(discoveredPath) + // Should persist the discovered path (overwrites invalid) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', discoveredPath) + }) + + it('should attempt auto-discovery when existing path is not bash.exe', () => { + const existingPath = 'C:\\CustomGit\\bin\\git.exe' + const discoveredPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + process.env.ProgramFiles = 'C:\\Program Files' + // Invalid path exists but is not bash.exe (validation will fail) + // Git is installed at standard location + mockExistingPaths(existingPath, gitPath, discoveredPath) + + const result = autoDiscoverGitBash() + + // Should discover and return the new path + expect(result).toBe(discoveredPath) + // Should persist the discovered path (overwrites invalid) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', discoveredPath) + }) + + it('should return null when existing path is invalid and discovery fails', () => { + const existingPath = 'C:\\NonExistent\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(existingPath) + vi.mocked(fs.existsSync).mockReturnValue(false) + vi.mocked(execFileSync).mockImplementation(() => { + throw new Error('Not found') + }) + + const result = autoDiscoverGitBash() + + // Both validation and discovery failed + expect(result).toBeNull() + // Should not persist when discovery fails + expect(configManager.set).not.toHaveBeenCalled() + }) + }) + + describe('config persistence verification', () => { + it('should persist discovered path with correct config key', () => { + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + process.env.ProgramFiles = 'C:\\Program Files' + mockExistingPaths(gitPath, bashPath) + + autoDiscoverGitBash() + + // Verify the exact call to configManager.set + expect(configManager.set).toHaveBeenCalledTimes(1) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath) + }) + + it('should persist on each discovery when config remains undefined', () => { + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + process.env.ProgramFiles = 'C:\\Program Files' + mockExistingPaths(gitPath, bashPath) + + autoDiscoverGitBash() + autoDiscoverGitBash() + + // Each call discovers and persists since config remains undefined (mocked) + expect(configManager.set).toHaveBeenCalledTimes(2) + }) + }) + + describe('real-world scenarios', () => { + it('should discover and persist standard Git for Windows installation', () => { + const gitPath = 'C:\\Program Files\\Git\\cmd\\git.exe' + const bashPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + process.env.ProgramFiles = 'C:\\Program Files' + mockExistingPaths(gitPath, bashPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(bashPath) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath) + }) + + it('should discover portable Git via where.exe and persist', () => { + const gitPath = 'D:\\PortableApps\\Git\\bin\\git.exe' + const bashPath = 'D:\\PortableApps\\Git\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(undefined) + + vi.mocked(fs.existsSync).mockImplementation((p) => { + const pathStr = p?.toString() || '' + // Common git paths don't exist + if (pathStr.includes('Program Files\\Git\\cmd\\git.exe')) return false + if (pathStr.includes('Program Files (x86)\\Git\\cmd\\git.exe')) return false + // Portable bash path exists + if (pathStr === bashPath) return true + return false + }) + + vi.mocked(execFileSync).mockReturnValue(gitPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(bashPath) + expect(configManager.set).toHaveBeenCalledWith('gitBashPath', bashPath) + }) + + it('should respect user-configured path over auto-discovery', () => { + const userConfiguredPath = 'D:\\MyGit\\bin\\bash.exe' + const systemPath = 'C:\\Program Files\\Git\\bin\\bash.exe' + + vi.mocked(configManager.get).mockReturnValue(userConfiguredPath) + mockExistingPaths(userConfiguredPath, systemPath) + + const result = autoDiscoverGitBash() + + expect(result).toBe(userConfiguredPath) + expect(configManager.set).not.toHaveBeenCalled() + // Verify findGitBash was not called for discovery + expect(execFileSync).not.toHaveBeenCalled() + }) + }) + }) }) diff --git a/src/main/utils/process.ts b/src/main/utils/process.ts index 7175af7e75..ccc0f66535 100644 --- a/src/main/utils/process.ts +++ b/src/main/utils/process.ts @@ -1,4 +1,5 @@ import { loggerService } from '@logger' +import type { GitBashPathInfo, GitBashPathSource } from '@shared/config/constant' import { HOME_CHERRY_DIR } from '@shared/config/constant' import { execFileSync, spawn } from 'child_process' import fs from 'fs' @@ -6,6 +7,7 @@ import os from 'os' import path from 'path' import { isWin } from '../constant' +import { ConfigKeys, configManager } from '../services/ConfigManager' import { getResourcePath } from '.' const logger = loggerService.withContext('Utils:Process') @@ -59,7 +61,7 @@ export async function getBinaryPath(name?: string): Promise { export async function isBinaryExists(name: string): Promise { const cmd = await getBinaryPath(name) - return await fs.existsSync(cmd) + return fs.existsSync(cmd) } /** @@ -225,3 +227,77 @@ export function validateGitBashPath(customPath?: string | null): string | null { logger.debug('Validated custom Git Bash path', { path: resolved }) return resolved } + +/** + * Auto-discover and persist Git Bash path if not already configured + * Only called when Git Bash is actually needed + * + * Precedence order: + * 1. CLAUDE_CODE_GIT_BASH_PATH environment variable (highest - runtime override) + * 2. Configured path from settings (manual or auto) + * 3. Auto-discovery via findGitBash (only if no valid config exists) + */ +export function autoDiscoverGitBash(): string | null { + if (!isWin) { + return null + } + + // 1. Check environment variable override first (highest priority) + const envOverride = process.env.CLAUDE_CODE_GIT_BASH_PATH + if (envOverride) { + const validated = validateGitBashPath(envOverride) + if (validated) { + logger.debug('Using CLAUDE_CODE_GIT_BASH_PATH override', { path: validated }) + return validated + } + logger.warn('CLAUDE_CODE_GIT_BASH_PATH provided but path is invalid', { path: envOverride }) + } + + // 2. Check if a path is already configured + const existingPath = configManager.get(ConfigKeys.GitBashPath) + const existingSource = configManager.get(ConfigKeys.GitBashPathSource) + + if (existingPath) { + const validated = validateGitBashPath(existingPath) + if (validated) { + return validated + } + // Existing path is invalid, try to auto-discover + logger.warn('Existing Git Bash path is invalid, attempting auto-discovery', { + path: existingPath, + source: existingSource + }) + } + + // 3. Try to find Git Bash via auto-discovery + const discoveredPath = findGitBash() + if (discoveredPath) { + // Persist the discovered path with 'auto' source + configManager.set(ConfigKeys.GitBashPath, discoveredPath) + configManager.set(ConfigKeys.GitBashPathSource, 'auto') + logger.info('Auto-discovered Git Bash path', { path: discoveredPath }) + } + + return discoveredPath +} + +/** + * Get Git Bash path info including source + * If no path is configured, triggers auto-discovery first + */ +export function getGitBashPathInfo(): GitBashPathInfo { + if (!isWin) { + return { path: null, source: null } + } + + let path = configManager.get(ConfigKeys.GitBashPath) ?? null + let source = configManager.get(ConfigKeys.GitBashPathSource) ?? null + + // If no path configured, trigger auto-discovery (handles upgrade from old versions) + if (!path) { + path = autoDiscoverGitBash() + source = path ? 'auto' : null + } + + return { path, source } +} diff --git a/src/preload/index.ts b/src/preload/index.ts index 117bec3b91..dc08e9a2df 100644 --- a/src/preload/index.ts +++ b/src/preload/index.ts @@ -2,7 +2,7 @@ import type { PermissionUpdate } from '@anthropic-ai/claude-agent-sdk' import { electronAPI } from '@electron-toolkit/preload' import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' import type { SpanContext } from '@opentelemetry/api' -import type { TerminalConfig, UpgradeChannel } from '@shared/config/constant' +import type { GitBashPathInfo, TerminalConfig, UpgradeChannel } from '@shared/config/constant' import type { LogLevel, LogSourceWithContext } from '@shared/config/logger' import type { FileChangeEvent, WebviewKeyEvent } from '@shared/config/types' import type { MCPServerLogEntry } from '@shared/config/types' @@ -126,6 +126,7 @@ const api = { getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName), checkGitBash: (): Promise => ipcRenderer.invoke(IpcChannel.System_CheckGitBash), getGitBashPath: (): Promise => ipcRenderer.invoke(IpcChannel.System_GetGitBashPath), + getGitBashPathInfo: (): Promise => ipcRenderer.invoke(IpcChannel.System_GetGitBashPathInfo), setGitBashPath: (newPath: string | null): Promise => ipcRenderer.invoke(IpcChannel.System_SetGitBashPath, newPath) }, diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts index d839da8964..73a5bed4fe 100644 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIApiClient.ts @@ -142,6 +142,10 @@ export class OpenAIAPIClient extends OpenAIBaseClient< return { thinking: { type: reasoningEffort ? 'enabled' : 'disabled' } } } + if (reasoningEffort === 'default') { + return {} + } + if (!reasoningEffort) { // DeepSeek hybrid inference models, v3.1 and maybe more in the future // 不同的 provider 有不同的思考控制方式,在这里统一解决 @@ -303,7 +307,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient< // Grok models/Perplexity models/OpenAI models if (isSupportedReasoningEffortModel(model)) { // 检查模型是否支持所选选项 - const supportedOptions = getModelSupportedReasoningEffortOptions(model) + const supportedOptions = getModelSupportedReasoningEffortOptions(model)?.filter((option) => option !== 'default') if (supportedOptions?.includes(reasoningEffort)) { return { reasoning_effort: reasoningEffort diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index 6f1ec709b8..ae83df4f3f 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -8,14 +8,12 @@ import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/uti import { openrouterReasoningMiddleware, skipGeminiThoughtSignatureMiddleware } from '@shared/middleware' import type { LanguageModelMiddleware } from 'ai' import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' -import { isEmpty } from 'lodash' import { getAiSdkProviderId } from '../provider/factory' import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' import { noThinkMiddleware } from './noThinkMiddleware' import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware' import { qwenThinkingMiddleware } from './qwenThinkingMiddleware' -import { toolChoiceMiddleware } from './toolChoiceMiddleware' const logger = loggerService.withContext('AiSdkMiddlewareBuilder') @@ -135,15 +133,6 @@ export class AiSdkMiddlewareBuilder { export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] { const builder = new AiSdkMiddlewareBuilder() - // 0. 知识库强制调用中间件(必须在最前面,确保第一轮强制调用知识库) - if (!isEmpty(config.assistant?.knowledge_bases?.map((base) => base.id)) && config.knowledgeRecognition !== 'on') { - builder.add({ - name: 'force-knowledge-first', - middleware: toolChoiceMiddleware('builtin_knowledge_search') - }) - logger.debug('Added toolChoice middleware to force knowledge base search on first round') - } - // 1. 根据provider添加特定中间件 if (config.provider) { addProviderSpecificMiddlewares(builder, config) diff --git a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts index 6be577f194..5b095a4461 100644 --- a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts +++ b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts @@ -31,7 +31,7 @@ import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool' const logger = loggerService.withContext('SearchOrchestrationPlugin') -const getMessageContent = (message: ModelMessage) => { +export const getMessageContent = (message: ModelMessage) => { if (typeof message.content === 'string') return message.content return message.content.reduce((acc, part) => { if (part.type === 'text') { @@ -266,14 +266,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) // 判断是否需要各种搜索 const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) - const knowledgeRecognition = assistant.knowledgeRecognition || 'on' + const knowledgeRecognition = assistant.knowledgeRecognition || 'off' const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) const shouldWebSearch = !!assistant.webSearchProviderId const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on' const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory // 执行意图分析 - if (shouldWebSearch || hasKnowledgeBase) { + if (shouldWebSearch || shouldKnowledgeSearch) { const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, { shouldWebSearch, shouldKnowledgeSearch, @@ -330,41 +330,25 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) // 📚 知识库搜索工具配置 const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) - const knowledgeRecognition = assistant.knowledgeRecognition || 'on' + const knowledgeRecognition = assistant.knowledgeRecognition || 'off' + const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on' - if (hasKnowledgeBase) { - if (knowledgeRecognition === 'off') { - // off 模式:直接添加知识库搜索工具,使用用户消息作为搜索关键词 + if (shouldKnowledgeSearch) { + // on 模式:根据意图识别结果决定是否添加工具 + const needsKnowledgeSearch = + analysisResult?.knowledge && + analysisResult.knowledge.question && + analysisResult.knowledge.question[0] !== 'not_needed' + + if (needsKnowledgeSearch && analysisResult.knowledge) { + // logger.info('📚 Adding knowledge search tool (intent-based)') const userMessage = userMessages[context.requestId] - const fallbackKeywords = { - question: [getMessageContent(userMessage) || 'search'], - rewrite: getMessageContent(userMessage) || 'search' - } - // logger.info('📚 Adding knowledge search tool (force mode)') params.tools['builtin_knowledge_search'] = knowledgeSearchTool( assistant, - fallbackKeywords, + analysisResult.knowledge, getMessageContent(userMessage), topicId ) - // params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' } - } else { - // on 模式:根据意图识别结果决定是否添加工具 - const needsKnowledgeSearch = - analysisResult?.knowledge && - analysisResult.knowledge.question && - analysisResult.knowledge.question[0] !== 'not_needed' - - if (needsKnowledgeSearch && analysisResult.knowledge) { - // logger.info('📚 Adding knowledge search tool (intent-based)') - const userMessage = userMessages[context.requestId] - params.tools['builtin_knowledge_search'] = knowledgeSearchTool( - assistant, - analysisResult.knowledge, - getMessageContent(userMessage), - topicId - ) - } } } diff --git a/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts b/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts index 70b4ac84b7..a4f345e3e5 100644 --- a/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts +++ b/src/renderer/src/aiCore/prepareParams/__tests__/model-parameters.test.ts @@ -18,7 +18,7 @@ vi.mock('@renderer/services/AssistantService', () => ({ toolUseMode: assistant.settings?.toolUseMode ?? 'prompt', defaultModel: assistant.defaultModel, customParameters: assistant.settings?.customParameters ?? [], - reasoning_effort: assistant.settings?.reasoning_effort, + reasoning_effort: assistant.settings?.reasoning_effort ?? 'default', reasoning_effort_cache: assistant.settings?.reasoning_effort_cache, qwenThinkMode: assistant.settings?.qwenThinkMode }) diff --git a/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts b/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts index fec4d197e3..df7d69d0c2 100644 --- a/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts @@ -11,6 +11,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { getAnthropicReasoningParams, + getAnthropicThinkingBudget, getBedrockReasoningParams, getCustomParameters, getGeminiReasoningParams, @@ -89,7 +90,8 @@ vi.mock('@renderer/config/models', async (importOriginal) => { isQwenAlwaysThinkModel: vi.fn(() => false), isSupportedThinkingTokenHunyuanModel: vi.fn(() => false), isSupportedThinkingTokenModel: vi.fn(() => false), - isGPT51SeriesModel: vi.fn(() => false) + isGPT51SeriesModel: vi.fn(() => false), + findTokenLimit: vi.fn(actual.findTokenLimit) } }) @@ -596,7 +598,7 @@ describe('reasoning utils', () => { expect(result).toEqual({}) }) - it('should return disabled thinking when no reasoning effort', async () => { + it('should return disabled thinking when reasoning effort is none', async () => { const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models') vi.mocked(isReasoningModel).mockReturnValue(true) @@ -611,7 +613,9 @@ describe('reasoning utils', () => { const assistant: Assistant = { id: 'test', name: 'Test', - settings: {} + settings: { + reasoning_effort: 'none' + } } as Assistant const result = getAnthropicReasoningParams(assistant, model) @@ -647,7 +651,7 @@ describe('reasoning utils', () => { expect(result).toEqual({ thinking: { type: 'enabled', - budgetTokens: 2048 + budgetTokens: 4096 } }) }) @@ -675,7 +679,7 @@ describe('reasoning utils', () => { expect(result).toEqual({}) }) - it('should disable thinking for Flash models without reasoning effort', async () => { + it('should disable thinking for Flash models when reasoning effort is none', async () => { const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models') vi.mocked(isReasoningModel).mockReturnValue(true) @@ -690,7 +694,9 @@ describe('reasoning utils', () => { const assistant: Assistant = { id: 'test', name: 'Test', - settings: {} + settings: { + reasoning_effort: 'none' + } } as Assistant const result = getGeminiReasoningParams(assistant, model) @@ -725,7 +731,7 @@ describe('reasoning utils', () => { const result = getGeminiReasoningParams(assistant, model) expect(result).toEqual({ thinkingConfig: { - thinkingBudget: 16448, + thinkingBudget: expect.any(Number), includeThoughts: true } }) @@ -889,7 +895,7 @@ describe('reasoning utils', () => { expect(result).toEqual({ reasoningConfig: { type: 'enabled', - budgetTokens: 2048 + budgetTokens: 4096 } }) }) @@ -990,4 +996,89 @@ describe('reasoning utils', () => { }) }) }) + + describe('getAnthropicThinkingBudget', () => { + it('should return undefined when reasoningEffort is undefined', async () => { + const result = getAnthropicThinkingBudget(4096, undefined, 'claude-3-7-sonnet') + expect(result).toBeUndefined() + }) + + it('should return undefined when reasoningEffort is none', async () => { + const result = getAnthropicThinkingBudget(4096, 'none', 'claude-3-7-sonnet') + expect(result).toBeUndefined() + }) + + it('should return undefined when tokenLimit is not found', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue(undefined) + + const result = getAnthropicThinkingBudget(4096, 'medium', 'unknown-model') + expect(result).toBeUndefined() + }) + + it('should calculate budget correctly when maxTokens is provided', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 }) + + const result = getAnthropicThinkingBudget(4096, 'medium', 'claude-3-7-sonnet') + // EFFORT_RATIO['medium'] = 0.5 + // budget = Math.floor((32768 - 1024) * 0.5 + 1024) + // = Math.floor(31744 * 0.5 + 1024) = Math.floor(15872 + 1024) = 16896 + // budgetTokens = Math.min(16896, 4096) = 4096 + // result = Math.max(1024, 4096) = 4096 + expect(result).toBe(4096) + }) + + it('should use tokenLimit.max when maxTokens is undefined', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 }) + + const result = getAnthropicThinkingBudget(undefined, 'medium', 'claude-3-7-sonnet') + // When maxTokens is undefined, budget is not constrained by maxTokens + // EFFORT_RATIO['medium'] = 0.5 + // budget = Math.floor((32768 - 1024) * 0.5 + 1024) + // = Math.floor(31744 * 0.5 + 1024) = Math.floor(15872 + 1024) = 16896 + // result = Math.max(1024, 16896) = 16896 + expect(result).toBe(16896) + }) + + it('should enforce minimum budget of 1024', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 100, max: 1000 }) + + const result = getAnthropicThinkingBudget(500, 'low', 'claude-3-7-sonnet') + // EFFORT_RATIO['low'] = 0.05 + // budget = Math.floor((1000 - 100) * 0.05 + 100) + // = Math.floor(900 * 0.05 + 100) = Math.floor(45 + 100) = 145 + // budgetTokens = Math.min(145, 500) = 145 + // result = Math.max(1024, 145) = 1024 + expect(result).toBe(1024) + }) + + it('should respect effort ratio for high reasoning effort', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 }) + + const result = getAnthropicThinkingBudget(8192, 'high', 'claude-3-7-sonnet') + // EFFORT_RATIO['high'] = 0.8 + // budget = Math.floor((32768 - 1024) * 0.8 + 1024) + // = Math.floor(31744 * 0.8 + 1024) = Math.floor(25395.2 + 1024) = 26419 + // budgetTokens = Math.min(26419, 8192) = 8192 + // result = Math.max(1024, 8192) = 8192 + expect(result).toBe(8192) + }) + + it('should use full token limit when maxTokens is undefined and reasoning effort is high', async () => { + const { findTokenLimit } = await import('@renderer/config/models') + vi.mocked(findTokenLimit).mockReturnValue({ min: 1024, max: 32768 }) + + const result = getAnthropicThinkingBudget(undefined, 'high', 'claude-3-7-sonnet') + // When maxTokens is undefined, budget is not constrained by maxTokens + // EFFORT_RATIO['high'] = 0.8 + // budget = Math.floor((32768 - 1024) * 0.8 + 1024) + // = Math.floor(31744 * 0.8 + 1024) = Math.floor(25395.2 + 1024) = 26419 + // result = Math.max(1024, 26419) = 26419 + expect(result).toBe(26419) + }) + }) }) diff --git a/src/renderer/src/aiCore/utils/reasoning.ts b/src/renderer/src/aiCore/utils/reasoning.ts index f182405714..a2364d97e1 100644 --- a/src/renderer/src/aiCore/utils/reasoning.ts +++ b/src/renderer/src/aiCore/utils/reasoning.ts @@ -10,6 +10,7 @@ import { GEMINI_FLASH_MODEL_REGEX, getModelSupportedReasoningEffortOptions, isDeepSeekHybridInferenceModel, + isDoubaoSeed18Model, isDoubaoSeedAfter251015, isDoubaoThinkingAutoModel, isGemini3ThinkingTokenModel, @@ -28,6 +29,7 @@ import { isSupportedThinkingTokenDoubaoModel, isSupportedThinkingTokenGeminiModel, isSupportedThinkingTokenHunyuanModel, + isSupportedThinkingTokenMiMoModel, isSupportedThinkingTokenModel, isSupportedThinkingTokenQwenModel, isSupportedThinkingTokenZhipuModel @@ -64,7 +66,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin // reasoningEffort is not set, no extra reasoning setting // Generally, for every model which supports reasoning control, the reasoning effort won't be undefined. // It's for some reasoning models that don't support reasoning control, such as deepseek reasoner. - if (!reasoningEffort) { + if (!reasoningEffort || reasoningEffort === 'default') { return {} } @@ -329,7 +331,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin // Grok models/Perplexity models/OpenAI models, use reasoning_effort if (isSupportedReasoningEffortModel(model)) { // 检查模型是否支持所选选项 - const supportedOptions = getModelSupportedReasoningEffortOptions(model) + const supportedOptions = getModelSupportedReasoningEffortOptions(model)?.filter((option) => option !== 'default') if (supportedOptions?.includes(reasoningEffort)) { return { reasoningEffort @@ -389,7 +391,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin // Use thinking, doubao, zhipu, etc. if (isSupportedThinkingTokenDoubaoModel(model)) { - if (isDoubaoSeedAfter251015(model)) { + if (isDoubaoSeedAfter251015(model) || isDoubaoSeed18Model(model)) { return { reasoningEffort } } if (reasoningEffort === 'high') { @@ -408,6 +410,12 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin return { thinking: { type: 'enabled' } } } + if (isSupportedThinkingTokenMiMoModel(model)) { + return { + thinking: { type: 'enabled' } + } + } + // Default case: no special thinking settings return {} } @@ -427,7 +435,7 @@ export function getOpenAIReasoningParams( let reasoningEffort = assistant?.settings?.reasoning_effort - if (!reasoningEffort) { + if (!reasoningEffort || reasoningEffort === 'default') { return {} } @@ -479,16 +487,14 @@ export function getAnthropicThinkingBudget( return undefined } - const budgetTokens = Math.max( - 1024, - Math.floor( - Math.min( - (tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min, - (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio - ) - ) - ) - return budgetTokens + const budget = Math.floor((tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min) + + let budgetTokens = budget + if (maxTokens !== undefined) { + budgetTokens = Math.min(budget, maxTokens) + } + + return Math.max(1024, budgetTokens) } /** @@ -505,7 +511,11 @@ export function getAnthropicReasoningParams( const reasoningEffort = assistant?.settings?.reasoning_effort - if (reasoningEffort === undefined || reasoningEffort === 'none') { + if (!reasoningEffort || reasoningEffort === 'default') { + return {} + } + + if (reasoningEffort === 'none') { return { thinking: { type: 'disabled' @@ -560,6 +570,10 @@ export function getGeminiReasoningParams( const reasoningEffort = assistant?.settings?.reasoning_effort + if (!reasoningEffort || reasoningEffort === 'default') { + return {} + } + // Gemini 推理参数 if (isSupportedThinkingTokenGeminiModel(model)) { if (reasoningEffort === undefined || reasoningEffort === 'none') { @@ -620,10 +634,6 @@ export function getXAIReasoningParams(assistant: Assistant, model: Model): Pick< const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant) - if (!reasoningEffort || reasoningEffort === 'none') { - return {} - } - switch (reasoningEffort) { case 'auto': case 'minimal': @@ -634,6 +644,10 @@ export function getXAIReasoningParams(assistant: Assistant, model: Model): Pick< return { reasoningEffort } case 'xhigh': return { reasoningEffort: 'high' } + case 'default': + case 'none': + default: + return {} } } @@ -650,7 +664,7 @@ export function getBedrockReasoningParams( const reasoningEffort = assistant?.settings?.reasoning_effort - if (reasoningEffort === undefined) { + if (reasoningEffort === undefined || reasoningEffort === 'default') { return {} } diff --git a/src/renderer/src/assets/images/models/mimo.svg b/src/renderer/src/assets/images/models/mimo.svg new file mode 100644 index 0000000000..82370fece3 --- /dev/null +++ b/src/renderer/src/assets/images/models/mimo.svg @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/src/renderer/src/assets/images/providers/mimo.svg b/src/renderer/src/assets/images/providers/mimo.svg new file mode 100644 index 0000000000..82370fece3 --- /dev/null +++ b/src/renderer/src/assets/images/providers/mimo.svg @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/src/renderer/src/components/Icons/SVGIcon.tsx b/src/renderer/src/components/Icons/SVGIcon.tsx index ad503f0e38..82be6b340e 100644 --- a/src/renderer/src/components/Icons/SVGIcon.tsx +++ b/src/renderer/src/components/Icons/SVGIcon.tsx @@ -113,6 +113,18 @@ export function MdiLightbulbOn(props: SVGProps) { ) } +export function MdiLightbulbQuestion(props: SVGProps) { + // {/* Icon from Material Design Icons by Pictogrammers - https://github.com/Templarian/MaterialDesign/blob/master/LICENSE */} + return ( + + + + ) +} + export function BingLogo(props: SVGProps) { return ( = ({ agent, afterSubmit, resolve }) => { const isEditing = (agent?: AgentWithTools) => agent !== undefined const [form, setForm] = useState(() => buildAgentForm(agent)) - const [hasGitBash, setHasGitBash] = useState(true) - const [customGitBashPath, setCustomGitBashPath] = useState('') + const [gitBashPathInfo, setGitBashPathInfo] = useState({ path: null, source: null }) useEffect(() => { if (open) { @@ -68,29 +69,15 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { } }, [agent, open]) - const checkGitBash = useCallback( - async (showToast = false) => { - try { - const [gitBashInstalled, savedPath] = await Promise.all([ - window.api.system.checkGitBash(), - window.api.system.getGitBashPath().catch(() => null) - ]) - setCustomGitBashPath(savedPath ?? '') - setHasGitBash(gitBashInstalled) - if (showToast) { - if (gitBashInstalled) { - window.toast.success(t('agent.gitBash.success', 'Git Bash detected successfully!')) - } else { - window.toast.error(t('agent.gitBash.notFound', 'Git Bash not found. Please install it first.')) - } - } - } catch (error) { - logger.error('Failed to check Git Bash:', error as Error) - setHasGitBash(true) // Default to true on error to avoid false warnings - } - }, - [t] - ) + const checkGitBash = useCallback(async () => { + if (!isWin) return + try { + const pathInfo = await window.api.system.getGitBashPathInfo() + setGitBashPathInfo(pathInfo) + } catch (error) { + logger.error('Failed to check Git Bash:', error as Error) + } + }, []) useEffect(() => { checkGitBash() @@ -119,24 +106,22 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { return } - setCustomGitBashPath(pickedPath) - await checkGitBash(true) + await checkGitBash() } catch (error) { logger.error('Failed to pick Git Bash path', error as Error) window.toast.error(t('agent.gitBash.pick.failed', 'Failed to set Git Bash path')) } }, [checkGitBash, t]) - const handleClearGitBash = useCallback(async () => { + const handleResetGitBash = useCallback(async () => { try { + // Clear manual setting and re-run auto-discovery await window.api.system.setGitBashPath(null) - setCustomGitBashPath('') - await checkGitBash(true) + await checkGitBash() } catch (error) { - logger.error('Failed to clear Git Bash path', error as Error) - window.toast.error(t('agent.gitBash.pick.failed', 'Failed to set Git Bash path')) + logger.error('Failed to reset Git Bash path', error as Error) } - }, [checkGitBash, t]) + }, [checkGitBash]) const onPermissionModeChange = useCallback((value: PermissionMode) => { setForm((prev) => { @@ -268,6 +253,12 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { return } + if (isWin && !gitBashPathInfo.path) { + window.toast.error(t('agent.gitBash.error.required', 'Git Bash path is required on Windows')) + loadingRef.current = false + return + } + if (isEditing(agent)) { if (!agent) { loadingRef.current = false @@ -327,7 +318,8 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { t, updateAgent, afterSubmit, - addAgent + addAgent, + gitBashPathInfo.path ] ) @@ -346,66 +338,6 @@ const PopupContainer: React.FC = ({ agent, afterSubmit, resolve }) => { footer={null}> - {!hasGitBash && ( - -

- {t( - 'agent.gitBash.error.description', - 'Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from' - )}{' '} - { - e.preventDefault() - window.api.openWebsite('https://git-scm.com/download/win') - }} - style={{ textDecoration: 'underline' }}> - git-scm.com - -
- - - - } - type="error" - showIcon - style={{ marginBottom: 16 }} - /> - )} - - {hasGitBash && customGitBashPath && ( - -
- {t('agent.gitBash.customPath', { - defaultValue: 'Using custom path: {{path}}', - path: customGitBashPath - })} -
-
- - -
- - } - type="success" - showIcon - style={{ marginBottom: 16 }} - /> - )} + {isWin && ( + +
+ + +
+ + + + {gitBashPathInfo.source === 'manual' && ( + + )} + + {gitBashPathInfo.path && gitBashPathInfo.source === 'auto' && ( + {t('agent.gitBash.autoDiscoveredHint', 'Auto-discovered')} + )} +
+ )} +