diff --git a/.gitignore b/.gitignore index a8107fa93..9322c8717 100644 --- a/.gitignore +++ b/.gitignore @@ -73,3 +73,5 @@ test-results YOUR_MEMORY_FILE_PATH .sessions/ +.next/ +*.tsbuildinfo diff --git a/electron.vite.config.ts b/electron.vite.config.ts index 172d48ca9..da471c9fc 100644 --- a/electron.vite.config.ts +++ b/electron.vite.config.ts @@ -25,7 +25,10 @@ export default defineConfig({ '@shared': resolve('packages/shared'), '@logger': resolve('src/main/services/LoggerService'), '@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'), - '@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node') + '@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node'), + '@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'), + '@cherrystudio/ai-core': resolve('packages/aiCore/src'), + '@cherrystudio/ai-sdk-provider': resolve('packages/ai-sdk-provider/src') } }, build: { diff --git a/packages/shared/ai-sdk-middlewares/index.ts b/packages/shared/ai-sdk-middlewares/index.ts new file mode 100644 index 000000000..a4db5ad2d --- /dev/null +++ b/packages/shared/ai-sdk-middlewares/index.ts @@ -0,0 +1,15 @@ +/** + * Shared AI SDK Middlewares + * + * Environment-agnostic middlewares that can be used in both + * renderer process and main process (API server). + */ + +export { + buildSharedMiddlewares, + getReasoningTagName, + isGemini3ModelId, + openrouterReasoningMiddleware, + type SharedMiddlewareConfig, + skipGeminiThoughtSignatureMiddleware +} from './middlewares' diff --git a/packages/shared/ai-sdk-middlewares/middlewares.ts b/packages/shared/ai-sdk-middlewares/middlewares.ts new file mode 100644 index 000000000..de857699f --- /dev/null +++ b/packages/shared/ai-sdk-middlewares/middlewares.ts @@ -0,0 +1,205 @@ +/** + * Shared AI SDK Middlewares + * + * These middlewares are environment-agnostic and can be used in both + * renderer process and main process (API server). + */ +import type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider' +import { extractReasoningMiddleware } from 'ai' + +/** + * Configuration for building shared middlewares + */ +export interface SharedMiddlewareConfig { + /** + * Whether to enable reasoning extraction + */ + enableReasoning?: boolean + + /** + * Tag name for reasoning extraction + * Defaults based on model ID + */ + reasoningTagName?: string + + /** + * Model ID - used to determine default reasoning tag and model detection + */ + modelId?: string + + /** + * Provider ID (Cherry Studio provider ID) + * Used for provider-specific middlewares like OpenRouter + */ + providerId?: string + + /** + * AI SDK Provider ID + * Used for Gemini thought signature middleware + * e.g., 'google', 'google-vertex' + */ + aiSdkProviderId?: string +} + +/** + * Check if model ID represents a Gemini 3 (2.5) model + * that requires thought signature handling + * + * @param modelId - The model ID string (not Model object) + */ +export function isGemini3ModelId(modelId?: string): boolean { + if (!modelId) return false + const lowerModelId = modelId.toLowerCase() + return lowerModelId.includes('gemini-3') +} + +/** + * Get the default reasoning tag name based on model ID + * + * Different models use different tags for reasoning content: + * - Most models: 'think' + * - GPT-OSS models: 'reasoning' + * - Gemini models: 'thought' + * - Seed models: 'seed:think' + */ +export function getReasoningTagName(modelId?: string): string { + if (!modelId) return 'think' + const lowerModelId = modelId.toLowerCase() + if (lowerModelId.includes('gpt-oss')) return 'reasoning' + if (lowerModelId.includes('gemini')) return 'thought' + if (lowerModelId.includes('seed-oss-36b')) return 'seed:think' + return 'think' +} + +/** + * Skip Gemini Thought Signature Middleware + * + * Due to the complexity of multi-model client requests (which can switch + * to other models mid-process), this middleware skips all Gemini 3 + * thinking signatures validation. + * + * @param aiSdkId - AI SDK Provider ID (e.g., 'google', 'google-vertex') + * @returns LanguageModelV2Middleware + */ +export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelV2Middleware { + const MAGIC_STRING = 'skip_thought_signature_validator' + return { + middlewareVersion: 'v2', + + transformParams: async ({ params }) => { + const transformedParams = { ...params } + // Process messages in prompt + if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) { + transformedParams.prompt = transformedParams.prompt.map((message) => { + if (typeof message.content !== 'string') { + for (const part of message.content) { + const googleOptions = part?.providerOptions?.[aiSdkId] + if (googleOptions?.thoughtSignature) { + googleOptions.thoughtSignature = MAGIC_STRING + } + } + } + return message + }) + } + + return transformedParams + } + } +} + +/** + * OpenRouter Reasoning Middleware + * + * Filters out [REDACTED] blocks from OpenRouter reasoning responses. + * OpenRouter may include [REDACTED] markers in reasoning content that + * should be removed for cleaner output. + * + * @see https://openrouter.ai/docs/docs/best-practices/reasoning-tokens + * @returns LanguageModelV2Middleware + */ +export function openrouterReasoningMiddleware(): LanguageModelV2Middleware { + const REDACTED_BLOCK = '[REDACTED]' + return { + middlewareVersion: 'v2', + wrapGenerate: async ({ doGenerate }) => { + const { content, ...rest } = await doGenerate() + const modifiedContent = content.map((part) => { + if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) { + return { + ...part, + text: part.text.replace(REDACTED_BLOCK, '') + } + } + return part + }) + return { content: modifiedContent, ...rest } + }, + wrapStream: async ({ doStream }) => { + const { stream, ...rest } = await doStream() + return { + stream: stream.pipeThrough( + new TransformStream({ + transform( + chunk: LanguageModelV2StreamPart, + controller: TransformStreamDefaultController + ) { + if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) { + controller.enqueue({ + ...chunk, + delta: chunk.delta.replace(REDACTED_BLOCK, '') + }) + } else { + controller.enqueue(chunk) + } + } + }) + ), + ...rest + } + } + } +} + +/** + * Build shared middlewares based on configuration + * + * This function builds a set of middlewares that are commonly needed + * across different environments (renderer, API server). + * + * @param config - Configuration for middleware building + * @returns Array of AI SDK middlewares + * + * @example + * ```typescript + * import { buildSharedMiddlewares } from '@shared/middleware' + * + * const middlewares = buildSharedMiddlewares({ + * enableReasoning: true, + * modelId: 'gemini-2.5-pro', + * providerId: 'openrouter', + * aiSdkProviderId: 'google' + * }) + * ``` + */ +export function buildSharedMiddlewares(config: SharedMiddlewareConfig): LanguageModelV2Middleware[] { + const middlewares: LanguageModelV2Middleware[] = [] + + // 1. Reasoning extraction middleware + if (config.enableReasoning) { + const tagName = config.reasoningTagName || getReasoningTagName(config.modelId) + middlewares.push(extractReasoningMiddleware({ tagName })) + } + + // 2. OpenRouter-specific: filter [REDACTED] blocks + if (config.providerId === 'openrouter' && config.enableReasoning) { + middlewares.push(openrouterReasoningMiddleware()) + } + + // 3. Gemini 3 (2.5) specific: skip thought signature validation + if (isGemini3ModelId(config.modelId) && config.aiSdkProviderId) { + middlewares.push(skipGeminiThoughtSignatureMiddleware(config.aiSdkProviderId)) + } + + return middlewares +} diff --git a/packages/shared/anthropic/index.ts b/packages/shared/anthropic/index.ts index b9e9cb884..e3c4a37cb 100644 --- a/packages/shared/anthropic/index.ts +++ b/packages/shared/anthropic/index.ts @@ -9,13 +9,27 @@ */ import Anthropic from '@anthropic-ai/sdk' -import type { TextBlockParam } from '@anthropic-ai/sdk/resources' +import type { MessageCreateParams, TextBlockParam, Tool as AnthropicTool } from '@anthropic-ai/sdk/resources' import { loggerService } from '@logger' -import type { Provider } from '@types' +import { type Provider, SystemProviderIds } from '@types' import type { ModelMessage } from 'ai' const logger = loggerService.withContext('anthropic-sdk') +/** + * Context for Anthropic SDK client creation. + * This allows the shared module to be used in different environments + * by providing environment-specific implementations. + */ +export interface AnthropicSdkContext { + /** + * Custom fetch function to use for HTTP requests. + * In Electron main process, this should be `net.fetch`. + * In other environments, can use the default fetch or a custom implementation. + */ + fetch?: typeof globalThis.fetch +} + const defaultClaudeCodeSystemPrompt = `You are Claude Code, Anthropic's official CLI for Claude.` const defaultClaudeCodeSystem: Array = [ @@ -58,8 +72,11 @@ const defaultClaudeCodeSystem: Array = [ export function getSdkClient( provider: Provider, oauthToken?: string | null, - extraHeaders?: Record + extraHeaders?: Record, + context?: AnthropicSdkContext ): Anthropic { + const customFetch = context?.fetch + if (provider.authType === 'oauth') { if (!oauthToken) { throw new Error('OAuth token is not available') @@ -85,7 +102,8 @@ export function getSdkClient( 'x-stainless-runtime': 'node', 'x-stainless-runtime-version': 'v22.18.0', ...extraHeaders - } + }, + fetch: customFetch }) } const baseURL = @@ -101,11 +119,12 @@ export function getSdkClient( baseURL, dangerouslyAllowBrowser: true, defaultHeaders: { - 'anthropic-beta': 'output-128k-2025-02-19', + 'anthropic-beta': 'interleaved-thinking-2025-05-14', 'APP-Code': 'MLTG2087', ...provider.extra_headers, ...extraHeaders - } + }, + fetch: customFetch }) } @@ -115,9 +134,11 @@ export function getSdkClient( baseURL, dangerouslyAllowBrowser: true, defaultHeaders: { - 'anthropic-beta': 'output-128k-2025-02-19', + 'anthropic-beta': 'interleaved-thinking-2025-05-14', + Authorization: provider.id === SystemProviderIds.longcat ? `Bearer ${provider.apiKey}` : undefined, ...provider.extra_headers - } + }, + fetch: customFetch }) } @@ -168,3 +189,31 @@ export function buildClaudeCodeSystemModelMessage(system?: string | Array { + if ('type' in tool && tool.type !== 'custom') return tool + + // oxlint-disable-next-line no-unused-vars + const { input_examples, ...sanitizedTool } = tool as AnthropicTool & { input_examples?: unknown } + + return sanitizedTool as typeof tool + }) +} diff --git a/packages/shared/config/providers.ts b/packages/shared/config/providers.ts index f7744150e..6490c61cc 100644 --- a/packages/shared/config/providers.ts +++ b/packages/shared/config/providers.ts @@ -43,6 +43,35 @@ export function isSiliconAnthropicCompatibleModel(modelId: string): boolean { } /** - * Silicon provider's Anthropic API host URL. + * PPIO provider models that support Anthropic API endpoint. + * These models can be used with Claude Code via the Anthropic-compatible API. + * + * @see https://ppio.com/docs/model/llm-anthropic-compatibility */ -export const SILICON_ANTHROPIC_API_HOST = 'https://api.siliconflow.cn' +export const PPIO_ANTHROPIC_COMPATIBLE_MODELS: readonly string[] = [ + 'moonshotai/kimi-k2-thinking', + 'minimax/minimax-m2', + 'deepseek/deepseek-v3.2-exp', + 'deepseek/deepseek-v3.1-terminus', + 'zai-org/glm-4.6', + 'moonshotai/kimi-k2-0905', + 'deepseek/deepseek-v3.1', + 'moonshotai/kimi-k2-instruct', + 'qwen/qwen3-next-80b-a3b-instruct', + 'qwen/qwen3-next-80b-a3b-thinking' +] + +/** + * Creates a Set for efficient lookup of PPIO Anthropic-compatible model IDs. + */ +const PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET = new Set(PPIO_ANTHROPIC_COMPATIBLE_MODELS) + +/** + * Checks if a model ID is compatible with Anthropic API on PPIO provider. + * + * @param modelId - The model ID to check + * @returns true if the model supports Anthropic API endpoint + */ +export function isPpioAnthropicCompatibleModel(modelId: string): boolean { + return PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET.has(modelId) +} diff --git a/src/renderer/src/aiCore/provider/config/aihubmix.ts b/packages/shared/provider/config/aihubmix.ts similarity index 53% rename from src/renderer/src/aiCore/provider/config/aihubmix.ts rename to packages/shared/provider/config/aihubmix.ts index 8feed8990..5214e8d06 100644 --- a/src/renderer/src/aiCore/provider/config/aihubmix.ts +++ b/packages/shared/provider/config/aihubmix.ts @@ -1,13 +1,13 @@ /** * AiHubMix规则集 */ -import { isOpenAILLMModel } from '@renderer/config/models' -import type { Provider } from '@renderer/types' +import { getLowerBaseModelName } from '@shared/utils/naming' +import type { MinimalModel, MinimalProvider } from '../types' import { provider2Provider, startsWith } from './helper' import type { RuleSet } from './types' -const extraProviderConfig = (provider: Provider) => { +const extraProviderConfig =

(provider: P) => { return { ...provider, extra_headers: { @@ -17,11 +17,23 @@ const extraProviderConfig = (provider: Provider) => { } } +function isOpenAILLMModel(model: M): boolean { + const modelId = getLowerBaseModelName(model.id) + const reasonings = ['o1', 'o3', 'o4', 'gpt-oss'] + if (reasonings.some((r) => modelId.includes(r))) { + return true + } + if (modelId.includes('gpt')) { + return true + } + return false +} + const AIHUBMIX_RULES: RuleSet = { rules: [ { match: startsWith('claude'), - provider: (provider: Provider) => { + provider: (provider) => { return extraProviderConfig({ ...provider, type: 'anthropic' @@ -34,7 +46,7 @@ const AIHUBMIX_RULES: RuleSet = { !model.id.endsWith('-nothink') && !model.id.endsWith('-search') && !model.id.includes('embedding'), - provider: (provider: Provider) => { + provider: (provider) => { return extraProviderConfig({ ...provider, type: 'gemini', @@ -44,7 +56,7 @@ const AIHUBMIX_RULES: RuleSet = { }, { match: isOpenAILLMModel, - provider: (provider: Provider) => { + provider: (provider) => { return extraProviderConfig({ ...provider, type: 'openai-response' @@ -52,7 +64,8 @@ const AIHUBMIX_RULES: RuleSet = { } } ], - fallbackRule: (provider: Provider) => extraProviderConfig(provider) + fallbackRule: (provider) => extraProviderConfig(provider) } -export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES) +export const aihubmixProviderCreator =

(model: MinimalModel, provider: P): P => + provider2Provider(AIHUBMIX_RULES, model, provider) diff --git a/packages/shared/provider/config/azure-anthropic.ts b/packages/shared/provider/config/azure-anthropic.ts new file mode 100644 index 000000000..e176614df --- /dev/null +++ b/packages/shared/provider/config/azure-anthropic.ts @@ -0,0 +1,22 @@ +import type { MinimalModel, MinimalProvider, ProviderType } from '../types' +import { provider2Provider, startsWith } from './helper' +import type { RuleSet } from './types' + +// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry +const AZURE_ANTHROPIC_RULES: RuleSet = { + rules: [ + { + match: startsWith('claude'), + provider: (provider: MinimalProvider) => ({ + ...provider, + type: 'anthropic' as ProviderType, + apiHost: provider.apiHost + 'anthropic/v1', + id: 'azure-anthropic' + }) + } + ], + fallbackRule: (provider: MinimalProvider) => provider +} + +export const azureAnthropicProviderCreator =

(model: MinimalModel, provider: P): P => + provider2Provider(AZURE_ANTHROPIC_RULES, model, provider) diff --git a/packages/shared/provider/config/helper.ts b/packages/shared/provider/config/helper.ts new file mode 100644 index 000000000..95f53f885 --- /dev/null +++ b/packages/shared/provider/config/helper.ts @@ -0,0 +1,32 @@ +import type { MinimalModel, MinimalProvider } from '../types' +import type { RuleSet } from './types' + +export const startsWith = + (prefix: string) => + (model: M) => + model.id.toLowerCase().startsWith(prefix.toLowerCase()) + +export const endpointIs = + (type: string) => + (model: M) => + model.endpoint_type === type + +/** + * 解析模型对应的Provider + * @param ruleSet 规则集对象 + * @param model 模型对象 + * @param provider 原始provider对象 + * @returns 解析出的provider对象 + */ +export function provider2Provider( + ruleSet: RuleSet, + model: M, + provider: P +): P { + for (const rule of ruleSet.rules) { + if (rule.match(model)) { + return rule.provider(provider) as P + } + } + return ruleSet.fallbackRule(provider) as P +} diff --git a/packages/shared/provider/config/index.ts b/packages/shared/provider/config/index.ts new file mode 100644 index 000000000..1273319ec --- /dev/null +++ b/packages/shared/provider/config/index.ts @@ -0,0 +1,6 @@ +export { aihubmixProviderCreator } from './aihubmix' +export { azureAnthropicProviderCreator } from './azure-anthropic' +export { endpointIs, provider2Provider, startsWith } from './helper' +export { newApiResolverCreator } from './newApi' +export type { RuleSet } from './types' +export { vertexAnthropicProviderCreator } from './vertex-anthropic' diff --git a/src/renderer/src/aiCore/provider/config/newApi.ts b/packages/shared/provider/config/newApi.ts similarity index 52% rename from src/renderer/src/aiCore/provider/config/newApi.ts rename to packages/shared/provider/config/newApi.ts index 97de62597..fd1b74085 100644 --- a/src/renderer/src/aiCore/provider/config/newApi.ts +++ b/packages/shared/provider/config/newApi.ts @@ -1,8 +1,7 @@ /** * NewAPI规则集 */ -import type { Provider } from '@renderer/types' - +import type { MinimalModel, MinimalProvider, ProviderType } from '../types' import { endpointIs, provider2Provider } from './helper' import type { RuleSet } from './types' @@ -10,42 +9,43 @@ const NEWAPI_RULES: RuleSet = { rules: [ { match: endpointIs('anthropic'), - provider: (provider: Provider) => { + provider: (provider) => { return { ...provider, - type: 'anthropic' + type: 'anthropic' as ProviderType } } }, { match: endpointIs('gemini'), - provider: (provider: Provider) => { + provider: (provider) => { return { ...provider, - type: 'gemini' + type: 'gemini' as ProviderType } } }, { match: endpointIs('openai-response'), - provider: (provider: Provider) => { + provider: (provider) => { return { ...provider, - type: 'openai-response' + type: 'openai-response' as ProviderType } } }, { match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model), - provider: (provider: Provider) => { + provider: (provider) => { return { ...provider, - type: 'openai' + type: 'openai' as ProviderType } } } ], - fallbackRule: (provider: Provider) => provider + fallbackRule: (provider) => provider } -export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES) +export const newApiResolverCreator =

(model: MinimalModel, provider: P): P => + provider2Provider(NEWAPI_RULES, model, provider) diff --git a/packages/shared/provider/config/types.ts b/packages/shared/provider/config/types.ts new file mode 100644 index 000000000..fdb130986 --- /dev/null +++ b/packages/shared/provider/config/types.ts @@ -0,0 +1,9 @@ +import type { MinimalModel, MinimalProvider } from '../types' + +export interface RuleSet { + rules: Array<{ + match: (model: M) => boolean + provider: (provider: P) => P + }> + fallbackRule: (provider: P) => P +} diff --git a/packages/shared/provider/config/vertex-anthropic.ts b/packages/shared/provider/config/vertex-anthropic.ts new file mode 100644 index 000000000..242ba2a9f --- /dev/null +++ b/packages/shared/provider/config/vertex-anthropic.ts @@ -0,0 +1,19 @@ +import type { MinimalModel, MinimalProvider } from '../types' +import { provider2Provider, startsWith } from './helper' +import type { RuleSet } from './types' + +const VERTEX_ANTHROPIC_RULES: RuleSet = { + rules: [ + { + match: startsWith('claude'), + provider: (provider: MinimalProvider) => ({ + ...provider, + id: 'google-vertex-anthropic' + }) + } + ], + fallbackRule: (provider: MinimalProvider) => provider +} + +export const vertexAnthropicProviderCreator =

(model: MinimalModel, provider: P): P => + provider2Provider(VERTEX_ANTHROPIC_RULES, model, provider) diff --git a/packages/shared/provider/constant.ts b/packages/shared/provider/constant.ts new file mode 100644 index 000000000..c449c9f63 --- /dev/null +++ b/packages/shared/provider/constant.ts @@ -0,0 +1,26 @@ +import { getLowerBaseModelName } from '@shared/utils/naming' + +import type { MinimalModel } from './types' + +export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1' +export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7' +export const COPILOT_INTEGRATION_ID = 'vscode-chat' +export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7' + +export const COPILOT_DEFAULT_HEADERS = { + 'Copilot-Integration-Id': COPILOT_INTEGRATION_ID, + 'User-Agent': COPILOT_USER_AGENT, + 'Editor-Version': COPILOT_EDITOR_VERSION, + 'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION, + 'editor-version': COPILOT_EDITOR_VERSION, + 'editor-plugin-version': COPILOT_PLUGIN_VERSION, + 'copilot-vision-request': 'true' +} as const + +// Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560) +const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex', 'gpt-5.1-codex', 'gpt-5.1-codex-mini'] + +export function isCopilotResponsesModel(model: M): boolean { + const normalizedId = getLowerBaseModelName(model.id) + return COPILOT_RESPONSES_MODEL_IDS.some((target) => normalizedId === target) +} diff --git a/packages/shared/provider/detection.ts b/packages/shared/provider/detection.ts new file mode 100644 index 000000000..8e76218fc --- /dev/null +++ b/packages/shared/provider/detection.ts @@ -0,0 +1,101 @@ +/** + * Provider Type Detection Utilities + * + * Functions to detect provider types based on provider configuration. + * These are pure functions that only depend on provider.type and provider.id. + * + * NOTE: These functions should match the logic in @renderer/utils/provider.ts + */ + +import type { MinimalProvider } from './types' + +/** + * Check if provider is Anthropic type + */ +export function isAnthropicProvider

(provider: P): boolean { + return provider.type === 'anthropic' +} + +/** + * Check if provider is OpenAI Response type (openai-response) + * NOTE: This matches isOpenAIProvider in renderer/utils/provider.ts + */ +export function isOpenAIProvider

(provider: P): boolean { + return provider.type === 'openai-response' +} + +/** + * Check if provider is Gemini type + */ +export function isGeminiProvider

(provider: P): boolean { + return provider.type === 'gemini' +} + +/** + * Check if provider is Azure OpenAI type + */ +export function isAzureOpenAIProvider

(provider: P): boolean { + return provider.type === 'azure-openai' +} + +/** + * Check if provider is Vertex AI type + */ +export function isVertexProvider

(provider: P): boolean { + return provider.type === 'vertexai' +} + +/** + * Check if provider is AWS Bedrock type + */ +export function isAwsBedrockProvider

(provider: P): boolean { + return provider.type === 'aws-bedrock' +} + +export function isAIGatewayProvider

(provider: P): boolean { + return provider.type === 'gateway' +} + +export function isOllamaProvider

(provider: P): boolean { + return provider.type === 'ollama' +} + +/** + * Check if Azure OpenAI provider uses responses endpoint + * Matches isAzureResponsesEndpoint in renderer/utils/provider.ts + */ +export function isAzureResponsesEndpoint

(provider: P): boolean { + return provider.apiVersion === 'preview' || provider.apiVersion === 'v1' +} + +/** + * Check if provider is Cherry AI type + * Matches isCherryAIProvider in renderer/utils/provider.ts + */ +export function isCherryAIProvider

(provider: P): boolean { + return provider.id === 'cherryai' +} + +/** + * Check if provider is Perplexity type + * Matches isPerplexityProvider in renderer/utils/provider.ts + */ +export function isPerplexityProvider

(provider: P): boolean { + return provider.id === 'perplexity' +} + +/** + * Check if provider is new-api type (supports multiple backends) + * Matches isNewApiProvider in renderer/utils/provider.ts + */ +export function isNewApiProvider

(provider: P): boolean { + return ['new-api', 'cherryin'].includes(provider.id) || provider.type === ('new-api' as string) +} + +/** + * Check if provider is OpenAI compatible + * Matches isOpenAICompatibleProvider in renderer/utils/provider.ts + */ +export function isOpenAICompatibleProvider

(provider: P): boolean { + return ['openai', 'new-api', 'mistral'].includes(provider.type) +} diff --git a/packages/shared/provider/format.ts b/packages/shared/provider/format.ts new file mode 100644 index 000000000..3a1fc637a --- /dev/null +++ b/packages/shared/provider/format.ts @@ -0,0 +1,141 @@ +/** + * Provider API Host Formatting + * + * Utilities for formatting provider API hosts to work with AI SDK. + * These handle the differences between how Cherry Studio stores API hosts + * and how AI SDK expects them. + */ + +import { + formatApiHost, + formatAzureOpenAIApiHost, + formatOllamaApiHost, + formatVertexApiHost, + isWithTrailingSharp, + routeToEndpoint, + withoutTrailingSlash +} from '../utils/url' +import { + isAnthropicProvider, + isAzureOpenAIProvider, + isCherryAIProvider, + isGeminiProvider, + isOllamaProvider, + isPerplexityProvider, + isVertexProvider +} from './detection' +import type { MinimalProvider } from './types' +import { SystemProviderIds } from './types' + +/** + * Interface for environment-specific implementations + * Renderer and Main process can provide their own implementations + */ +export interface ProviderFormatContext { + vertex: { + project: string + location: string + } +} + +/** + * Default Azure OpenAI API host formatter + */ +export function defaultFormatAzureOpenAIApiHost(host: string): string { + const normalizedHost = withoutTrailingSlash(host) + ?.replace(/\/v1$/, '') + .replace(/\/openai$/, '') + // AI SDK will add /v1 + return formatApiHost(normalizedHost + '/openai', false) +} + +/** + * Format provider API host for AI SDK + * + * This function normalizes the apiHost to work with AI SDK. + * Different providers have different requirements: + * - Most providers: add /v1 suffix + * - Gemini: add /v1beta suffix + * - Some providers: no suffix needed + * + * @param provider - The provider to format + * @param context - Optional context with environment-specific implementations + * @returns Provider with formatted apiHost (and anthropicApiHost if applicable) + */ +export function formatProviderApiHost(provider: T, context: ProviderFormatContext): T { + const formatted = { ...provider } + const appendApiVersion = !isWithTrailingSharp(provider.apiHost) + // Format anthropicApiHost if present + if (formatted.anthropicApiHost) { + formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost, appendApiVersion) + } + + // Format based on provider type + if (isAnthropicProvider(provider)) { + const baseHost = formatted.anthropicApiHost || formatted.apiHost + // AI SDK needs /v1 in baseURL + formatted.apiHost = formatApiHost(baseHost, appendApiVersion) + if (!formatted.anthropicApiHost) { + formatted.anthropicApiHost = formatted.apiHost + } + } else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) { + formatted.apiHost = formatApiHost(formatted.apiHost, false) + } else if (isOllamaProvider(formatted)) { + formatted.apiHost = formatOllamaApiHost(formatted.apiHost) + } else if (isGeminiProvider(formatted)) { + formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion, 'v1beta') + } else if (isAzureOpenAIProvider(formatted)) { + formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost) + } else if (isVertexProvider(formatted)) { + formatted.apiHost = formatVertexApiHost(formatted, context.vertex.project, context.vertex.location) + } else if (isCherryAIProvider(formatted)) { + formatted.apiHost = formatApiHost(formatted.apiHost, false) + } else if (isPerplexityProvider(formatted)) { + formatted.apiHost = formatApiHost(formatted.apiHost, false) + } else { + formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion) + } + + return formatted +} + +/** + * Get the base URL for AI SDK from a formatted provider + * + * This extracts the baseURL that AI SDK expects, handling + * the '#' endpoint routing format if present. + * + * @param formattedApiHost - The formatted apiHost (after formatProviderApiHost) + * @returns The baseURL for AI SDK + */ +export function getBaseUrlForAiSdk(formattedApiHost: string): string { + const { baseURL } = routeToEndpoint(formattedApiHost) + return baseURL +} + +/** + * Get rotated API key from comma-separated keys + * + * This is the interface for API key rotation. The actual implementation + * depends on the environment (renderer uses window.keyv, main uses its own storage). + */ +export interface ApiKeyRotator { + /** + * Get the next API key in rotation + * @param providerId - The provider ID for tracking rotation + * @param keys - Comma-separated API keys + * @returns The next API key to use + */ + getRotatedKey(providerId: string, keys: string): string +} + +/** + * Simple API key rotator that always returns the first key + * Use this when rotation is not needed + */ +export const simpleKeyRotator: ApiKeyRotator = { + getRotatedKey(_providerId: string, keys: string): string { + const keyList = keys.split(',').map((k) => k.trim()) + return keyList[0] || keys + } +} diff --git a/packages/shared/provider/index.ts b/packages/shared/provider/index.ts new file mode 100644 index 000000000..53f132acf --- /dev/null +++ b/packages/shared/provider/index.ts @@ -0,0 +1,49 @@ +/** + * Shared Provider Utilities + * + * This module exports utilities for working with AI providers + * that can be shared between main process and renderer process. + */ + +// Type definitions +export type { MinimalProvider, ProviderType, SystemProviderId } from './types' +export { SystemProviderIds } from './types' + +// Provider type detection +export { + isAIGatewayProvider, + isAnthropicProvider, + isAwsBedrockProvider, + isAzureOpenAIProvider, + isAzureResponsesEndpoint, + isCherryAIProvider, + isGeminiProvider, + isNewApiProvider, + isOllamaProvider, + isOpenAICompatibleProvider, + isOpenAIProvider, + isPerplexityProvider, + isVertexProvider +} from './detection' + +// API host formatting +export type { ApiKeyRotator, ProviderFormatContext } from './format' +export { + defaultFormatAzureOpenAIApiHost, + formatProviderApiHost, + getBaseUrlForAiSdk, + simpleKeyRotator +} from './format' + +// Provider ID mapping +export { getAiSdkProviderId, STATIC_PROVIDER_MAPPING, tryResolveProviderId } from './mapping' + +// AI SDK configuration +export type { AiSdkConfig, AiSdkConfigContext } from './sdk-config' +export { providerToAiSdkConfig } from './sdk-config' + +// Provider resolution +export { resolveActualProvider } from './resolve' + +// Provider initialization +export { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from './initialization' diff --git a/packages/shared/provider/initialization.ts b/packages/shared/provider/initialization.ts new file mode 100644 index 000000000..c21ecd873 --- /dev/null +++ b/packages/shared/provider/initialization.ts @@ -0,0 +1,114 @@ +import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider' + +type ProviderInitializationLogger = { + warn?: (message: string) => void + error?: (message: string, error: Error) => void +} + +export const SHARED_PROVIDER_CONFIGS: ProviderConfig[] = [ + { + id: 'openrouter', + name: 'OpenRouter', + import: () => import('@openrouter/ai-sdk-provider'), + creatorFunctionName: 'createOpenRouter', + supportsImageGeneration: true, + aliases: ['openrouter'] + }, + { + id: 'google-vertex', + name: 'Google Vertex AI', + import: () => import('@ai-sdk/google-vertex/edge'), + creatorFunctionName: 'createVertex', + supportsImageGeneration: true, + aliases: ['vertexai'] + }, + { + id: 'google-vertex-anthropic', + name: 'Google Vertex AI Anthropic', + import: () => import('@ai-sdk/google-vertex/anthropic/edge'), + creatorFunctionName: 'createVertexAnthropic', + supportsImageGeneration: true, + aliases: ['vertexai-anthropic'] + }, + { + id: 'azure-anthropic', + name: 'Azure AI Anthropic', + import: () => import('@ai-sdk/anthropic'), + creatorFunctionName: 'createAnthropic', + supportsImageGeneration: false, + aliases: ['azure-anthropic'] + }, + { + id: 'github-copilot-openai-compatible', + name: 'GitHub Copilot OpenAI Compatible', + import: () => import('@opeoginni/github-copilot-openai-compatible'), + creatorFunctionName: 'createGitHubCopilotOpenAICompatible', + supportsImageGeneration: false, + aliases: ['copilot', 'github-copilot'] + }, + { + id: 'bedrock', + name: 'Amazon Bedrock', + import: () => import('@ai-sdk/amazon-bedrock'), + creatorFunctionName: 'createAmazonBedrock', + supportsImageGeneration: true, + aliases: ['aws-bedrock'] + }, + { + id: 'perplexity', + name: 'Perplexity', + import: () => import('@ai-sdk/perplexity'), + creatorFunctionName: 'createPerplexity', + supportsImageGeneration: false, + aliases: ['perplexity'] + }, + { + id: 'mistral', + name: 'Mistral', + import: () => import('@ai-sdk/mistral'), + creatorFunctionName: 'createMistral', + supportsImageGeneration: false, + aliases: ['mistral'] + }, + { + id: 'huggingface', + name: 'HuggingFace', + import: () => import('@ai-sdk/huggingface'), + creatorFunctionName: 'createHuggingFace', + supportsImageGeneration: true, + aliases: ['hf', 'hugging-face'] + }, + { + id: 'gateway', + name: 'Vercel AI Gateway', + import: () => import('@ai-sdk/gateway'), + creatorFunctionName: 'createGateway', + supportsImageGeneration: true, + aliases: ['ai-gateway'] + }, + { + id: 'cerebras', + name: 'Cerebras', + import: () => import('@ai-sdk/cerebras'), + creatorFunctionName: 'createCerebras', + supportsImageGeneration: false + }, + { + id: 'ollama', + name: 'Ollama', + import: () => import('ollama-ai-provider-v2'), + creatorFunctionName: 'createOllama', + supportsImageGeneration: false + } +] as const + +export function initializeSharedProviders(logger?: ProviderInitializationLogger): void { + try { + const successCount = registerMultipleProviderConfigs(SHARED_PROVIDER_CONFIGS) + if (successCount < SHARED_PROVIDER_CONFIGS.length) { + logger?.warn?.('Some providers failed to register. Check previous error logs.') + } + } catch (error) { + logger?.error?.('Failed to initialize shared providers', error as Error) + } +} diff --git a/packages/shared/provider/mapping.ts b/packages/shared/provider/mapping.ts new file mode 100644 index 000000000..20e2e10c3 --- /dev/null +++ b/packages/shared/provider/mapping.ts @@ -0,0 +1,95 @@ +/** + * Provider ID Mapping + * + * Maps Cherry Studio provider IDs/types to AI SDK provider IDs. + * This logic should match @renderer/aiCore/provider/factory.ts + */ + +import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider' + +import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from './detection' +import type { MinimalProvider } from './types' + +/** + * Static mapping from Cherry Studio provider ID/type to AI SDK provider ID + * Matches STATIC_PROVIDER_MAPPING in @renderer/aiCore/provider/factory.ts + */ +export const STATIC_PROVIDER_MAPPING: Record = { + gemini: 'google', // Google Gemini -> google + 'azure-openai': 'azure', // Azure OpenAI -> azure + 'openai-response': 'openai', // OpenAI Responses -> openai + grok: 'xai', // Grok -> xai + copilot: 'github-copilot-openai-compatible' +} + +/** + * Try to resolve a provider identifier to an AI SDK provider ID + * Matches tryResolveProviderId in @renderer/aiCore/provider/factory.ts + * + * @param identifier - The provider ID or type to resolve + * @param checker - Provider config checker (defaults to static mapping only) + * @returns The resolved AI SDK provider ID, or null if not found + */ +export function tryResolveProviderId(identifier: string): ProviderId | null { + // 1. 检查静态映射 + const staticMapping = STATIC_PROVIDER_MAPPING[identifier] + if (staticMapping) { + return staticMapping + } + + // 2. 检查AiCore是否支持(包括别名支持) + if (hasProviderConfigByAlias(identifier)) { + // 解析为真实的Provider ID + return resolveProviderConfigId(identifier) as ProviderId + } + + return null +} + +/** + * Get the AI SDK Provider ID for a Cherry Studio provider + * Matches getAiSdkProviderId in @renderer/aiCore/provider/factory.ts + * + * Logic: + * 1. Handle Azure OpenAI specially (check responses endpoint) + * 2. Try to resolve from provider.id + * 3. Try to resolve from provider.type (but not for generic 'openai' type) + * 4. Check for OpenAI API host pattern + * 5. Fallback to provider's own ID + * + * @param provider - The Cherry Studio provider + * @param checker - Provider config checker (defaults to static mapping only) + * @returns The AI SDK provider ID to use + */ +export function getAiSdkProviderId(provider: MinimalProvider): ProviderId { + // 1. Handle Azure OpenAI specially - check this FIRST before other resolution + if (isAzureOpenAIProvider(provider)) { + if (isAzureResponsesEndpoint(provider)) { + return 'azure-responses' + } + return 'azure' + } + + // 2. 尝试解析provider.id + const resolvedFromId = tryResolveProviderId(provider.id) + if (resolvedFromId) { + return resolvedFromId + } + + // 3. 尝试解析provider.type + // 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上 + if (provider.type !== 'openai') { + const resolvedFromType = tryResolveProviderId(provider.type) + if (resolvedFromType) { + return resolvedFromType + } + } + + // 4. Check for OpenAI API host pattern + if (provider.apiHost.includes('api.openai.com')) { + return 'openai-chat' + } + + // 5. 最后的fallback(使用provider本身的id) + return provider.id +} diff --git a/packages/shared/provider/resolve.ts b/packages/shared/provider/resolve.ts new file mode 100644 index 000000000..385da6a58 --- /dev/null +++ b/packages/shared/provider/resolve.ts @@ -0,0 +1,43 @@ +import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config' +import { azureAnthropicProviderCreator } from './config/azure-anthropic' +import { isAzureOpenAIProvider, isNewApiProvider } from './detection' +import type { MinimalModel, MinimalProvider } from './types' + +export interface ResolveActualProviderOptions

{ + isSystemProvider?: (provider: P) => boolean +} + +const defaultIsSystemProvider =

(provider: P): boolean => { + if ('isSystem' in provider) { + return Boolean((provider as unknown as { isSystem?: boolean }).isSystem) + } + return false +} + +export function resolveActualProvider( + provider: P, + model: M, + options: ResolveActualProviderOptions

= {} +): P { + let resolvedProvider = provider + + if (isNewApiProvider(resolvedProvider)) { + resolvedProvider = newApiResolverCreator(model, resolvedProvider) + } + + const isSystemProvider = options.isSystemProvider?.(resolvedProvider) ?? defaultIsSystemProvider(resolvedProvider) + + if (isSystemProvider && resolvedProvider.id === 'aihubmix') { + resolvedProvider = aihubmixProviderCreator(model, resolvedProvider) + } + + if (isSystemProvider && resolvedProvider.id === 'vertexai') { + resolvedProvider = vertexAnthropicProviderCreator(model, resolvedProvider) + } + + if (isAzureOpenAIProvider(resolvedProvider)) { + resolvedProvider = azureAnthropicProviderCreator(model, resolvedProvider) + } + + return resolvedProvider +} diff --git a/packages/shared/provider/sdk-config.ts b/packages/shared/provider/sdk-config.ts new file mode 100644 index 000000000..df4b52ca4 --- /dev/null +++ b/packages/shared/provider/sdk-config.ts @@ -0,0 +1,289 @@ +/** + * AI SDK Configuration + * + * Shared utilities for converting Cherry Studio Provider to AI SDK configuration. + * Environment-specific logic (renderer/main) is injected via context interfaces. + */ + +import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider' +import { defaultAppHeaders } from '@shared/utils' +import { isEmpty } from 'lodash' + +import { routeToEndpoint } from '../utils/url' +import { isAzureOpenAIProvider, isOllamaProvider } from './detection' +import { getAiSdkProviderId } from './mapping' +import type { MinimalProvider } from './types' +import { SystemProviderIds } from './types' + +/** + * AI SDK configuration result + */ +export interface AiSdkConfig { + providerId: string + options: Record +} + +/** + * Context for environment-specific implementations + */ +export interface AiSdkConfigContext { + /** + * Check if a model uses chat completion only (for OpenAI response mode) + * Default: returns false + */ + isOpenAIChatCompletionOnlyModel?: (modelId: string) => boolean + + /** + * Check if provider supports stream options + * Default: returns true + */ + isSupportStreamOptionsProvider?: (provider: MinimalProvider) => boolean + + /** + * Get includeUsage setting for stream options + * Default: returns undefined + */ + getIncludeUsageSetting?: () => boolean | undefined | Promise + + /** + * Get Copilot default headers (constants) + * Default: returns empty object + */ + getCopilotDefaultHeaders?: () => Record + + /** + * Get Copilot stored headers from state + * Default: returns empty object + */ + getCopilotStoredHeaders?: () => Record + + /** + * Get AWS Bedrock configuration + * Default: returns undefined (not configured) + */ + getAwsBedrockConfig?: () => + | { + authType: 'apiKey' | 'iam' + region: string + apiKey?: string + accessKeyId?: string + secretAccessKey?: string + } + | undefined + + /** + * Get Vertex AI configuration + * Default: returns undefined (not configured) + */ + getVertexConfig?: (provider: MinimalProvider) => + | { + project: string + location: string + googleCredentials: { + privateKey: string + clientEmail: string + } + } + | undefined + + /** + * Get endpoint type for cherryin provider + */ + getEndpointType?: (modelId: string) => string | undefined + + /** + * Custom fetch implementation + * Main process: use Electron net.fetch + * Renderer process: use browser fetch (default) + */ + fetch?: typeof globalThis.fetch + + /** + * Get CherryAI signed fetch wrapper + * Returns a fetch function that adds signature headers to requests + */ + getCherryAISignedFetch?: () => typeof globalThis.fetch +} + +/** + * Convert Cherry Studio Provider to AI SDK configuration + * + * @param provider - The formatted provider (after formatProviderApiHost) + * @param modelId - The model ID to use + * @param context - Environment-specific implementations + * @returns AI SDK configuration + */ +export function providerToAiSdkConfig( + provider: MinimalProvider, + modelId: string, + context: AiSdkConfigContext = {} +): AiSdkConfig { + const isOpenAIChatCompletionOnlyModel = context.isOpenAIChatCompletionOnlyModel || (() => false) + const isSupportStreamOptionsProvider = context.isSupportStreamOptionsProvider || (() => true) + const getIncludeUsageSetting = context.getIncludeUsageSetting || (() => undefined) + + const aiSdkProviderId = getAiSdkProviderId(provider) + + // Build base config + const { baseURL, endpoint } = routeToEndpoint(provider.apiHost) + const baseConfig = { + baseURL, + apiKey: provider.apiKey + } + + let includeUsage: boolean | undefined = undefined + if (isSupportStreamOptionsProvider(provider)) { + const setting = getIncludeUsageSetting() + includeUsage = setting instanceof Promise ? undefined : setting + } + + // Handle Copilot specially + if (provider.id === SystemProviderIds.copilot) { + const defaultHeaders = context.getCopilotDefaultHeaders?.() ?? {} + const storedHeaders = context.getCopilotStoredHeaders?.() ?? {} + const copilotExtraOptions: Record = { + headers: { + ...defaultHeaders, + ...storedHeaders, + ...provider.extra_headers + }, + name: provider.id, + includeUsage + } + if (context.fetch) { + copilotExtraOptions.fetch = context.fetch + } + const options = ProviderConfigFactory.fromProvider( + 'github-copilot-openai-compatible', + baseConfig, + copilotExtraOptions + ) + + return { + providerId: 'github-copilot-openai-compatible', + options + } + } + + if (isOllamaProvider(provider)) { + return { + providerId: 'ollama', + options: { + ...baseConfig, + headers: { + ...provider.extra_headers, + Authorization: !isEmpty(baseConfig.apiKey) ? `Bearer ${baseConfig.apiKey}` : undefined + } + } + } + } + + // Build extra options + const extraOptions: Record = {} + if (endpoint) { + extraOptions.endpoint = endpoint + } + + // Handle OpenAI mode + if (provider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(modelId)) { + extraOptions.mode = 'responses' + } else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && provider.type === 'openai')) { + extraOptions.mode = 'chat' + } + + // Add extra headers + const headers: Record = { + ...defaultAppHeaders(), + ...provider.extra_headers + } + + if (aiSdkProviderId === 'openai') { + headers['X-Api-Key'] = baseConfig.apiKey + } + + extraOptions.headers = headers + + // Handle Azure modes + if (aiSdkProviderId === 'azure-responses') { + extraOptions.mode = 'responses' + } else if (aiSdkProviderId === 'azure') { + extraOptions.mode = 'chat' + } + + if (isAzureOpenAIProvider(provider)) { + const apiVersion = provider.apiVersion?.trim() + if (apiVersion) { + extraOptions.apiVersion = apiVersion + if (!['preview', 'v1'].includes(apiVersion)) { + extraOptions.useDeploymentBasedUrls = true + } + } + } + + // Handle AWS Bedrock + if (aiSdkProviderId === 'bedrock') { + const bedrockConfig = context.getAwsBedrockConfig?.() + if (bedrockConfig) { + extraOptions.region = bedrockConfig.region + if (bedrockConfig.authType === 'apiKey') { + extraOptions.apiKey = bedrockConfig.apiKey + } else { + extraOptions.accessKeyId = bedrockConfig.accessKeyId + extraOptions.secretAccessKey = bedrockConfig.secretAccessKey + } + } + } + + // Handle Vertex AI + if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') { + const vertexConfig = context.getVertexConfig?.(provider) + if (vertexConfig) { + extraOptions.project = vertexConfig.project + extraOptions.location = vertexConfig.location + extraOptions.googleCredentials = { + ...vertexConfig.googleCredentials, + privateKey: formatPrivateKey(vertexConfig.googleCredentials.privateKey) + } + baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models' + } + } + + // Handle cherryin endpoint type + if (aiSdkProviderId === 'cherryin') { + const endpointType = context.getEndpointType?.(modelId) + if (endpointType) { + extraOptions.endpointType = endpointType + } + } + + // Handle cherryai signed fetch + if (provider.id === 'cherryai') { + const signedFetch = context.getCherryAISignedFetch?.() + if (signedFetch) { + extraOptions.fetch = signedFetch + } + } else if (context.fetch) { + extraOptions.fetch = context.fetch + } + + // Check if AI SDK supports this provider natively + if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { + const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions) + return { + providerId: aiSdkProviderId, + options + } + } + + // Fallback to openai-compatible + const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey) + return { + providerId: 'openai-compatible', + options: { + ...options, + name: provider.id, + ...extraOptions, + includeUsage + } + } +} diff --git a/packages/shared/provider/types.ts b/packages/shared/provider/types.ts new file mode 100644 index 000000000..3dd56376d --- /dev/null +++ b/packages/shared/provider/types.ts @@ -0,0 +1,177 @@ +import * as z from 'zod' + +export const ProviderTypeSchema = z.enum([ + 'openai', + 'openai-response', + 'anthropic', + 'gemini', + 'azure-openai', + 'vertexai', + 'mistral', + 'aws-bedrock', + 'vertex-anthropic', + 'new-api', + 'gateway', + 'ollama' +]) + +export type ProviderType = z.infer + +/** + * Minimal provider interface for shared utilities + * This is the subset of Provider that shared code needs + */ +export type MinimalProvider = { + id: string + type: ProviderType + apiKey: string + apiHost: string + anthropicApiHost?: string + apiVersion?: string + extra_headers?: Record +} + +/** + * Minimal model interface for shared utilities + * This is the subset of Model that shared code needs + */ +export type MinimalModel = { + id: string + endpoint_type?: string +} + +export const SystemProviderIdSchema = z.enum([ + 'cherryin', + 'silicon', + 'aihubmix', + 'ocoolai', + 'deepseek', + 'ppio', + 'alayanew', + 'qiniu', + 'dmxapi', + 'burncloud', + 'tokenflux', + '302ai', + 'cephalon', + 'lanyun', + 'ph8', + 'openrouter', + 'ollama', + 'ovms', + 'new-api', + 'lmstudio', + 'anthropic', + 'openai', + 'azure-openai', + 'gemini', + 'vertexai', + 'github', + 'copilot', + 'zhipu', + 'yi', + 'moonshot', + 'baichuan', + 'dashscope', + 'stepfun', + 'doubao', + 'infini', + 'minimax', + 'groq', + 'together', + 'fireworks', + 'nvidia', + 'grok', + 'hyperbolic', + 'mistral', + 'jina', + 'perplexity', + 'modelscope', + 'xirang', + 'hunyuan', + 'tencent-cloud-ti', + 'baidu-cloud', + 'gpustack', + 'voyageai', + 'aws-bedrock', + 'poe', + 'aionly', + 'longcat', + 'huggingface', + 'sophnet', + 'gateway', + 'cerebras', + 'mimo' +]) + +export type SystemProviderId = z.infer + +export const isSystemProviderId = (id: string): id is SystemProviderId => { + return SystemProviderIdSchema.safeParse(id).success +} + +export const SystemProviderIds = { + cherryin: 'cherryin', + silicon: 'silicon', + aihubmix: 'aihubmix', + ocoolai: 'ocoolai', + deepseek: 'deepseek', + ppio: 'ppio', + alayanew: 'alayanew', + qiniu: 'qiniu', + dmxapi: 'dmxapi', + burncloud: 'burncloud', + tokenflux: 'tokenflux', + '302ai': '302ai', + cephalon: 'cephalon', + lanyun: 'lanyun', + ph8: 'ph8', + sophnet: 'sophnet', + openrouter: 'openrouter', + ollama: 'ollama', + ovms: 'ovms', + 'new-api': 'new-api', + lmstudio: 'lmstudio', + anthropic: 'anthropic', + openai: 'openai', + 'azure-openai': 'azure-openai', + gemini: 'gemini', + vertexai: 'vertexai', + github: 'github', + copilot: 'copilot', + zhipu: 'zhipu', + yi: 'yi', + moonshot: 'moonshot', + baichuan: 'baichuan', + dashscope: 'dashscope', + stepfun: 'stepfun', + doubao: 'doubao', + infini: 'infini', + minimax: 'minimax', + groq: 'groq', + together: 'together', + fireworks: 'fireworks', + nvidia: 'nvidia', + grok: 'grok', + hyperbolic: 'hyperbolic', + mistral: 'mistral', + jina: 'jina', + perplexity: 'perplexity', + modelscope: 'modelscope', + xirang: 'xirang', + hunyuan: 'hunyuan', + 'tencent-cloud-ti': 'tencent-cloud-ti', + 'baidu-cloud': 'baidu-cloud', + gpustack: 'gpustack', + voyageai: 'voyageai', + 'aws-bedrock': 'aws-bedrock', + poe: 'poe', + aionly: 'aionly', + longcat: 'longcat', + huggingface: 'huggingface', + gateway: 'gateway', + cerebras: 'cerebras', + mimo: 'mimo' +} as const satisfies Record + +export type SystemProviderIdTypeMap = typeof SystemProviderIds diff --git a/packages/shared/utils.ts b/packages/shared/utils/headers.ts similarity index 100% rename from packages/shared/utils.ts rename to packages/shared/utils/headers.ts diff --git a/packages/shared/utils/index.ts b/packages/shared/utils/index.ts new file mode 100644 index 000000000..11cefe0c9 --- /dev/null +++ b/packages/shared/utils/index.ts @@ -0,0 +1,3 @@ +export { defaultAppHeaders } from './headers' +export { getBaseModelName, getLowerBaseModelName } from './naming' +export * from './url' diff --git a/packages/shared/utils/naming.ts b/packages/shared/utils/naming.ts new file mode 100644 index 000000000..c9aaf55c3 --- /dev/null +++ b/packages/shared/utils/naming.ts @@ -0,0 +1,36 @@ +/** + * 从模型 ID 中提取基础名称。 + * 例如: + * - 'deepseek/deepseek-r1' => 'deepseek-r1' + * - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1' + * @param {string} id 模型 ID + * @param {string} [delimiter='/'] 分隔符,默认为 '/' + * @returns {string} 基础名称 + */ +export const getBaseModelName = (id: string, delimiter: string = '/'): string => { + const parts = id.split(delimiter) + return parts[parts.length - 1] +} + +/** + * 从模型 ID 中提取基础名称并转换为小写。 + * 例如: + * - 'deepseek/DeepSeek-R1' => 'deepseek-r1' + * - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1' + * @param {string} id 模型 ID + * @param {string} [delimiter='/'] 分隔符,默认为 '/' + * @returns {string} 小写的基础名称 + */ +export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => { + const baseModelName = getBaseModelName(id, delimiter).toLowerCase() + // for openrouter + if (baseModelName.endsWith(':free')) { + return baseModelName.replace(':free', '') + } + + // for cherryin + if (baseModelName.endsWith('(free)')) { + return baseModelName.replace('(free)', '') + } + return baseModelName +} diff --git a/packages/shared/utils/url/index.ts b/packages/shared/utils/url/index.ts new file mode 100644 index 000000000..82a0f551f --- /dev/null +++ b/packages/shared/utils/url/index.ts @@ -0,0 +1,293 @@ +/** + * Shared API Utilities + * + * Common utilities for API URL formatting and validation. + * Used by both main process (API Server) and renderer. + */ + +import type { MinimalProvider } from '@shared/provider' +import { trim } from 'lodash' + +// Supported endpoints for routing +export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const +export const SUPPORTED_ENDPOINT_LIST = [ + 'chat/completions', + 'responses', + 'messages', + 'generateContent', + 'streamGenerateContent', + ...SUPPORTED_IMAGE_ENDPOINT_LIST +] as const + +/** + * Removes the trailing slash from a URL string if it exists. + */ +export function withoutTrailingSlash(url: T): T { + return url.replace(/\/$/, '') as T +} + +/** + * Removes the trailing '#' from a URL string if it exists. + * + * @template T - The string type to preserve type safety + * @param {T} url - The URL string to process + * @returns {T} The URL string without a trailing '#' + * + * @example + * ```ts + * withoutTrailingSharp('https://example.com#') // 'https://example.com' + * withoutTrailingSharp('https://example.com') // 'https://example.com' + * ``` + */ +export function withoutTrailingSharp(url: T): T { + return url.replace(/#$/, '') as T +} + +/** + * Checks if a URL string ends with a trailing '#' character. + * + * @template T - The string type to preserve type safety + * @param {T} url - The URL string to check + * @returns {boolean} True if the URL ends with '#', false otherwise + * + * @example + * ```ts + * isWithTrailingSharp('https://example.com#') // true + * isWithTrailingSharp('https://example.com') // false + * ``` + */ +export function isWithTrailingSharp(url: T): boolean { + return url.endsWith('#') +} + +/** + * Matches a version segment in a path that starts with `/v` and optionally + * continues with `alpha` or `beta`. The segment may be followed by `/` or the end + * of the string (useful for cases like `/v3alpha/resources`). + */ +const VERSION_REGEX_PATTERN = '\\/v\\d+(?:alpha|beta)?(?=\\/|$)' + +/** + * Matches an API version at the end of a URL (with optional trailing slash). + * Used to detect and extract versions only from the trailing position. + */ +const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i + +/** + * 判断 host 的 path 中是否包含形如版本的字符串(例如 /v1、/v2beta 等), + * + * @param host - 要检查的 host 或 path 字符串 + * @returns 如果 path 中包含版本字符串则返回 true,否则 false + */ +export function hasAPIVersion(host?: string): boolean { + if (!host) return false + + const regex = new RegExp(VERSION_REGEX_PATTERN, 'i') + + try { + const url = new URL(host) + return regex.test(url.pathname) + } catch { + // 若无法作为完整 URL 解析,则当作路径直接检测 + return regex.test(host) + } +} + +/** + * 格式化 Azure OpenAI 的 API 主机地址。 + */ +export function formatAzureOpenAIApiHost(host: string): string { + const normalizedHost = withoutTrailingSlash(host) + ?.replace(/\/v1$/, '') + .replace(/\/openai$/, '') + // NOTE: AISDK会添加上`v1` + return formatApiHost(normalizedHost + '/openai', false) +} + +export function formatVertexApiHost( + provider: MinimalProvider, + project: string = 'test-project', + location: string = 'us-central1' +): string { + const { apiHost } = provider + const trimmedHost = withoutTrailingSlash(trim(apiHost)) + if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) { + const host = + location === 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com` + return `${formatApiHost(host)}/projects/${project}/locations/${location}` + } + return formatApiHost(trimmedHost) +} + +/** + * 格式化 Ollama 的 API 主机地址。 + */ +export function formatOllamaApiHost(host: string): string { + const normalizedHost = withoutTrailingSlash(host) + ?.replace(/\/v1$/, '') + ?.replace(/\/api$/, '') + ?.replace(/\/chat$/, '') + return formatApiHost(normalizedHost + '/api', false) +} + +/** + * Formats an API host URL by normalizing it and optionally appending an API version. + * + * @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed. + * @param supportApiVersion - Whether the API version is supported. Defaults to `true`. + * @param apiVersion - The API version to append if needed. Defaults to `'v1'`. + * + * @returns The formatted API host URL. If the host is empty after normalization, returns an empty string. + * If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host with trailing '#' removed. + * Otherwise, returns the host with the API version appended. + * + * @example + * formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1' + * formatApiHost('https://api.example.com#') // Returns 'https://api.example.com' + * formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2' + */ +export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string { + const normalizedHost = withoutTrailingSlash(trim(host)) + if (!normalizedHost) { + return '' + } + + const shouldAppendApiVersion = !(normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost)) + + if (shouldAppendApiVersion) { + return `${normalizedHost}/${apiVersion}` + } else { + return withoutTrailingSharp(normalizedHost) + } +} + +/** + * Converts an API host URL into separate base URL and endpoint components. + * + * This function extracts endpoint information from a composite API host string. + * If the host ends with '#', it attempts to match the preceding part against the supported endpoint list. + * + * @param apiHost - The API host string to parse + * @returns An object containing: + * - `baseURL`: The base URL without the endpoint suffix + * - `endpoint`: The matched endpoint identifier, or empty string if no match found + * + * @example + * routeToEndpoint('https://api.example.com/openai/chat/completions#') + * // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' } + * + * @example + * routeToEndpoint('https://api.example.com/v1') + * // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' } + */ +export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } { + const trimmedHost = (apiHost || '').trim() + if (!trimmedHost.endsWith('#')) { + return { baseURL: trimmedHost, endpoint: '' } + } + // Remove trailing # + const host = trimmedHost.slice(0, -1) + const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint)) + if (!endpointMatch) { + const baseURL = withoutTrailingSlash(host) + return { baseURL, endpoint: '' } + } + const baseSegment = host.slice(0, host.length - endpointMatch.length) + const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // Remove trailing colon (gemini special case) + return { baseURL, endpoint: endpointMatch } +} + +/** + * Gets the AI SDK compatible base URL from a provider's apiHost. + * + * AI SDK expects baseURL WITH version suffix (e.g., /v1). + * This function: + * 1. Handles '#' endpoint routing format + * 2. Ensures the URL has a version suffix (adds /v1 if missing) + * + * @param apiHost - The provider's apiHost value (may or may not have /v1) + * @param apiVersion - The API version to use if missing. Defaults to 'v1'. + * @returns The baseURL suitable for AI SDK (with version suffix) + * + * @example + * getAiSdkBaseUrl('https://api.openai.com') // 'https://api.openai.com/v1' + * getAiSdkBaseUrl('https://api.openai.com/v1') // 'https://api.openai.com/v1' + * getAiSdkBaseUrl('https://api.example.com/chat/completions#') // 'https://api.example.com' + */ +export function getAiSdkBaseUrl(apiHost: string, apiVersion: string = 'v1'): string { + // First handle '#' endpoint routing format + const { baseURL } = routeToEndpoint(apiHost) + + // If already has version, return as-is + if (hasAPIVersion(baseURL)) { + return withoutTrailingSlash(baseURL) + } + + // Add version suffix + return `${withoutTrailingSlash(baseURL)}/${apiVersion}` +} + +/** + * Validates an API host address. + * + * @param apiHost - The API host address to validate + * @returns true if valid URL with http/https protocol, false otherwise + */ +export function validateApiHost(apiHost: string): boolean { + if (!apiHost || !apiHost.trim()) { + return true // Allow empty + } + try { + const url = new URL(apiHost.trim()) + return url.protocol === 'http:' || url.protocol === 'https:' + } catch { + return false + } +} + +/** + * Extracts the trailing API version segment from a URL path. + * + * This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL. + * Only versions at the end of the path are extracted, not versions in the middle. + * The returned version string does not include leading or trailing slashes. + * + * @param {string} url - The URL string to parse. + * @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found. + * + * @example + * getTrailingApiVersion('https://api.example.com/v1') // 'v1' + * getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta' + * getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end) + * getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta' + * getTrailingApiVersion('https://api.example.com') // undefined + */ +export function getTrailingApiVersion(url: string): string | undefined { + const match = url.match(TRAILING_VERSION_REGEX) + + if (match) { + // Extract version without leading slash and trailing slash + return match[0].replace(/^\//, '').replace(/\/$/, '') + } + + return undefined +} + +/** + * Removes the trailing API version segment from a URL path. + * + * This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL. + * Only versions at the end of the path are removed, not versions in the middle. + * + * @param {string} url - The URL string to process. + * @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found. + * + * @example + * withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com' + * withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com' + * withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change) + * withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com' + */ +export function withoutTrailingApiVersion(url: string): string { + return url.replace(TRAILING_VERSION_REGEX, '') +} diff --git a/src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts b/src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts new file mode 100644 index 000000000..9ef19c0b9 --- /dev/null +++ b/src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts @@ -0,0 +1,637 @@ +/** + * AI SDK to Anthropic SSE Adapter + * + * Converts AI SDK's fullStream (TextStreamPart) events to Anthropic Messages API SSE format. + * This enables any AI provider supported by AI SDK to be exposed via Anthropic-compatible API. + * + * Anthropic SSE Event Flow: + * 1. message_start - Initial message with metadata + * 2. content_block_start - Begin a content block (text, tool_use, thinking) + * 3. content_block_delta - Incremental content updates + * 4. content_block_stop - End a content block + * 5. message_delta - Updates to overall message (stop_reason, usage) + * 6. message_stop - Stream complete + * + * @see https://docs.anthropic.com/en/api/messages-streaming + */ + +import type { + ContentBlock, + InputJSONDelta, + Message, + MessageDeltaUsage, + RawContentBlockDeltaEvent, + RawContentBlockStartEvent, + RawContentBlockStopEvent, + RawMessageDeltaEvent, + RawMessageStartEvent, + RawMessageStopEvent, + RawMessageStreamEvent, + StopReason, + TextBlock, + TextDelta, + ThinkingBlock, + ThinkingDelta, + ToolUseBlock, + Usage +} from '@anthropic-ai/sdk/resources/messages' +import { loggerService } from '@logger' +import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai' + +import { googleReasoningCache, openRouterReasoningCache } from '../services/reasoning-cache' + +const logger = loggerService.withContext('AiSdkToAnthropicSSE') + +interface ContentBlockState { + type: 'text' | 'tool_use' | 'thinking' + index: number + started: boolean + content: string + // For tool_use blocks + toolId?: string + toolName?: string + toolInput?: string +} + +interface AdapterState { + messageId: string + model: string + inputTokens: number + outputTokens: number + cacheInputTokens: number + currentBlockIndex: number + blocks: Map + textBlockIndex: number | null + // Track multiple thinking blocks by their reasoning ID + thinkingBlocks: Map // reasoningId -> blockIndex + currentThinkingId: string | null // Currently active thinking block ID + toolBlocks: Map // toolCallId -> blockIndex + stopReason: StopReason | null + hasEmittedMessageStart: boolean +} + +export type SSEEventCallback = (event: RawMessageStreamEvent) => void + +export interface AiSdkToAnthropicSSEOptions { + model: string + messageId?: string + inputTokens?: number + onEvent: SSEEventCallback +} + +/** + * Adapter that converts AI SDK fullStream events to Anthropic SSE events + */ +export class AiSdkToAnthropicSSE { + private state: AdapterState + private onEvent: SSEEventCallback + + constructor(options: AiSdkToAnthropicSSEOptions) { + this.onEvent = options.onEvent + this.state = { + messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`, + model: options.model, + inputTokens: options.inputTokens || 0, + outputTokens: 0, + cacheInputTokens: 0, + currentBlockIndex: 0, + blocks: new Map(), + textBlockIndex: null, + thinkingBlocks: new Map(), + currentThinkingId: null, + toolBlocks: new Map(), + stopReason: null, + hasEmittedMessageStart: false + } + } + + /** + * Process the AI SDK stream and emit Anthropic SSE events + */ + async processStream(fullStream: ReadableStream>): Promise { + const reader = fullStream.getReader() + + try { + // Emit message_start at the beginning + this.emitMessageStart() + + while (true) { + const { done, value } = await reader.read() + + if (done) { + break + } + + this.processChunk(value) + } + + // Ensure all blocks are closed and emit final events + this.finalize() + } catch (error) { + await reader.cancel() + throw error + } finally { + reader.releaseLock() + } + } + + /** + * Process a single AI SDK chunk and emit corresponding Anthropic events + */ + private processChunk(chunk: TextStreamPart): void { + logger.silly('AiSdkToAnthropicSSE - Processing chunk:', { chunk: JSON.stringify(chunk) }) + switch (chunk.type) { + // === Text Events === + case 'text-start': + this.startTextBlock() + break + + case 'text-delta': + this.emitTextDelta(chunk.text || '') + break + + case 'text-end': + this.stopTextBlock() + break + + // === Reasoning/Thinking Events === + case 'reasoning-start': { + const reasoningId = chunk.id + this.startThinkingBlock(reasoningId) + break + } + + case 'reasoning-delta': { + const reasoningId = chunk.id + this.emitThinkingDelta(chunk.text || '', reasoningId) + break + } + + case 'reasoning-end': { + const reasoningId = chunk.id + this.stopThinkingBlock(reasoningId) + break + } + + // === Tool Events === + case 'tool-call': + if (googleReasoningCache && chunk.providerMetadata?.google?.thoughtSignature) { + googleReasoningCache.set( + `google-${chunk.toolName}`, + chunk.providerMetadata?.google?.thoughtSignature as string + ) + } + if ( + openRouterReasoningCache && + chunk.providerMetadata?.openrouter?.reasoning_details && + Array.isArray(chunk.providerMetadata.openrouter.reasoning_details) + ) { + openRouterReasoningCache.set( + `openrouter-${chunk.toolCallId}`, + JSON.parse(JSON.stringify(chunk.providerMetadata.openrouter.reasoning_details)) + ) + } + this.handleToolCall({ + type: 'tool-call', + toolCallId: chunk.toolCallId, + toolName: chunk.toolName, + args: chunk.input + }) + break + + case 'tool-result': + // this.handleToolResult({ + // type: 'tool-result', + // toolCallId: chunk.toolCallId, + // toolName: chunk.toolName, + // args: chunk.input, + // result: chunk.output + // }) + break + + case 'finish-step': + if (chunk.finishReason === 'tool-calls') { + this.state.stopReason = 'tool_use' + } + break + + case 'finish': + this.handleFinish(chunk) + break + + case 'error': + throw chunk.error + + // Ignore other event types + default: + break + } + } + + private emitMessageStart(): void { + if (this.state.hasEmittedMessageStart) return + + this.state.hasEmittedMessageStart = true + + const usage: Usage = { + input_tokens: this.state.inputTokens, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + server_tool_use: null + } + + const message: Message = { + id: this.state.messageId, + type: 'message', + role: 'assistant', + content: [], + model: this.state.model, + stop_reason: null, + stop_sequence: null, + usage + } + + const event: RawMessageStartEvent = { + type: 'message_start', + message + } + + this.onEvent(event) + } + + private startTextBlock(): void { + // If we already have a text block, don't create another + if (this.state.textBlockIndex !== null) return + + const index = this.state.currentBlockIndex++ + this.state.textBlockIndex = index + this.state.blocks.set(index, { + type: 'text', + index, + started: true, + content: '' + }) + + const contentBlock: TextBlock = { + type: 'text', + text: '', + citations: null + } + + const event: RawContentBlockStartEvent = { + type: 'content_block_start', + index, + content_block: contentBlock + } + + this.onEvent(event) + } + + private emitTextDelta(text: string): void { + if (!text) return + + // Auto-start text block if not started + if (this.state.textBlockIndex === null) { + this.startTextBlock() + } + + const index = this.state.textBlockIndex! + const block = this.state.blocks.get(index) + if (block) { + block.content += text + } + + const delta: TextDelta = { + type: 'text_delta', + text + } + + const event: RawContentBlockDeltaEvent = { + type: 'content_block_delta', + index, + delta + } + + this.onEvent(event) + } + + private stopTextBlock(): void { + if (this.state.textBlockIndex === null) return + + const index = this.state.textBlockIndex + + const event: RawContentBlockStopEvent = { + type: 'content_block_stop', + index + } + + this.onEvent(event) + this.state.textBlockIndex = null + } + + private startThinkingBlock(reasoningId: string): void { + // Check if this thinking block already exists + if (this.state.thinkingBlocks.has(reasoningId)) return + + const index = this.state.currentBlockIndex++ + this.state.thinkingBlocks.set(reasoningId, index) + this.state.currentThinkingId = reasoningId + this.state.blocks.set(index, { + type: 'thinking', + index, + started: true, + content: '' + }) + + const contentBlock: ThinkingBlock = { + type: 'thinking', + thinking: '', + signature: '' + } + + const event: RawContentBlockStartEvent = { + type: 'content_block_start', + index, + content_block: contentBlock + } + + this.onEvent(event) + } + + private emitThinkingDelta(text: string, reasoningId?: string): void { + if (!text) return + + // Determine which thinking block to use + const targetId = reasoningId || this.state.currentThinkingId + if (!targetId) { + // Auto-start thinking block if not started + const newId = `reasoning_${Date.now()}` + this.startThinkingBlock(newId) + return this.emitThinkingDelta(text, newId) + } + + const index = this.state.thinkingBlocks.get(targetId) + if (index === undefined) { + // If the block doesn't exist, create it + this.startThinkingBlock(targetId) + return this.emitThinkingDelta(text, targetId) + } + + const block = this.state.blocks.get(index) + if (block) { + block.content += text + } + + const delta: ThinkingDelta = { + type: 'thinking_delta', + thinking: text + } + + const event: RawContentBlockDeltaEvent = { + type: 'content_block_delta', + index, + delta + } + + this.onEvent(event) + } + + private stopThinkingBlock(reasoningId?: string): void { + const targetId = reasoningId || this.state.currentThinkingId + if (!targetId) return + + const index = this.state.thinkingBlocks.get(targetId) + if (index === undefined) return + + const event: RawContentBlockStopEvent = { + type: 'content_block_stop', + index + } + + this.onEvent(event) + this.state.thinkingBlocks.delete(targetId) + + // Update currentThinkingId if we just closed the current one + if (this.state.currentThinkingId === targetId) { + // Set to the most recent remaining thinking block, or null if none + const remaining = Array.from(this.state.thinkingBlocks.keys()) + this.state.currentThinkingId = remaining.length > 0 ? remaining[remaining.length - 1] : null + } + } + + private handleToolCall(chunk: { type: 'tool-call'; toolCallId: string; toolName: string; args: unknown }): void { + const { toolCallId, toolName, args } = chunk + + // Check if we already have this tool call + if (this.state.toolBlocks.has(toolCallId)) { + return + } + + const index = this.state.currentBlockIndex++ + this.state.toolBlocks.set(toolCallId, index) + + const inputJson = JSON.stringify(args) + + this.state.blocks.set(index, { + type: 'tool_use', + index, + started: true, + content: inputJson, + toolId: toolCallId, + toolName, + toolInput: inputJson + }) + + // Emit content_block_start for tool_use + const contentBlock: ToolUseBlock = { + type: 'tool_use', + id: toolCallId, + name: toolName, + input: {} + } + + const startEvent: RawContentBlockStartEvent = { + type: 'content_block_start', + index, + content_block: contentBlock + } + + this.onEvent(startEvent) + + // Emit the full input as a delta (Anthropic streams JSON incrementally) + const delta: InputJSONDelta = { + type: 'input_json_delta', + partial_json: inputJson + } + + const deltaEvent: RawContentBlockDeltaEvent = { + type: 'content_block_delta', + index, + delta + } + + this.onEvent(deltaEvent) + + // Emit content_block_stop + const stopEvent: RawContentBlockStopEvent = { + type: 'content_block_stop', + index + } + + this.onEvent(stopEvent) + + // Mark that we have tool use + this.state.stopReason = 'tool_use' + } + + private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void { + // Update usage + if (chunk.totalUsage) { + this.state.inputTokens = chunk.totalUsage.inputTokens || 0 + this.state.outputTokens = chunk.totalUsage.outputTokens || 0 + this.state.cacheInputTokens = chunk.totalUsage.cachedInputTokens || 0 + } + + // Determine finish reason + if (!this.state.stopReason) { + switch (chunk.finishReason) { + case 'stop': + this.state.stopReason = 'end_turn' + break + case 'length': + this.state.stopReason = 'max_tokens' + break + case 'tool-calls': + this.state.stopReason = 'tool_use' + break + case 'content-filter': + this.state.stopReason = 'refusal' + break + default: + this.state.stopReason = 'end_turn' + } + } + } + + private finalize(): void { + // Close any open blocks + if (this.state.textBlockIndex !== null) { + this.stopTextBlock() + } + // Close all open thinking blocks + for (const reasoningId of this.state.thinkingBlocks.keys()) { + this.stopThinkingBlock(reasoningId) + } + + // Emit message_delta with final stop reason and usage + const usage: MessageDeltaUsage = { + output_tokens: this.state.outputTokens, + input_tokens: this.state.inputTokens, + cache_creation_input_tokens: this.state.cacheInputTokens, + cache_read_input_tokens: null, + server_tool_use: null + } + + const messageDeltaEvent: RawMessageDeltaEvent = { + type: 'message_delta', + delta: { + stop_reason: this.state.stopReason || 'end_turn', + stop_sequence: null + }, + usage + } + + this.onEvent(messageDeltaEvent) + + // Emit message_stop + const messageStopEvent: RawMessageStopEvent = { + type: 'message_stop' + } + + this.onEvent(messageStopEvent) + } + + /** + * Set input token count (typically from prompt) + */ + setInputTokens(count: number): void { + this.state.inputTokens = count + } + + /** + * Get the current message ID + */ + getMessageId(): string { + return this.state.messageId + } + + /** + * Build a complete Message object for non-streaming responses + */ + buildNonStreamingResponse(): Message { + const content: ContentBlock[] = [] + + // Collect all content blocks in order + const sortedBlocks = Array.from(this.state.blocks.values()).sort((a, b) => a.index - b.index) + + for (const block of sortedBlocks) { + switch (block.type) { + case 'text': + content.push({ + type: 'text', + text: block.content, + citations: null + } as TextBlock) + break + case 'thinking': + content.push({ + type: 'thinking', + thinking: block.content + } as ThinkingBlock) + break + case 'tool_use': + content.push({ + type: 'tool_use', + id: block.toolId!, + name: block.toolName!, + input: JSON.parse(block.toolInput || '{}') + } as ToolUseBlock) + break + } + } + + return { + id: this.state.messageId, + type: 'message', + role: 'assistant', + content, + model: this.state.model, + stop_reason: this.state.stopReason || 'end_turn', + stop_sequence: null, + usage: { + input_tokens: this.state.inputTokens, + output_tokens: this.state.outputTokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + server_tool_use: null + } + } + } +} + +/** + * Format an Anthropic SSE event for HTTP streaming + */ +export function formatSSEEvent(event: RawMessageStreamEvent): string { + return `event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n` +} + +/** + * Create a done marker for SSE stream + */ +export function formatSSEDone(): string { + return 'data: [DONE]\n\n' +} + +export default AiSdkToAnthropicSSE diff --git a/src/main/apiServer/adapters/__tests__/AiSdkToAnthropicSSE.test.ts b/src/main/apiServer/adapters/__tests__/AiSdkToAnthropicSSE.test.ts new file mode 100644 index 000000000..bbeed2563 --- /dev/null +++ b/src/main/apiServer/adapters/__tests__/AiSdkToAnthropicSSE.test.ts @@ -0,0 +1,536 @@ +import type { RawMessageStreamEvent } from '@anthropic-ai/sdk/resources/messages' +import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai' +import { describe, expect, it, vi } from 'vitest' + +import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '../AiSdkToAnthropicSSE' + +const createTextDelta = (text: string, id = 'text_0'): TextStreamPart => ({ + type: 'text-delta', + id, + text +}) + +const createTextStart = (id = 'text_0'): TextStreamPart => ({ + type: 'text-start', + id +}) + +const createTextEnd = (id = 'text_0'): TextStreamPart => ({ + type: 'text-end', + id +}) + +const createFinish = ( + finishReason: FinishReason | undefined = 'stop', + totalUsage?: Partial +): TextStreamPart => { + const defaultUsage: LanguageModelUsage = { + inputTokens: 0, + outputTokens: 0, + totalTokens: 0 + } + const event: TextStreamPart = { + type: 'finish', + finishReason: finishReason || 'stop', + totalUsage: { ...defaultUsage, ...totalUsage } + } + return event +} + +// Helper to create stream +function createMockStream(events: readonly TextStreamPart[]) { + return new ReadableStream>({ + start(controller) { + for (const event of events) { + controller.enqueue(event) + } + controller.close() + } + }) +} + +describe('AiSdkToAnthropicSSE', () => { + describe('Text Processing', () => { + it('should emit message_start and process text-delta events', async () => { + const events: RawMessageStreamEvent[] = [] + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: (event) => events.push(event) + }) + + // Create a mock stream with text events + const stream = createMockStream([createTextDelta('Hello'), createTextDelta(' world'), createFinish('stop')]) + + await adapter.processStream(stream) + + // Verify message_start + expect(events[0]).toMatchObject({ + type: 'message_start', + message: { + role: 'assistant', + model: 'test:model' + } + }) + + // Verify content_block_start for text + expect(events[1]).toMatchObject({ + type: 'content_block_start', + content_block: { type: 'text' } + }) + + // Verify text deltas + expect(events[2]).toMatchObject({ + type: 'content_block_delta', + delta: { type: 'text_delta', text: 'Hello' } + }) + expect(events[3]).toMatchObject({ + type: 'content_block_delta', + delta: { type: 'text_delta', text: ' world' } + }) + + // Verify content_block_stop + expect(events[4]).toMatchObject({ + type: 'content_block_stop' + }) + + // Verify message_delta with stop_reason + expect(events[5]).toMatchObject({ + type: 'message_delta', + delta: { stop_reason: 'end_turn' } + }) + + // Verify message_stop + expect(events[6]).toMatchObject({ + type: 'message_stop' + }) + }) + + it('should handle text-start and text-end events', async () => { + const events: RawMessageStreamEvent[] = [] + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: (event) => events.push(event) + }) + + const stream = createMockStream([ + createTextStart(), + createTextDelta('Test'), + createTextEnd(), + createFinish('stop') + ]) + + await adapter.processStream(stream) + + // Should have content_block_start, delta, and content_block_stop + const blockEvents = events.filter((e) => e.type.startsWith('content_block')) + expect(blockEvents.length).toBeGreaterThanOrEqual(3) + }) + + it('should auto-start text block if not explicitly started', async () => { + const events: RawMessageStreamEvent[] = [] + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: (event) => events.push(event) + }) + + const stream = createMockStream([createTextDelta('Auto-started'), createFinish('stop')]) + + await adapter.processStream(stream) + + // Should automatically emit content_block_start + expect(events.some((e) => e.type === 'content_block_start')).toBe(true) + }) + }) + + describe('Tool Call Processing', () => { + it('should emit tool_use block for tool-call events', async () => { + const events: RawMessageStreamEvent[] = [] + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: (event) => events.push(event) + }) + + const stream = createMockStream([ + { + type: 'tool-call', + toolCallId: 'call_123', + toolName: 'get_weather', + input: { location: 'SF' } + }, + createFinish('tool-calls') + ]) + + await adapter.processStream(stream) + + // Find tool_use block events + const blockStart = events.find((e) => { + if (e.type === 'content_block_start') { + return e.content_block.type === 'tool_use' + } + return false + }) + expect(blockStart).toBeDefined() + if (blockStart && blockStart.type === 'content_block_start') { + expect(blockStart.content_block).toMatchObject({ + type: 'tool_use', + id: 'call_123', + name: 'get_weather' + }) + } + + // Should emit input_json_delta + const delta = events.find((e) => { + if (e.type === 'content_block_delta') { + return e.delta.type === 'input_json_delta' + } + return false + }) + expect(delta).toBeDefined() + + // Should have stop_reason as tool_use + const messageDelta = events.find((e) => e.type === 'message_delta') + if (messageDelta && messageDelta.type === 'message_delta') { + expect(messageDelta.delta.stop_reason).toBe('tool_use') + } + }) + + it('should not create duplicate tool blocks', async () => { + const events: RawMessageStreamEvent[] = [] + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: (event) => events.push(event) + }) + + const toolCallEvent: TextStreamPart = { + type: 'tool-call', + toolCallId: 'call_123', + toolName: 'test_tool', + input: {} + } + const stream = createMockStream([toolCallEvent, toolCallEvent, createFinish()]) + + await adapter.processStream(stream) + + // Should only have one tool_use block + const toolBlocks = events.filter((e) => { + if (e.type === 'content_block_start') { + return e.content_block.type === 'tool_use' + } + return false + }) + expect(toolBlocks.length).toBe(1) + }) + }) + + describe('Reasoning/Thinking Processing', () => { + it('should emit thinking block for reasoning events', async () => { + const events: RawMessageStreamEvent[] = [] + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: (event) => events.push(event) + }) + + const stream = createMockStream([ + { type: 'reasoning-start', id: 'reason_1' }, + { type: 'reasoning-delta', id: 'reason_1', text: 'Thinking...' }, + { type: 'reasoning-end', id: 'reason_1' }, + createFinish() + ]) + + await adapter.processStream(stream) + + // Find thinking block events + const blockStart = events.find((e) => { + if (e.type === 'content_block_start') { + return e.content_block.type === 'thinking' + } + return false + }) + expect(blockStart).toBeDefined() + + // Should emit thinking_delta + const delta = events.find((e) => { + if (e.type === 'content_block_delta') { + return e.delta.type === 'thinking_delta' + } + return false + }) + expect(delta).toBeDefined() + if (delta && delta.type === 'content_block_delta' && delta.delta.type === 'thinking_delta') { + expect(delta.delta.thinking).toBe('Thinking...') + } + }) + + it('should handle multiple thinking blocks', async () => { + const events: RawMessageStreamEvent[] = [] + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: (event) => events.push(event) + }) + + const stream = createMockStream([ + { type: 'reasoning-start', id: 'reason_1' }, + { type: 'reasoning-delta', id: 'reason_1', text: 'First thought' }, + { type: 'reasoning-start', id: 'reason_2' }, + { type: 'reasoning-delta', id: 'reason_2', text: 'Second thought' }, + { type: 'reasoning-end', id: 'reason_1' }, + { type: 'reasoning-end', id: 'reason_2' }, + createFinish() + ]) + + await adapter.processStream(stream) + + // Should have two thinking blocks + const thinkingBlocks = events.filter((e) => { + if (e.type === 'content_block_start') { + return e.content_block.type === 'thinking' + } + return false + }) + expect(thinkingBlocks.length).toBe(2) + }) + }) + + describe('Finish Reasons', () => { + it('should map finish reasons correctly', async () => { + const testCases: Array<{ + aiSdkReason: FinishReason + expectedReason: string + }> = [ + { aiSdkReason: 'stop', expectedReason: 'end_turn' }, + { aiSdkReason: 'length', expectedReason: 'max_tokens' }, + { aiSdkReason: 'tool-calls', expectedReason: 'tool_use' }, + { aiSdkReason: 'content-filter', expectedReason: 'refusal' } + ] + + for (const { aiSdkReason, expectedReason } of testCases) { + const events: RawMessageStreamEvent[] = [] + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: (event) => events.push(event) + }) + + const stream = createMockStream([createFinish(aiSdkReason)]) + + await adapter.processStream(stream) + + const messageDelta = events.find((e) => e.type === 'message_delta') + if (messageDelta && messageDelta.type === 'message_delta') { + expect(messageDelta.delta.stop_reason).toBe(expectedReason) + } + } + }) + }) + + describe('Usage Tracking', () => { + it('should track token usage', async () => { + const events: RawMessageStreamEvent[] = [] + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + inputTokens: 100, + onEvent: (event) => events.push(event) + }) + + const stream = createMockStream([ + createTextDelta('Hello'), + createFinish('stop', { + inputTokens: 100, + outputTokens: 50, + cachedInputTokens: 20 + }) + ]) + + await adapter.processStream(stream) + + const messageDelta = events.find((e) => e.type === 'message_delta') + if (messageDelta && messageDelta.type === 'message_delta') { + expect(messageDelta.usage).toMatchObject({ + input_tokens: 100, + output_tokens: 50, + cache_creation_input_tokens: 20 + }) + } + }) + }) + + describe('Non-Streaming Response', () => { + it('should build complete message for non-streaming', async () => { + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: vi.fn() + }) + + const stream = createMockStream([ + createTextDelta('Hello world'), + { + type: 'tool-call', + toolCallId: 'call_1', + toolName: 'test', + input: { arg: 'value' } + }, + createFinish('tool-calls', { inputTokens: 10, outputTokens: 20 }) + ]) + + await adapter.processStream(stream) + + const response = adapter.buildNonStreamingResponse() + + expect(response).toMatchObject({ + type: 'message', + role: 'assistant', + model: 'test:model', + stop_reason: 'tool_use' + }) + + expect(response.content).toHaveLength(2) + expect(response.content[0]).toMatchObject({ + type: 'text', + text: 'Hello world' + }) + expect(response.content[1]).toMatchObject({ + type: 'tool_use', + id: 'call_1', + name: 'test', + input: { arg: 'value' } + }) + + expect(response.usage).toMatchObject({ + input_tokens: 10, + output_tokens: 20 + }) + }) + }) + + describe('Error Handling', () => { + it('should throw on error events', async () => { + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: vi.fn() + }) + + const testError = new Error('Test error') + const stream = createMockStream([{ type: 'error', error: testError }]) + + await expect(adapter.processStream(stream)).rejects.toThrow('Test error') + }) + }) + + describe('Edge Cases', () => { + it('should handle empty stream', async () => { + const events: RawMessageStreamEvent[] = [] + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: (event) => events.push(event) + }) + + const stream = new ReadableStream>({ + start(controller) { + controller.close() + } + }) + + await adapter.processStream(stream) + + // Should still emit message_start, message_delta, and message_stop + expect(events.some((e) => e.type === 'message_start')).toBe(true) + expect(events.some((e) => e.type === 'message_delta')).toBe(true) + expect(events.some((e) => e.type === 'message_stop')).toBe(true) + }) + + it('should handle empty text deltas', async () => { + const events: RawMessageStreamEvent[] = [] + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: (event) => events.push(event) + }) + + const stream = createMockStream([createTextDelta(''), createTextDelta(''), createFinish()]) + + await adapter.processStream(stream) + + // Should not emit deltas for empty text + const deltas = events.filter((e) => e.type === 'content_block_delta') + expect(deltas.length).toBe(0) + }) + }) + + describe('Utility Functions', () => { + it('should format SSE events correctly', () => { + const event: RawMessageStreamEvent = { + type: 'message_start', + message: { + id: 'msg_123', + type: 'message', + role: 'assistant', + content: [], + model: 'test', + stop_reason: null, + stop_sequence: null, + usage: { + input_tokens: 10, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + server_tool_use: null + } + } + } + + const formatted = formatSSEEvent(event) + + expect(formatted).toContain('event: message_start') + expect(formatted).toContain('data: ') + expect(formatted).toContain('"type":"message_start"') + expect(formatted.endsWith('\n\n')).toBe(true) + }) + + it('should format SSE done marker correctly', () => { + const done = formatSSEDone() + + expect(done).toBe('data: [DONE]\n\n') + }) + }) + + describe('Message ID', () => { + it('should use provided message ID', () => { + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + messageId: 'custom_msg_123', + onEvent: vi.fn() + }) + + expect(adapter.getMessageId()).toBe('custom_msg_123') + }) + + it('should generate message ID if not provided', () => { + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: vi.fn() + }) + + const messageId = adapter.getMessageId() + expect(messageId).toMatch(/^msg_/) + }) + }) + + describe('Input Tokens', () => { + it('should allow setting input tokens', () => { + const events: RawMessageStreamEvent[] = [] + const adapter = new AiSdkToAnthropicSSE({ + model: 'test:model', + onEvent: (event) => events.push(event) + }) + + adapter.setInputTokens(500) + + const stream = createMockStream([createFinish()]) + + return adapter.processStream(stream).then(() => { + const messageStart = events.find((e) => e.type === 'message_start') + if (messageStart && messageStart.type === 'message_start') { + expect(messageStart.message.usage.input_tokens).toBe(500) + } + }) + }) + }) +}) diff --git a/src/main/apiServer/adapters/index.ts b/src/main/apiServer/adapters/index.ts new file mode 100644 index 000000000..a19db9594 --- /dev/null +++ b/src/main/apiServer/adapters/index.ts @@ -0,0 +1,13 @@ +/** + * Shared Adapters + * + * This module exports adapters for converting between different AI API formats. + */ + +export { + AiSdkToAnthropicSSE, + type AiSdkToAnthropicSSEOptions, + formatSSEDone, + formatSSEEvent, + type SSEEventCallback +} from './AiSdkToAnthropicSSE' diff --git a/src/main/apiServer/adapters/openrouter.ts b/src/main/apiServer/adapters/openrouter.ts new file mode 100644 index 000000000..3b6319178 --- /dev/null +++ b/src/main/apiServer/adapters/openrouter.ts @@ -0,0 +1,95 @@ +import * as z from 'zod/v4' + +enum ReasoningFormat { + Unknown = 'unknown', + OpenAIResponsesV1 = 'openai-responses-v1', + XAIResponsesV1 = 'xai-responses-v1', + AnthropicClaudeV1 = 'anthropic-claude-v1', + GoogleGeminiV1 = 'google-gemini-v1' +} + +// Anthropic Claude was the first reasoning that we're +// passing back and forth +export const DEFAULT_REASONING_FORMAT = ReasoningFormat.AnthropicClaudeV1 + +function isDefinedOrNotNull(value: T | null | undefined): value is T { + return value !== null && value !== undefined +} + +export enum ReasoningDetailType { + Summary = 'reasoning.summary', + Encrypted = 'reasoning.encrypted', + Text = 'reasoning.text' +} + +export const CommonReasoningDetailSchema = z + .object({ + id: z.string().nullish(), + format: z.enum(ReasoningFormat).nullish(), + index: z.number().optional() + }) + .loose() + +export const ReasoningDetailSummarySchema = z + .object({ + type: z.literal(ReasoningDetailType.Summary), + summary: z.string() + }) + .extend(CommonReasoningDetailSchema.shape) +export type ReasoningDetailSummary = z.infer + +export const ReasoningDetailEncryptedSchema = z + .object({ + type: z.literal(ReasoningDetailType.Encrypted), + data: z.string() + }) + .extend(CommonReasoningDetailSchema.shape) + +export type ReasoningDetailEncrypted = z.infer + +export const ReasoningDetailTextSchema = z + .object({ + type: z.literal(ReasoningDetailType.Text), + text: z.string().nullish(), + signature: z.string().nullish() + }) + .extend(CommonReasoningDetailSchema.shape) + +export type ReasoningDetailText = z.infer + +export const ReasoningDetailUnionSchema = z.union([ + ReasoningDetailSummarySchema, + ReasoningDetailEncryptedSchema, + ReasoningDetailTextSchema +]) + +export type ReasoningDetailUnion = z.infer + +const ReasoningDetailsWithUnknownSchema = z.union([ReasoningDetailUnionSchema, z.unknown().transform(() => null)]) + +export const ReasoningDetailArraySchema = z + .array(ReasoningDetailsWithUnknownSchema) + .transform((d) => d.filter((d): d is ReasoningDetailUnion => !!d)) + +export const OutputUnionToReasoningDetailsSchema = z.union([ + z + .object({ + delta: z.object({ + reasoning_details: z.array(ReasoningDetailsWithUnknownSchema) + }) + }) + .transform((data) => data.delta.reasoning_details.filter(isDefinedOrNotNull)), + z + .object({ + message: z.object({ + reasoning_details: z.array(ReasoningDetailsWithUnknownSchema) + }) + }) + .transform((data) => data.message.reasoning_details.filter(isDefinedOrNotNull)), + z + .object({ + text: z.string(), + reasoning_details: z.array(ReasoningDetailsWithUnknownSchema) + }) + .transform((data) => data.reasoning_details.filter(isDefinedOrNotNull)) +]) diff --git a/src/main/apiServer/routes/__tests__/messages.test.ts b/src/main/apiServer/routes/__tests__/messages.test.ts new file mode 100644 index 000000000..b07b82e31 --- /dev/null +++ b/src/main/apiServer/routes/__tests__/messages.test.ts @@ -0,0 +1,393 @@ +import { describe, expect, it } from 'vitest' + +import { estimateTokenCount } from '../messages' + +describe('estimateTokenCount', () => { + describe('Text Content', () => { + it('should estimate tokens for simple string content', () => { + const input = { + messages: [ + { + role: 'user' as const, + content: 'Hello, world!' + } + ] + } + const tokens = estimateTokenCount(input) + // Should include text tokens + role overhead (3) + expect(tokens).toBeGreaterThan(3) + expect(tokens).toBeLessThan(20) + }) + + it('should estimate tokens for multiple messages', () => { + const input = { + messages: [ + { role: 'user' as const, content: 'First message' }, + { role: 'assistant' as const, content: 'Second message' }, + { role: 'user' as const, content: 'Third message' } + ] + } + const tokens = estimateTokenCount(input) + // Should include text tokens + role overhead (3 per message = 9) + expect(tokens).toBeGreaterThan(9) + }) + + it('should estimate tokens for text content blocks', () => { + const input = { + messages: [ + { + role: 'user' as const, + content: [ + { type: 'text' as const, text: 'Hello' }, + { type: 'text' as const, text: 'World' } + ] + } + ] + } + const tokens = estimateTokenCount(input) + expect(tokens).toBeGreaterThan(3) + }) + + it('should handle empty messages array', () => { + const input = { + messages: [] + } + const tokens = estimateTokenCount(input) + expect(tokens).toBe(0) + }) + + it('should handle messages with empty content', () => { + const input = { + messages: [{ role: 'user' as const, content: '' }] + } + const tokens = estimateTokenCount(input) + // Should only have role overhead (3) + expect(tokens).toBe(3) + }) + }) + + describe('System Messages', () => { + it('should estimate tokens for string system message', () => { + const input = { + messages: [{ role: 'user' as const, content: 'Hello' }], + system: 'You are a helpful assistant.' + } + const tokens = estimateTokenCount(input) + // Should include system tokens + message tokens + role overhead + expect(tokens).toBeGreaterThan(3) + }) + + it('should estimate tokens for system content blocks', () => { + const input = { + messages: [{ role: 'user' as const, content: 'Hello' }], + system: [ + { type: 'text' as const, text: 'System instruction 1' }, + { type: 'text' as const, text: 'System instruction 2' } + ] + } + const tokens = estimateTokenCount(input) + expect(tokens).toBeGreaterThan(3) + }) + }) + + describe('Image Content', () => { + it('should estimate tokens for base64 images', () => { + // Create a fake base64 string (400 characters = ~300 bytes when decoded) + const fakeBase64 = 'A'.repeat(400) + const input = { + messages: [ + { + role: 'user' as const, + content: [ + { + type: 'image' as const, + source: { + type: 'base64' as const, + media_type: 'image/png' as const, + data: fakeBase64 + } + } + ] + } + ] + } + const tokens = estimateTokenCount(input) + // Should estimate based on data size: 400 * 0.75 / 100 = 3 tokens + role overhead (3) + expect(tokens).toBeGreaterThan(3) + expect(tokens).toBeLessThan(10) + }) + + it('should estimate tokens for URL images', () => { + const input = { + messages: [ + { + role: 'user' as const, + content: [ + { + type: 'image' as const, + source: { + type: 'url' as const, + url: 'https://example.com/image.png' + } + } + ] + } + ] + } + const tokens = estimateTokenCount(input) + // Should use default estimate: 1000 + role overhead (3) + expect(tokens).toBe(1003) + }) + + it('should estimate tokens for mixed text and image content', () => { + const input = { + messages: [ + { + role: 'user' as const, + content: [ + { type: 'text' as const, text: 'What is in this image?' }, + { + type: 'image' as const, + source: { + type: 'url' as const, + url: 'https://example.com/image.png' + } + } + ] + } + ] + } + const tokens = estimateTokenCount(input) + // Should include text tokens + 1000 (image) + role overhead (3) + expect(tokens).toBeGreaterThan(1003) + }) + }) + + describe('Tool Content', () => { + it('should estimate tokens for tool_use blocks', () => { + const input = { + messages: [ + { + role: 'assistant' as const, + content: [ + { + type: 'tool_use' as const, + id: 'tool_123', + name: 'get_weather', + input: { location: 'San Francisco', unit: 'celsius' } + } + ] + } + ] + } + const tokens = estimateTokenCount(input) + // Should include: tool name tokens + input JSON tokens + 10 (overhead) + 3 (role) + expect(tokens).toBeGreaterThan(13) + }) + + it('should estimate tokens for tool_result blocks with string content', () => { + const input = { + messages: [ + { + role: 'user' as const, + content: [ + { + type: 'tool_result' as const, + tool_use_id: 'tool_123', + content: 'The weather in San Francisco is 18°C and sunny.' + } + ] + } + ] + } + const tokens = estimateTokenCount(input) + // Should include: content tokens + 10 (overhead) + 3 (role) + expect(tokens).toBeGreaterThan(13) + }) + + it('should estimate tokens for tool_result blocks with array content', () => { + const input = { + messages: [ + { + role: 'user' as const, + content: [ + { + type: 'tool_result' as const, + tool_use_id: 'tool_123', + content: [ + { type: 'text' as const, text: 'Result 1' }, + { type: 'text' as const, text: 'Result 2' } + ] + } + ] + } + ] + } + const tokens = estimateTokenCount(input) + // Should include: text tokens + 10 (overhead) + 3 (role) + expect(tokens).toBeGreaterThan(13) + }) + + it('should handle tool_use without input', () => { + const input = { + messages: [ + { + role: 'assistant' as const, + content: [ + { + type: 'tool_use' as const, + id: 'tool_123', + name: 'no_input_tool', + input: {} + } + ] + } + ] + } + const tokens = estimateTokenCount(input) + // Should include: tool name tokens + 10 (overhead) + 3 (role) + expect(tokens).toBeGreaterThan(13) + }) + }) + + describe('Complex Scenarios', () => { + it('should estimate tokens for multi-turn conversation with various content types', () => { + const input = { + messages: [ + { + role: 'user' as const, + content: [ + { type: 'text' as const, text: 'Analyze this image' }, + { + type: 'image' as const, + source: { + type: 'url' as const, + url: 'https://example.com/chart.png' + } + } + ] + }, + { + role: 'assistant' as const, + content: [ + { + type: 'tool_use' as const, + id: 'tool_1', + name: 'analyze_image', + input: { url: 'https://example.com/chart.png' } + } + ] + }, + { + role: 'user' as const, + content: [ + { + type: 'tool_result' as const, + tool_use_id: 'tool_1', + content: 'The chart shows sales data for Q4 2024.' + } + ] + }, + { + role: 'assistant' as const, + content: 'Based on the analysis, the sales trend is positive.' + } + ], + system: 'You are a data analyst assistant.' + } + const tokens = estimateTokenCount(input) + // Should include: + // - System message tokens + // - Message 1: text + image (1000) + 3 + // - Message 2: tool_use + 10 + 3 + // - Message 3: tool_result + 10 + 3 + // - Message 4: text + 3 + expect(tokens).toBeGreaterThan(1032) // At least 1000 (image) + 32 (overhead) + }) + + it('should handle very long text content', () => { + const longText = 'word '.repeat(1000) // ~5000 characters + const input = { + messages: [{ role: 'user' as const, content: longText }] + } + const tokens = estimateTokenCount(input) + // Should estimate based on text length using tokenx + expect(tokens).toBeGreaterThan(1000) + }) + + it('should handle multiple images in single message', () => { + const input = { + messages: [ + { + role: 'user' as const, + content: [ + { + type: 'image' as const, + source: { type: 'url' as const, url: 'https://example.com/1.png' } + }, + { + type: 'image' as const, + source: { type: 'url' as const, url: 'https://example.com/2.png' } + }, + { + type: 'image' as const, + source: { type: 'url' as const, url: 'https://example.com/3.png' } + } + ] + } + ] + } + const tokens = estimateTokenCount(input) + // Should estimate: 3 * 1000 (images) + 3 (role) + expect(tokens).toBe(3003) + }) + }) + + describe('Edge Cases', () => { + it('should handle undefined system message', () => { + const input = { + messages: [{ role: 'user' as const, content: 'Hello' }], + system: undefined + } + const tokens = estimateTokenCount(input) + expect(tokens).toBeGreaterThan(0) + }) + + it('should handle empty system message', () => { + const input = { + messages: [{ role: 'user' as const, content: 'Hello' }], + system: '' + } + const tokens = estimateTokenCount(input) + expect(tokens).toBeGreaterThan(0) + }) + + it('should handle content blocks with missing text', () => { + const input = { + messages: [ + { + role: 'user' as const, + content: [{ type: 'text' as const, text: undefined as any }] + } + ] + } + const tokens = estimateTokenCount(input) + // Should only have role overhead + expect(tokens).toBe(3) + }) + + it('should handle empty content array', () => { + const input = { + messages: [ + { + role: 'user' as const, + content: [] + } + ] + } + const tokens = estimateTokenCount(input) + // Should only have role overhead + expect(tokens).toBe(3) + }) + }) +}) diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index 02ce0544e..dbd6b676c 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -1,17 +1,129 @@ import type { MessageCreateParams } from '@anthropic-ai/sdk/resources' import { loggerService } from '@logger' +import { buildSharedMiddlewares, type SharedMiddlewareConfig } from '@shared/ai-sdk-middlewares' +import { getAiSdkProviderId } from '@shared/provider' import type { Provider } from '@types' import type { Request, Response } from 'express' import express from 'express' +import { approximateTokenSize } from 'tokenx' import { messagesService } from '../services/messages' -import { getProviderById, validateModelId } from '../utils' +import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages' +import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils' + +/** + * Check if a specific model on a provider should use direct Anthropic SDK + * + * A provider+model combination is considered "Anthropic-compatible" if: + * 1. It's a native Anthropic provider (type === 'anthropic'), OR + * 2. It has anthropicApiHost configured AND the specific model supports Anthropic API + * (for aggregated providers like Silicon, only certain models support Anthropic endpoint) + * + * @param provider - The provider to check + * @param modelId - The model ID to check (without provider prefix) + * @returns true if should use direct Anthropic SDK, false for unified SDK + */ +function shouldUseDirectAnthropic(provider: Provider, modelId: string): boolean { + // Native Anthropic provider - always use direct SDK + if (provider.type === 'anthropic') { + return true + } + + // No anthropicApiHost configured - use unified SDK + if (!provider.anthropicApiHost?.trim()) { + return false + } + + // Has anthropicApiHost - check model-level compatibility + // For aggregated providers, only specific models support Anthropic API + return isModelAnthropicCompatible(provider, modelId) +} const logger = loggerService.withContext('ApiServerMessagesRoutes') const router = express.Router() const providerRouter = express.Router({ mergeParams: true }) +/** + * Estimate token count from messages + * Uses tokenx library for accurate token estimation and supports images, tools + */ +export interface CountTokensInput { + messages: MessageCreateParams['messages'] + system?: MessageCreateParams['system'] +} + +export function estimateTokenCount(input: CountTokensInput): number { + const { messages, system } = input + let totalTokens = 0 + + // Count system message tokens using tokenx + if (system) { + if (typeof system === 'string') { + totalTokens += approximateTokenSize(system) + } else if (Array.isArray(system)) { + for (const block of system) { + if (block.type === 'text' && block.text) { + totalTokens += approximateTokenSize(block.text) + } + } + } + } + + // Count message tokens + for (const msg of messages) { + if (typeof msg.content === 'string') { + totalTokens += approximateTokenSize(msg.content) + } else if (Array.isArray(msg.content)) { + for (const block of msg.content) { + if (block.type === 'text' && block.text) { + totalTokens += approximateTokenSize(block.text) + } else if (block.type === 'image') { + // Image token estimation (consistent with TokenService) + if (block.source.type === 'base64') { + // Base64 images: estimate from data length + const dataSize = block.source.data.length * 0.75 // base64 to bytes + totalTokens += Math.floor(dataSize / 100) + } else { + // URL images: use default estimate + totalTokens += 1000 + } + } else if (block.type === 'tool_use') { + // Tool use token estimation: name + input JSON + if (block.name) { + totalTokens += approximateTokenSize(block.name) + } + if (block.input) { + const inputJson = JSON.stringify(block.input) + totalTokens += approximateTokenSize(inputJson) + } + // Add overhead for tool use structure + totalTokens += 10 + } else if (block.type === 'tool_result') { + // Tool result token estimation + if (typeof block.content === 'string') { + totalTokens += approximateTokenSize(block.content) + } else if (Array.isArray(block.content)) { + for (const item of block.content) { + if (typeof item === 'string') { + totalTokens += approximateTokenSize(item) + } else if (item.type === 'text' && item.text) { + totalTokens += approximateTokenSize(item.text) + } + } + } + // Add overhead for tool result structure + totalTokens += 10 + } + } + } + // Add role overhead + totalTokens += 3 + } + + return totalTokens +} + // Helper function for basic request validation async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> { const request: MessageCreateParams = req.body @@ -32,22 +144,101 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro return { valid: true } } +/** + * Shared handler for count_tokens endpoint + * Validates request and returns token count estimation + */ +async function handleCountTokens( + req: Request, + res: Response, + options: { + requireModel?: boolean + logContext?: Record + } = {} +): Promise { + try { + const { model, messages, system } = req.body + const { requireModel = false, logContext = {} } = options + + // Validate model parameter if required + if (requireModel && !model) { + return res.status(400).json({ + type: 'error', + error: { + type: 'invalid_request_error', + message: 'model parameter is required' + } + }) + } + + // Validate messages parameter + if (!messages || !Array.isArray(messages)) { + return res.status(400).json({ + type: 'error', + error: { + type: 'invalid_request_error', + message: 'messages parameter is required' + } + }) + } + + // Estimate token count + const estimatedTokens = estimateTokenCount({ messages, system }) + + // Log with context + logger.debug('Token count estimated', { + model, + messageCount: messages.length, + estimatedTokens, + ...logContext + }) + + return res.json({ + input_tokens: estimatedTokens + }) + } catch (error: any) { + logger.error('Token counting error', { error }) + return res.status(500).json({ + type: 'error', + error: { + type: 'api_error', + message: error.message || 'Internal server error' + } + }) + } +} + interface HandleMessageProcessingOptions { - req: Request res: Response provider: Provider request: MessageCreateParams modelId?: string } -async function handleMessageProcessing({ - req, +/** + * Handle message processing using direct Anthropic SDK + * Used for providers with anthropicApiHost or native Anthropic providers + * This bypasses AI SDK conversion and uses native Anthropic protocol + */ +async function handleDirectAnthropicProcessing({ res, provider, request, - modelId -}: HandleMessageProcessingOptions): Promise { + modelId, + extraHeaders +}: HandleMessageProcessingOptions & { extraHeaders?: Record }): Promise { + const actualModelId = modelId || request.model + + logger.info('Processing message via direct Anthropic SDK', { + providerId: provider.id, + providerType: provider.type, + modelId: actualModelId, + stream: !!request.stream, + anthropicApiHost: provider.anthropicApiHost + }) + try { + // Validate request const validation = messagesService.validateRequest(request) if (!validation.isValid) { res.status(400).json({ @@ -60,28 +251,126 @@ async function handleMessageProcessing({ return } - const extraHeaders = messagesService.prepareHeaders(req.headers) + // Process message using messagesService (native Anthropic SDK) const { client, anthropicRequest } = await messagesService.processMessage({ provider, request, extraHeaders, - modelId + modelId: actualModelId }) if (request.stream) { + // Use native Anthropic streaming await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider) - return + } else { + // Use native Anthropic non-streaming + const response = await client.messages.create(anthropicRequest) + res.json(response) } - - const response = await client.messages.create(anthropicRequest) - res.json(response) } catch (error: any) { - logger.error('Message processing error', { error }) + logger.error('Direct Anthropic processing error', { error }) const { statusCode, errorResponse } = messagesService.transformError(error) res.status(statusCode).json(errorResponse) } } +/** + * Handle message processing using unified AI SDK + * Used for non-Anthropic providers that need format conversion + * - Uses AI SDK adapters with output converted to Anthropic SSE format + */ +async function handleUnifiedProcessing({ + res, + provider, + request, + modelId +}: HandleMessageProcessingOptions): Promise { + const actualModelId = modelId || request.model + + logger.info('Processing message via unified AI SDK', { + providerId: provider.id, + providerType: provider.type, + modelId: actualModelId, + stream: !!request.stream + }) + + try { + // Validate request + const validation = messagesService.validateRequest(request) + if (!validation.isValid) { + res.status(400).json({ + type: 'error', + error: { + type: 'invalid_request_error', + message: validation.errors.join('; ') + } + }) + return + } + + const middlewareConfig: SharedMiddlewareConfig = { + modelId: actualModelId, + providerId: provider.id, + aiSdkProviderId: getAiSdkProviderId(provider) + } + const middlewares = buildSharedMiddlewares(middlewareConfig) + + logger.debug('Built middlewares for unified processing', { + middlewareCount: middlewares.length, + modelId: actualModelId, + providerId: provider.id + }) + + if (request.stream) { + await streamUnifiedMessages({ + response: res, + provider, + modelId: actualModelId, + params: request, + middlewares, + onError: (error) => { + logger.error('Stream error', error as Error) + }, + onComplete: () => { + logger.debug('Stream completed') + } + }) + } else { + const response = await generateUnifiedMessage({ + provider, + modelId: actualModelId, + params: request, + middlewares + }) + res.json(response) + } + } catch (error: any) { + const { statusCode, errorResponse } = messagesService.transformError(error) + res.status(statusCode).json(errorResponse) + } +} + +/** + * Handle message processing - routes to appropriate handler based on provider and model + * + * Routing logic: + * - Native Anthropic providers (type === 'anthropic'): Direct Anthropic SDK + * - Providers with anthropicApiHost AND model supports Anthropic API: Direct Anthropic SDK + * - Other providers/models: Unified AI SDK with Anthropic SSE conversion + */ +async function handleMessageProcessing({ + res, + provider, + request, + modelId +}: HandleMessageProcessingOptions): Promise { + const actualModelId = modelId || request.model + if (shouldUseDirectAnthropic(provider, actualModelId)) { + return handleDirectAnthropicProcessing({ res, provider, request, modelId }) + } + return handleUnifiedProcessing({ res, provider, request, modelId }) +} + /** * @swagger * /v1/messages: @@ -235,7 +524,7 @@ router.post('/', async (req: Request, res: Response) => { const provider = modelValidation.provider! const modelId = modelValidation.modelId! - return handleMessageProcessing({ req, res, provider, request, modelId }) + return handleMessageProcessing({ res, provider, request, modelId }) } catch (error: any) { logger.error('Message processing error', { error }) const { statusCode, errorResponse } = messagesService.transformError(error) @@ -393,7 +682,7 @@ providerRouter.post('/', async (req: Request, res: Response) => { const request: MessageCreateParams = req.body - return handleMessageProcessing({ req, res, provider, request }) + return handleMessageProcessing({ res, provider, request }) } catch (error: any) { logger.error('Message processing error', { error }) const { statusCode, errorResponse } = messagesService.transformError(error) @@ -401,4 +690,58 @@ providerRouter.post('/', async (req: Request, res: Response) => { } }) +/** + * @swagger + * /v1/messages/count_tokens: + * post: + * summary: Count tokens for messages + * description: Count tokens for Anthropic Messages API format (required by Claude Code SDK) + * tags: [Messages] + * requestBody: + * required: true + * content: + * application/json: + * schema: + * type: object + * required: + * - model + * - messages + * properties: + * model: + * type: string + * description: Model ID + * messages: + * type: array + * items: + * type: object + * system: + * type: string + * description: System message + * responses: + * 200: + * description: Token count response + * content: + * application/json: + * schema: + * type: object + * properties: + * input_tokens: + * type: integer + * 400: + * description: Bad request + */ +router.post('/count_tokens', async (req: Request, res: Response) => { + return handleCountTokens(req, res, { requireModel: true }) +}) + +/** + * Provider-specific count_tokens endpoint + */ +providerRouter.post('/count_tokens', async (req: Request, res: Response) => { + return handleCountTokens(req, res, { + requireModel: false, + logContext: { providerId: req.params.provider } + }) +}) + export { providerRouter as messagesProviderRoutes, router as messagesRoutes } diff --git a/src/main/apiServer/services/__tests__/jsonSchemaToZod.test.ts b/src/main/apiServer/services/__tests__/jsonSchemaToZod.test.ts new file mode 100644 index 000000000..804db0d35 --- /dev/null +++ b/src/main/apiServer/services/__tests__/jsonSchemaToZod.test.ts @@ -0,0 +1,340 @@ +import { describe, expect, it } from 'vitest' +import * as z from 'zod' + +import { type JsonSchemaLike, jsonSchemaToZod } from '../unified-messages' + +describe('jsonSchemaToZod', () => { + describe('Basic Types', () => { + it('should convert string type', () => { + const schema: JsonSchemaLike = { type: 'string' } + const result = jsonSchemaToZod(schema) + expect(result).toBeInstanceOf(z.ZodString) + expect(result.safeParse('hello').success).toBe(true) + expect(result.safeParse(123).success).toBe(false) + }) + + it('should convert string with minLength', () => { + const schema: JsonSchemaLike = { type: 'string', minLength: 3 } + const result = jsonSchemaToZod(schema) + expect(result.safeParse('ab').success).toBe(false) + expect(result.safeParse('abc').success).toBe(true) + }) + + it('should convert string with maxLength', () => { + const schema: JsonSchemaLike = { type: 'string', maxLength: 5 } + const result = jsonSchemaToZod(schema) + expect(result.safeParse('hello').success).toBe(true) + expect(result.safeParse('hello world').success).toBe(false) + }) + + it('should convert string with pattern', () => { + const schema: JsonSchemaLike = { type: 'string', pattern: '^[0-9]+$' } + const result = jsonSchemaToZod(schema) + expect(result.safeParse('123').success).toBe(true) + expect(result.safeParse('abc').success).toBe(false) + }) + + it('should convert number type', () => { + const schema: JsonSchemaLike = { type: 'number' } + const result = jsonSchemaToZod(schema) + expect(result).toBeInstanceOf(z.ZodNumber) + expect(result.safeParse(42).success).toBe(true) + expect(result.safeParse(3.14).success).toBe(true) + expect(result.safeParse('42').success).toBe(false) + }) + + it('should convert integer type', () => { + const schema: JsonSchemaLike = { type: 'integer' } + const result = jsonSchemaToZod(schema) + expect(result.safeParse(42).success).toBe(true) + expect(result.safeParse(3.14).success).toBe(false) + }) + + it('should convert number with minimum', () => { + const schema: JsonSchemaLike = { type: 'number', minimum: 10 } + const result = jsonSchemaToZod(schema) + expect(result.safeParse(5).success).toBe(false) + expect(result.safeParse(10).success).toBe(true) + expect(result.safeParse(15).success).toBe(true) + }) + + it('should convert number with maximum', () => { + const schema: JsonSchemaLike = { type: 'number', maximum: 100 } + const result = jsonSchemaToZod(schema) + expect(result.safeParse(50).success).toBe(true) + expect(result.safeParse(100).success).toBe(true) + expect(result.safeParse(150).success).toBe(false) + }) + + it('should convert boolean type', () => { + const schema: JsonSchemaLike = { type: 'boolean' } + const result = jsonSchemaToZod(schema) + expect(result).toBeInstanceOf(z.ZodBoolean) + expect(result.safeParse(true).success).toBe(true) + expect(result.safeParse(false).success).toBe(true) + expect(result.safeParse('true').success).toBe(false) + }) + + it('should convert null type', () => { + const schema: JsonSchemaLike = { type: 'null' } + const result = jsonSchemaToZod(schema) + expect(result).toBeInstanceOf(z.ZodNull) + expect(result.safeParse(null).success).toBe(true) + expect(result.safeParse(undefined).success).toBe(false) + }) + }) + + describe('Enum Types', () => { + it('should convert string enum', () => { + const schema: JsonSchemaLike = { enum: ['red', 'green', 'blue'] } + const result = jsonSchemaToZod(schema) + expect(result.safeParse('red').success).toBe(true) + expect(result.safeParse('green').success).toBe(true) + expect(result.safeParse('yellow').success).toBe(false) + }) + + it('should convert non-string enum with literals', () => { + const schema: JsonSchemaLike = { enum: [1, 2, 3] } + const result = jsonSchemaToZod(schema) + expect(result.safeParse(1).success).toBe(true) + expect(result.safeParse(2).success).toBe(true) + expect(result.safeParse(4).success).toBe(false) + }) + + it('should convert single value enum', () => { + const schema: JsonSchemaLike = { enum: ['only'] } + const result = jsonSchemaToZod(schema) + expect(result.safeParse('only').success).toBe(true) + expect(result.safeParse('other').success).toBe(false) + }) + + it('should convert mixed enum', () => { + const schema: JsonSchemaLike = { enum: ['text', 1, true] } + const result = jsonSchemaToZod(schema) + expect(result.safeParse('text').success).toBe(true) + expect(result.safeParse(1).success).toBe(true) + expect(result.safeParse(true).success).toBe(true) + expect(result.safeParse(false).success).toBe(false) + }) + }) + + describe('Array Types', () => { + it('should convert array of strings', () => { + const schema: JsonSchemaLike = { + type: 'array', + items: { type: 'string' } + } + const result = jsonSchemaToZod(schema) + expect(result.safeParse(['a', 'b']).success).toBe(true) + expect(result.safeParse([1, 2]).success).toBe(false) + }) + + it('should convert array without items (unknown)', () => { + const schema: JsonSchemaLike = { type: 'array' } + const result = jsonSchemaToZod(schema) + expect(result.safeParse([]).success).toBe(true) + expect(result.safeParse(['a', 1, true]).success).toBe(true) + }) + + it('should convert array with minItems', () => { + const schema: JsonSchemaLike = { + type: 'array', + items: { type: 'number' }, + minItems: 2 + } + const result = jsonSchemaToZod(schema) + expect(result.safeParse([1]).success).toBe(false) + expect(result.safeParse([1, 2]).success).toBe(true) + }) + + it('should convert array with maxItems', () => { + const schema: JsonSchemaLike = { + type: 'array', + items: { type: 'number' }, + maxItems: 3 + } + const result = jsonSchemaToZod(schema) + expect(result.safeParse([1, 2, 3]).success).toBe(true) + expect(result.safeParse([1, 2, 3, 4]).success).toBe(false) + }) + }) + + describe('Object Types', () => { + it('should convert simple object', () => { + const schema: JsonSchemaLike = { + type: 'object', + properties: { + name: { type: 'string' }, + age: { type: 'number' } + } + } + const result = jsonSchemaToZod(schema) + expect(result.safeParse({ name: 'John', age: 30 }).success).toBe(true) + expect(result.safeParse({ name: 'John', age: '30' }).success).toBe(false) + }) + + it('should handle required fields', () => { + const schema: JsonSchemaLike = { + type: 'object', + properties: { + name: { type: 'string' }, + age: { type: 'number' } + }, + required: ['name'] + } + const result = jsonSchemaToZod(schema) + expect(result.safeParse({ name: 'John', age: 30 }).success).toBe(true) + expect(result.safeParse({ age: 30 }).success).toBe(false) + expect(result.safeParse({ name: 'John' }).success).toBe(true) + }) + + it('should convert empty object', () => { + const schema: JsonSchemaLike = { type: 'object' } + const result = jsonSchemaToZod(schema) + expect(result.safeParse({}).success).toBe(true) + }) + + it('should convert nested objects', () => { + const schema: JsonSchemaLike = { + type: 'object', + properties: { + user: { + type: 'object', + properties: { + name: { type: 'string' }, + email: { type: 'string' } + } + } + } + } + const result = jsonSchemaToZod(schema) + expect(result.safeParse({ user: { name: 'John', email: 'john@example.com' } }).success).toBe(true) + expect(result.safeParse({ user: { name: 'John' } }).success).toBe(true) + }) + }) + + describe('Union Types', () => { + it('should convert union type (type array)', () => { + const schema: JsonSchemaLike = { type: ['string', 'null'] } + const result = jsonSchemaToZod(schema) + expect(result.safeParse('hello').success).toBe(true) + expect(result.safeParse(null).success).toBe(true) + expect(result.safeParse(123).success).toBe(false) + }) + + it('should convert single type array', () => { + const schema: JsonSchemaLike = { type: ['string'] } + const result = jsonSchemaToZod(schema) + expect(result.safeParse('hello').success).toBe(true) + expect(result.safeParse(123).success).toBe(false) + }) + + it('should convert multiple union types', () => { + const schema: JsonSchemaLike = { type: ['string', 'number', 'boolean'] } + const result = jsonSchemaToZod(schema) + expect(result.safeParse('text').success).toBe(true) + expect(result.safeParse(42).success).toBe(true) + expect(result.safeParse(true).success).toBe(true) + expect(result.safeParse(null).success).toBe(false) + }) + }) + + describe('Description Handling', () => { + it('should preserve description for string', () => { + const schema: JsonSchemaLike = { + type: 'string', + description: 'A user name' + } + const result = jsonSchemaToZod(schema) + expect(result.description).toBe('A user name') + }) + + it('should preserve description for enum', () => { + const schema: JsonSchemaLike = { + enum: ['red', 'green', 'blue'], + description: 'Available colors' + } + const result = jsonSchemaToZod(schema) + expect(result.description).toBe('Available colors') + }) + + it('should preserve description for object', () => { + const schema: JsonSchemaLike = { + type: 'object', + description: 'User object', + properties: { + name: { type: 'string' } + } + } + const result = jsonSchemaToZod(schema) + expect(result.description).toBe('User object') + }) + }) + + describe('Edge Cases', () => { + it('should handle unknown type', () => { + const schema: JsonSchemaLike = { type: 'unknown-type' as any } + const result = jsonSchemaToZod(schema) + expect(result).toBeInstanceOf(z.ZodType) + expect(result.safeParse(anything).success).toBe(true) + }) + + it('should handle schema without type', () => { + const schema: JsonSchemaLike = {} + const result = jsonSchemaToZod(schema) + expect(result).toBeInstanceOf(z.ZodType) + expect(result.safeParse(anything).success).toBe(true) + }) + + it('should handle complex nested schema', () => { + const schema: JsonSchemaLike = { + type: 'object', + properties: { + items: { + type: 'array', + items: { + type: 'object', + properties: { + id: { type: 'integer' }, + name: { type: 'string' }, + tags: { + type: 'array', + items: { type: 'string' } + } + }, + required: ['id'] + } + } + } + } + const result = jsonSchemaToZod(schema) + const validData = { + items: [ + { id: 1, name: 'Item 1', tags: ['tag1', 'tag2'] }, + { id: 2, tags: [] } + ] + } + expect(result.safeParse(validData).success).toBe(true) + + const invalidData = { + items: [{ name: 'No ID' }] + } + expect(result.safeParse(invalidData).success).toBe(false) + }) + }) + + describe('OpenRouter Model IDs', () => { + it('should handle model identifier format with colons', () => { + const schema: JsonSchemaLike = { + type: 'string', + enum: ['openrouter:anthropic/claude-3.5-sonnet:free', 'openrouter:gpt-4:paid'] + } + const result = jsonSchemaToZod(schema) + expect(result.safeParse('openrouter:anthropic/claude-3.5-sonnet:free').success).toBe(true) + expect(result.safeParse('openrouter:gpt-4:paid').success).toBe(true) + expect(result.safeParse('other').success).toBe(false) + }) + }) +}) + +const anything = Math.random() > 0.5 ? 'string' : Math.random() > 0.5 ? 123 : { a: true } diff --git a/src/main/apiServer/services/__tests__/unified-messages.test.ts b/src/main/apiServer/services/__tests__/unified-messages.test.ts new file mode 100644 index 000000000..f8ee1a495 --- /dev/null +++ b/src/main/apiServer/services/__tests__/unified-messages.test.ts @@ -0,0 +1,795 @@ +import type { MessageCreateParams } from '@anthropic-ai/sdk/resources/messages' +import { describe, expect, it } from 'vitest' + +import { convertAnthropicToAiMessages, convertAnthropicToolsToAiSdk } from '../unified-messages' + +describe('unified-messages', () => { + describe('convertAnthropicToolsToAiSdk', () => { + it('should return undefined for empty tools array', () => { + const result = convertAnthropicToolsToAiSdk([]) + expect(result).toBeUndefined() + }) + + it('should return undefined for undefined tools', () => { + const result = convertAnthropicToolsToAiSdk(undefined) + expect(result).toBeUndefined() + }) + + it('should convert simple tool with string schema', () => { + const anthropicTools: MessageCreateParams['tools'] = [ + { + type: 'custom', + name: 'get_weather', + description: 'Get current weather', + input_schema: { + type: 'object', + properties: { + location: { type: 'string' } + }, + required: ['location'] + } + } + ] + + const result = convertAnthropicToolsToAiSdk(anthropicTools) + expect(result).toBeDefined() + expect(result).toHaveProperty('get_weather') + expect(result!.get_weather).toHaveProperty('description', 'Get current weather') + }) + + it('should convert multiple tools', () => { + const anthropicTools: MessageCreateParams['tools'] = [ + { + type: 'custom', + name: 'tool1', + description: 'First tool', + input_schema: { + type: 'object', + properties: {} + } + }, + { + type: 'custom', + name: 'tool2', + description: 'Second tool', + input_schema: { + type: 'object', + properties: {} + } + } + ] + + const result = convertAnthropicToolsToAiSdk(anthropicTools) + expect(result).toBeDefined() + expect(Object.keys(result!)).toHaveLength(2) + expect(result).toHaveProperty('tool1') + expect(result).toHaveProperty('tool2') + }) + + it('should convert tool with complex schema', () => { + const anthropicTools: MessageCreateParams['tools'] = [ + { + type: 'custom', + name: 'search', + description: 'Search for information', + input_schema: { + type: 'object', + properties: { + query: { type: 'string', minLength: 1 }, + limit: { type: 'integer', minimum: 1, maximum: 100 }, + filters: { + type: 'array', + items: { type: 'string' } + } + }, + required: ['query'] + } + } + ] + + const result = convertAnthropicToolsToAiSdk(anthropicTools) + expect(result).toBeDefined() + expect(result).toHaveProperty('search') + }) + + it('should skip bash_20250124 tool type', () => { + const anthropicTools: MessageCreateParams['tools'] = [ + { + type: 'bash_20250124', + name: 'bash' + }, + { + type: 'custom', + name: 'regular_tool', + description: 'A regular tool', + input_schema: { + type: 'object', + properties: {} + } + } + ] + + const result = convertAnthropicToolsToAiSdk(anthropicTools) + expect(result).toBeDefined() + expect(Object.keys(result!)).toHaveLength(1) + expect(result).toHaveProperty('regular_tool') + expect(result).not.toHaveProperty('bash') + }) + + it('should handle tool with no description', () => { + const anthropicTools: MessageCreateParams['tools'] = [ + { + type: 'custom', + name: 'no_desc_tool', + input_schema: { + type: 'object', + properties: {} + } + } + ] + + const result = convertAnthropicToolsToAiSdk(anthropicTools) + expect(result).toBeDefined() + expect(result).toHaveProperty('no_desc_tool') + expect(result!.no_desc_tool).toHaveProperty('description', '') + }) + }) + + describe('convertAnthropicToAiMessages', () => { + describe('System Messages', () => { + it('should convert string system message', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + system: 'You are a helpful assistant.', + messages: [ + { + role: 'user', + content: 'Hello' + } + ] + } + + const result = convertAnthropicToAiMessages(params) + expect(result).toHaveLength(2) + expect(result[0]).toEqual({ + role: 'system', + content: 'You are a helpful assistant.' + }) + }) + + it('should convert array system message', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + system: [ + { type: 'text', text: 'Instruction 1' }, + { type: 'text', text: 'Instruction 2' } + ], + messages: [ + { + role: 'user', + content: 'Hello' + } + ] + } + + const result = convertAnthropicToAiMessages(params) + expect(result[0]).toEqual({ + role: 'system', + content: 'Instruction 1\nInstruction 2' + }) + }) + + it('should handle no system message', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'user', + content: 'Hello' + } + ] + } + + const result = convertAnthropicToAiMessages(params) + expect(result[0].role).toBe('user') + }) + }) + + describe('Text Messages', () => { + it('should convert simple string message', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'user', + content: 'Hello, world!' + } + ] + } + + const result = convertAnthropicToAiMessages(params) + expect(result).toHaveLength(1) + expect(result[0]).toEqual({ + role: 'user', + content: 'Hello, world!' + }) + }) + + it('should convert text block array', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'user', + content: [ + { type: 'text', text: 'First part' }, + { type: 'text', text: 'Second part' } + ] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + expect(result).toHaveLength(1) + expect(result[0].role).toBe('user') + expect(Array.isArray(result[0].content)).toBe(true) + if (Array.isArray(result[0].content)) { + expect(result[0].content).toHaveLength(2) + expect(result[0].content[0]).toEqual({ type: 'text', text: 'First part' }) + expect(result[0].content[1]).toEqual({ type: 'text', text: 'Second part' }) + } + }) + + it('should convert assistant message', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'user', + content: 'Hello' + }, + { + role: 'assistant', + content: 'Hi there!' + } + ] + } + + const result = convertAnthropicToAiMessages(params) + expect(result).toHaveLength(2) + expect(result[1]).toEqual({ + role: 'assistant', + content: 'Hi there!' + }) + }) + }) + + describe('Image Messages', () => { + it('should convert base64 image', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'user', + content: [ + { + type: 'image', + source: { + type: 'base64', + media_type: 'image/png', + data: 'iVBORw0KGgo=' + } + } + ] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + expect(result).toHaveLength(1) + expect(Array.isArray(result[0].content)).toBe(true) + if (Array.isArray(result[0].content)) { + expect(result[0].content).toHaveLength(1) + const imagePart = result[0].content[0] + if (imagePart.type === 'image') { + expect(imagePart.image).toBe('data:image/png;base64,iVBORw0KGgo=') + } + } + }) + + it('should convert URL image', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'user', + content: [ + { + type: 'image', + source: { + type: 'url', + url: 'https://example.com/image.png' + } + } + ] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + if (Array.isArray(result[0].content)) { + const imagePart = result[0].content[0] + if (imagePart.type === 'image') { + expect(imagePart.image).toBe('https://example.com/image.png') + } + } + }) + + it('should convert mixed text and image content', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'user', + content: [ + { type: 'text', text: 'Look at this:' }, + { + type: 'image', + source: { + type: 'url', + url: 'https://example.com/pic.jpg' + } + } + ] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + if (Array.isArray(result[0].content)) { + expect(result[0].content).toHaveLength(2) + expect(result[0].content[0].type).toBe('text') + expect(result[0].content[1].type).toBe('image') + } + }) + }) + + describe('Tool Messages', () => { + it('should convert tool_use block', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'user', + content: 'What is the weather?' + }, + { + role: 'assistant', + content: [ + { + type: 'tool_use', + id: 'call_123', + name: 'get_weather', + input: { location: 'San Francisco' } + } + ] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + expect(result).toHaveLength(2) + const assistantMsg = result[1] + expect(assistantMsg.role).toBe('assistant') + if (Array.isArray(assistantMsg.content)) { + expect(assistantMsg.content).toHaveLength(1) + const toolCall = assistantMsg.content[0] + if (toolCall.type === 'tool-call') { + expect(toolCall.toolName).toBe('get_weather') + expect(toolCall.toolCallId).toBe('call_123') + expect(toolCall.input).toEqual({ location: 'San Francisco' }) + } + } + }) + + it('should convert tool_result with string content', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'assistant', + content: [ + { + type: 'tool_use', + id: 'call_123', + name: 'get_weather', + input: {} + } + ] + }, + { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: 'call_123', + content: 'Temperature is 72°F' + } + ] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + const toolMsg = result[1] + expect(toolMsg.role).toBe('tool') + if (Array.isArray(toolMsg.content)) { + expect(toolMsg.content).toHaveLength(1) + const toolResult = toolMsg.content[0] + if (toolResult.type === 'tool-result') { + expect(toolResult.toolCallId).toBe('call_123') + expect(toolResult.toolName).toBe('get_weather') + if (toolResult.output.type === 'text') { + expect(toolResult.output.value).toBe('Temperature is 72°F') + } + } + } + }) + + it('should convert tool_result with array content', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'assistant', + content: [ + { + type: 'tool_use', + id: 'call_456', + name: 'analyze', + input: {} + } + ] + }, + { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: 'call_456', + content: [ + { type: 'text', text: 'Result part 1' }, + { type: 'text', text: 'Result part 2' } + ] + } + ] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + const toolMsg = result[1] + if (Array.isArray(toolMsg.content)) { + const toolResult = toolMsg.content[0] + if (toolResult.type === 'tool-result' && toolResult.output.type === 'content') { + expect(toolResult.output.value).toHaveLength(2) + expect(toolResult.output.value[0]).toEqual({ type: 'text', text: 'Result part 1' }) + expect(toolResult.output.value[1]).toEqual({ type: 'text', text: 'Result part 2' }) + } + } + }) + + it('should convert tool_result with image content', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'assistant', + content: [ + { + type: 'tool_use', + id: 'call_789', + name: 'screenshot', + input: {} + } + ] + }, + { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: 'call_789', + content: [ + { + type: 'image', + source: { + type: 'base64', + media_type: 'image/png', + data: 'abc123' + } + } + ] + } + ] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + const toolMsg = result[1] + if (Array.isArray(toolMsg.content)) { + const toolResult = toolMsg.content[0] + if (toolResult.type === 'tool-result' && toolResult.output.type === 'content') { + expect(toolResult.output.value).toHaveLength(1) + const media = toolResult.output.value[0] + if (media.type === 'media') { + expect(media.data).toBe('abc123') + expect(media.mediaType).toBe('image/png') + } + } + } + }) + + it('should handle multiple tool calls', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'assistant', + content: [ + { + type: 'tool_use', + id: 'call_1', + name: 'tool1', + input: {} + }, + { + type: 'tool_use', + id: 'call_2', + name: 'tool2', + input: {} + } + ] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + if (Array.isArray(result[0].content)) { + expect(result[0].content).toHaveLength(2) + expect(result[0].content[0].type).toBe('tool-call') + expect(result[0].content[1].type).toBe('tool-call') + } + }) + }) + + describe('Thinking Content', () => { + it('should convert thinking block to reasoning', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'assistant', + content: [ + { + type: 'thinking', + thinking: 'Let me analyze this...', + signature: 'sig123' + }, + { + type: 'text', + text: 'Here is my answer' + } + ] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + if (Array.isArray(result[0].content)) { + expect(result[0].content).toHaveLength(2) + const reasoning = result[0].content[0] + if (reasoning.type === 'reasoning') { + expect(reasoning.text).toBe('Let me analyze this...') + } + const text = result[0].content[1] + if (text.type === 'text') { + expect(text.text).toBe('Here is my answer') + } + } + }) + + it('should convert redacted_thinking to reasoning', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'assistant', + content: [ + { + type: 'redacted_thinking', + data: '[Redacted]' + } + ] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + if (Array.isArray(result[0].content)) { + expect(result[0].content).toHaveLength(1) + const reasoning = result[0].content[0] + if (reasoning.type === 'reasoning') { + expect(reasoning.text).toBe('[Redacted]') + } + } + }) + }) + + describe('Multi-turn Conversations', () => { + it('should handle complete conversation flow', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + system: 'You are a helpful assistant.', + messages: [ + { + role: 'user', + content: 'What is the weather in SF?' + }, + { + role: 'assistant', + content: [ + { + type: 'tool_use', + id: 'weather_call', + name: 'get_weather', + input: { location: 'SF' } + } + ] + }, + { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: 'weather_call', + content: '72°F and sunny' + } + ] + }, + { + role: 'assistant', + content: 'The weather in San Francisco is 72°F and sunny.' + } + ] + } + + const result = convertAnthropicToAiMessages(params) + expect(result).toHaveLength(5) + expect(result[0].role).toBe('system') + expect(result[1].role).toBe('user') + expect(result[2].role).toBe('assistant') + expect(result[3].role).toBe('tool') + expect(result[4].role).toBe('assistant') + }) + }) + + describe('Edge Cases', () => { + it('should handle empty content array for user', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'user', + content: [] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + expect(result).toHaveLength(0) + }) + + it('should handle empty content array for assistant', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'assistant', + content: [] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + expect(result).toHaveLength(0) + }) + + it('should handle tool_result without matching tool_use', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: 'unknown_call', + content: 'Some result' + } + ] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + expect(result).toHaveLength(1) + if (Array.isArray(result[0].content)) { + const toolResult = result[0].content[0] + if (toolResult.type === 'tool-result') { + expect(toolResult.toolName).toBe('unknown') + } + } + }) + + it('should handle tool_result with empty content', () => { + const params: MessageCreateParams = { + model: 'claude-3-5-sonnet-20241022', + max_tokens: 1024, + messages: [ + { + role: 'assistant', + content: [ + { + type: 'tool_use', + id: 'call_empty', + name: 'empty_tool', + input: {} + } + ] + }, + { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: 'call_empty' + } + ] + } + ] + } + + const result = convertAnthropicToAiMessages(params) + const toolMsg = result[1] + if (Array.isArray(toolMsg.content)) { + const toolResult = toolMsg.content[0] + if (toolResult.type === 'tool-result' && toolResult.output.type === 'text') { + expect(toolResult.output.value).toBe('') + } + } + }) + }) + }) +}) diff --git a/src/main/apiServer/services/messages.ts b/src/main/apiServer/services/messages.ts index 8b46deaa8..957c06652 100644 --- a/src/main/apiServer/services/messages.ts +++ b/src/main/apiServer/services/messages.ts @@ -2,8 +2,10 @@ import type Anthropic from '@anthropic-ai/sdk' import type { MessageCreateParams, MessageStreamEvent } from '@anthropic-ai/sdk/resources' import { loggerService } from '@logger' import anthropicService from '@main/services/AnthropicService' -import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic' +import { buildClaudeCodeSystemMessage, getSdkClient, sanitizeToolsForAnthropic } from '@shared/anthropic' import type { Provider } from '@types' +import { APICallError, RetryError } from 'ai' +import { net } from 'electron' import type { Response } from 'express' const logger = loggerService.withContext('MessagesService') @@ -98,11 +100,30 @@ export class MessagesService { async getClient(provider: Provider, extraHeaders?: Record): Promise { // Create Anthropic client for the provider + // Wrap net.fetch to handle compatibility issues: + // 1. net.fetch expects string URLs, not Request objects + // 2. net.fetch doesn't support 'agent' option from Node.js http module + const electronFetch: typeof globalThis.fetch = async (input: URL | RequestInfo, init?: RequestInit) => { + const url = typeof input === 'string' ? input : input instanceof URL ? input.toString() : input.url + // Remove unsupported options for Electron's net.fetch + if (init) { + const initWithAgent = init as RequestInit & { agent?: unknown } + delete initWithAgent.agent + const headers = new Headers(initWithAgent.headers) + if (headers.has('content-length')) { + headers.delete('content-length') + } + initWithAgent.headers = headers + return net.fetch(url, initWithAgent) + } + return net.fetch(url) + } + const context = { fetch: electronFetch } if (provider.authType === 'oauth') { const oauthToken = await anthropicService.getValidAccessToken() - return getSdkClient(provider, oauthToken, extraHeaders) + return getSdkClient(provider, oauthToken, extraHeaders, context) } - return getSdkClient(provider, null, extraHeaders) + return getSdkClient(provider, null, extraHeaders, context) } prepareHeaders(headers: Record): Record { @@ -127,7 +148,8 @@ export class MessagesService { createAnthropicRequest(request: MessageCreateParams, provider: Provider, modelId?: string): MessageCreateParams { const anthropicRequest: MessageCreateParams = { ...request, - stream: !!request.stream + stream: !!request.stream, + tools: sanitizeToolsForAnthropic(request.tools) } // Override model if provided @@ -233,9 +255,71 @@ export class MessagesService { } transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } { - let statusCode = 500 - let errorType = 'api_error' - let errorMessage = 'Internal server error' + let statusCode: number | undefined = undefined + let errorType: string | undefined = undefined + let errorMessage: string | undefined = undefined + + const errorMap: Record = { + 400: 'invalid_request_error', + 401: 'authentication_error', + 403: 'forbidden_error', + 404: 'not_found_error', + 429: 'rate_limit_error', + 500: 'internal_server_error' + } + + // Handle AI SDK RetryError - extract the last error for better error messages + if (RetryError.isInstance(error)) { + const lastError = error.lastError + // If the last error is an APICallError, extract its details + if (APICallError.isInstance(lastError)) { + statusCode = lastError.statusCode || 502 + errorMessage = lastError.message + return { + statusCode, + errorResponse: { + type: 'error', + error: { + type: errorMap[statusCode] || 'api_error', + message: `${error.reason}: ${errorMessage}`, + requestId: lastError.name + } + } + } + } + // Fallback for other retry errors + errorMessage = error.message + statusCode = 502 + return { + statusCode, + errorResponse: { + type: 'error', + error: { + type: 'api_error', + message: errorMessage, + requestId: error.name + } + } + } + } + + if (APICallError.isInstance(error)) { + statusCode = error.statusCode + errorMessage = error.message + if (statusCode) { + return { + statusCode, + errorResponse: { + type: 'error', + error: { + type: errorMap[statusCode] || 'api_error', + message: errorMessage, + requestId: error.name + } + } + } + } + } const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined const anthropicError = error?.error @@ -277,11 +361,11 @@ export class MessagesService { typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error' return { - statusCode, + statusCode: statusCode ?? 500, errorResponse: { type: 'error', error: { - type: errorType, + type: errorType || 'api_error', message: safeErrorMessage, requestId: error?.request_id } diff --git a/src/main/apiServer/services/models.ts b/src/main/apiServer/services/models.ts index 52f0db857..b72c21b1e 100644 --- a/src/main/apiServer/services/models.ts +++ b/src/main/apiServer/services/models.ts @@ -1,13 +1,6 @@ -import { isEmpty } from 'lodash' - import type { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels' import { loggerService } from '../../services/LoggerService' -import { - getAvailableProviders, - getProviderAnthropicModelChecker, - listAllAvailableModels, - transformModelToOpenAI -} from '../utils' +import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils' const logger = loggerService.withContext('ModelsService') @@ -20,11 +13,12 @@ export class ModelsService { try { logger.debug('Getting available models from providers', { filter }) - let providers = await getAvailableProviders() + const providers = await getAvailableProviders() - if (filter.providerType === 'anthropic') { - providers = providers.filter((p) => p.type === 'anthropic' || !isEmpty(p.anthropicApiHost?.trim())) - } + // Note: When providerType === 'anthropic', we now return ALL available models + // because the API Server's unified adapter (AiSdkToAnthropicSSE) can convert + // any provider's response to Anthropic SSE format. This enables Claude Code Agent + // to work with OpenAI, Gemini, and other providers transparently. const models = await listAllAvailableModels(providers) // Use Map to deduplicate models by their full ID (provider:model_id) @@ -32,20 +26,11 @@ export class ModelsService { for (const model of models) { const provider = providers.find((p) => p.id === model.provider) - // logger.debug(`Processing model ${model.id}`) if (!provider) { logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`) continue } - if (filter.providerType === 'anthropic') { - const checker = getProviderAnthropicModelChecker(provider.id) - if (!checker(model)) { - logger.debug(`Skipping model ${model.id} from ${model.provider}. Reason: Not an Anthropic model.`) - continue - } - } - const openAIModel = transformModelToOpenAI(model, provider) const fullModelId = openAIModel.id // This is already in format "provider:model_id" diff --git a/src/main/apiServer/services/reasoning-cache.ts b/src/main/apiServer/services/reasoning-cache.ts new file mode 100644 index 000000000..eb39e691d --- /dev/null +++ b/src/main/apiServer/services/reasoning-cache.ts @@ -0,0 +1,45 @@ +/** + * Reasoning Cache Service + * + * Manages reasoning-related caching for AI providers that support thinking/reasoning modes. + * This includes Google Gemini's thought signatures and OpenRouter's reasoning details. + */ + +import type { ReasoningDetailUnion } from '@main/apiServer/adapters/openrouter' +import { CacheService } from '@main/services/CacheService' + +/** + * Interface for reasoning cache + */ +export interface IReasoningCache { + set(key: string, value: T): void + get(key: string): T | undefined +} + +/** + * Cache duration: 30 minutes + * Reasoning data is typically only needed within a short conversation context + */ +const REASONING_CACHE_DURATION = 30 * 60 * 1000 + +/** + * Google Gemini reasoning cache + * + * Stores thought signatures for Gemini 3 models to handle multi-turn conversations + * where the model needs to maintain thinking context across tool calls. + */ +export const googleReasoningCache: IReasoningCache = { + set: (key, value) => CacheService.set(`google-reasoning:${key}`, value, REASONING_CACHE_DURATION), + get: (key) => CacheService.get(`google-reasoning:${key}`) || undefined +} + +/** + * OpenRouter reasoning cache + * + * Stores reasoning details from OpenRouter responses to preserve thinking tokens + * and reasoning metadata across the conversation flow. + */ +export const openRouterReasoningCache: IReasoningCache = { + set: (key, value) => CacheService.set(`openrouter-reasoning:${key}`, value, REASONING_CACHE_DURATION), + get: (key) => CacheService.get(`openrouter-reasoning:${key}`) || undefined +} diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts new file mode 100644 index 000000000..b9c306b2f --- /dev/null +++ b/src/main/apiServer/services/unified-messages.ts @@ -0,0 +1,764 @@ +import type { AnthropicProviderOptions } from '@ai-sdk/anthropic' +import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google' +import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai' +import type { JSONSchema7, LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider' +import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' +import type { + ImageBlockParam, + MessageCreateParams, + TextBlockParam, + Tool as AnthropicTool +} from '@anthropic-ai/sdk/resources/messages' +import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core' +import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' +import { loggerService } from '@logger' +import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@main/apiServer/adapters' +import { generateSignature as cherryaiGenerateSignature } from '@main/integration/cherryai' +import anthropicService from '@main/services/AnthropicService' +import copilotService from '@main/services/CopilotService' +import { reduxService } from '@main/services/ReduxService' +import type { OpenRouterProviderOptions } from '@openrouter/ai-sdk-provider' +import { isGemini3ModelId } from '@shared/ai-sdk-middlewares' +import { + type AiSdkConfig, + type AiSdkConfigContext, + formatProviderApiHost, + initializeSharedProviders, + isAnthropicProvider, + isGeminiProvider, + isOpenAIProvider, + type MinimalProvider, + type ProviderFormatContext, + providerToAiSdkConfig as sharedProviderToAiSdkConfig, + resolveActualProvider, + SystemProviderIds +} from '@shared/provider' +import { COPILOT_DEFAULT_HEADERS } from '@shared/provider/constant' +import { defaultAppHeaders } from '@shared/utils' +import type { Provider } from '@types' +import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool as AiSdkTool } from 'ai' +import { simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel, zodSchema } from 'ai' +import { net } from 'electron' +import type { Response } from 'express' +import * as z from 'zod' + +import { googleReasoningCache, openRouterReasoningCache } from './reasoning-cache' + +const logger = loggerService.withContext('UnifiedMessagesService') + +const MAGIC_STRING = 'skip_thought_signature_validator' + +function sanitizeJson(value: unknown): JSONValue { + return JSON.parse(JSON.stringify(value)) +} + +initializeSharedProviders({ + warn: (message) => logger.warn(message), + error: (message, error) => logger.error(message, error) +}) + +/** + * Configuration for unified message streaming + */ +export interface UnifiedStreamConfig { + response: Response + provider: Provider + modelId: string + params: MessageCreateParams + onError?: (error: unknown) => void + onComplete?: () => void + /** + * Optional AI SDK middlewares to apply + */ + middlewares?: LanguageModelV2Middleware[] + /** + * Optional AI Core plugins to use with the executor + */ + plugins?: AiPlugin[] +} + +/** + * Configuration for non-streaming message generation + */ +export interface GenerateUnifiedMessageConfig { + provider: Provider + modelId: string + params: MessageCreateParams + middlewares?: LanguageModelV2Middleware[] + plugins?: AiPlugin[] +} + +function getMainProcessFormatContext(): ProviderFormatContext { + const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai') + return { + vertex: { + project: vertexSettings?.projectId || 'default-project', + location: vertexSettings?.location || 'us-central1' + } + } +} + +function isSupportStreamOptionsProvider(provider: MinimalProvider): boolean { + const NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS = ['mistral'] as const + return !NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS.some((pid) => pid === provider.id) +} + +const mainProcessSdkContext: AiSdkConfigContext = { + isSupportStreamOptionsProvider, + getIncludeUsageSetting: () => + reduxService.selectSync('state.settings.openAI?.streamOptions?.includeUsage'), + fetch: net.fetch as typeof globalThis.fetch +} + +function getActualProvider(provider: Provider, modelId: string): Provider { + const model = provider.models?.find((m) => m.id === modelId) + if (!model) return provider + return resolveActualProvider(provider, model) +} + +function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig { + const actualProvider = getActualProvider(provider, modelId) + const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext()) + return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext) +} + +function convertAnthropicToolResultToAiSdk( + content: string | Array +): LanguageModelV2ToolResultOutput { + if (typeof content === 'string') { + return { type: 'text', value: content } + } + const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = [] + for (const block of content) { + if (block.type === 'text') { + values.push({ type: 'text', text: block.text }) + } else if (block.type === 'image') { + values.push({ + type: 'media', + data: block.source.type === 'base64' ? block.source.data : block.source.url, + mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png' + }) + } + } + return { type: 'content', value: values } +} + +/** + * JSON Schema type for tool input schemas + */ +export type JsonSchemaLike = JSONSchema7 + +/** + * Convert JSON Schema to Zod schema + * This avoids non-standard fields like input_examples that Anthropic doesn't support + * TODO: Anthropic/beta support input_examples + */ +export function jsonSchemaToZod(schema: JsonSchemaLike): z.ZodTypeAny { + const schemaType = schema.type + const enumValues = schema.enum + const description = schema.description + + // Handle enum first + if (enumValues && Array.isArray(enumValues) && enumValues.length > 0) { + if (enumValues.every((v) => typeof v === 'string')) { + const zodEnum = z.enum(enumValues as [string, ...string[]]) + return description ? zodEnum.describe(description) : zodEnum + } + // For non-string enums, use union of literals + const literals = enumValues.map((v) => z.literal(v as string | number | boolean)) + if (literals.length === 1) { + return description ? literals[0].describe(description) : literals[0] + } + const zodUnion = z.union(literals as unknown as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]]) + return description ? zodUnion.describe(description) : zodUnion + } + + // Handle union types (type: ["string", "null"]) + if (Array.isArray(schemaType)) { + const schemas = schemaType.map((t) => + jsonSchemaToZod({ + ...schema, + type: t, + enum: undefined + }) + ) + if (schemas.length === 1) { + return schemas[0] + } + return z.union(schemas as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]]) + } + + // Handle by type + switch (schemaType) { + case 'string': { + let zodString = z.string() + if (typeof schema.minLength === 'number') zodString = zodString.min(schema.minLength) + if (typeof schema.maxLength === 'number') zodString = zodString.max(schema.maxLength) + if (typeof schema.pattern === 'string') zodString = zodString.regex(new RegExp(schema.pattern)) + return description ? zodString.describe(description) : zodString + } + + case 'number': + case 'integer': { + let zodNumber = schemaType === 'integer' ? z.number().int() : z.number() + if (typeof schema.minimum === 'number') zodNumber = zodNumber.min(schema.minimum) + if (typeof schema.maximum === 'number') zodNumber = zodNumber.max(schema.maximum) + return description ? zodNumber.describe(description) : zodNumber + } + + case 'boolean': { + const zodBoolean = z.boolean() + return description ? zodBoolean.describe(description) : zodBoolean + } + + case 'null': + return z.null() + + case 'array': { + const items = schema.items + let zodArray: z.ZodArray + if (items && typeof items === 'object' && !Array.isArray(items)) { + zodArray = z.array(jsonSchemaToZod(items as JsonSchemaLike)) + } else { + zodArray = z.array(z.unknown()) + } + if (typeof schema.minItems === 'number') zodArray = zodArray.min(schema.minItems) + if (typeof schema.maxItems === 'number') zodArray = zodArray.max(schema.maxItems) + return description ? zodArray.describe(description) : zodArray + } + + case 'object': { + const properties = schema.properties + const required = schema.required || [] + + // Always use z.object() to ensure "properties" field is present in output schema + // OpenAI requires explicit properties field even for empty objects + const shape: Record = {} + if (properties && typeof properties === 'object') { + for (const [key, propSchema] of Object.entries(properties)) { + if (typeof propSchema === 'boolean') { + shape[key] = propSchema ? z.unknown() : z.never() + } else { + const zodProp = jsonSchemaToZod(propSchema as JsonSchemaLike) + shape[key] = required.includes(key) ? zodProp : zodProp.optional() + } + } + } + + const zodObject = z.object(shape) + return description ? zodObject.describe(description) : zodObject + } + + default: + // Unknown type, use z.unknown() + return z.unknown() + } +} + +export function convertAnthropicToolsToAiSdk( + tools: MessageCreateParams['tools'] +): Record | undefined { + if (!tools || tools.length === 0) return undefined + + const aiSdkTools: Record = {} + for (const anthropicTool of tools) { + if (anthropicTool.type === 'bash_20250124') continue + const toolDef = anthropicTool as AnthropicTool + const rawSchema = toolDef.input_schema + // Convert Anthropic's InputSchema to JSONSchema7-compatible format + const schema = jsonSchemaToZod(rawSchema as JsonSchemaLike) + + // Use tool() with inputSchema (AI SDK v5 API) + const aiTool = tool({ + description: toolDef.description || '', + inputSchema: zodSchema(schema) + }) + + aiSdkTools[toolDef.name] = aiTool + } + return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined +} + +export function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] { + const messages: ModelMessage[] = [] + + // System message + if (params.system) { + if (typeof params.system === 'string') { + messages.push({ role: 'system', content: params.system }) + } else if (Array.isArray(params.system)) { + const systemText = params.system + .filter((block) => block.type === 'text') + .map((block) => block.text) + .join('\n') + if (systemText) { + messages.push({ role: 'system', content: systemText }) + } + } + } + + const toolCallIdToName = new Map() + for (const msg of params.messages) { + if (Array.isArray(msg.content)) { + for (const block of msg.content) { + if (block.type === 'tool_use') { + toolCallIdToName.set(block.id, block.name) + } + } + } + } + + // User/assistant messages + for (const msg of params.messages) { + if (typeof msg.content === 'string') { + messages.push({ + role: msg.role === 'user' ? 'user' : 'assistant', + content: msg.content + }) + } else if (Array.isArray(msg.content)) { + const textParts: TextPart[] = [] + const imageParts: ImagePart[] = [] + const reasoningParts: ReasoningPart[] = [] + const toolCallParts: ToolCallPart[] = [] + const toolResultParts: ToolResultPart[] = [] + + for (const block of msg.content) { + if (block.type === 'text') { + textParts.push({ type: 'text', text: block.text }) + } else if (block.type === 'thinking') { + reasoningParts.push({ type: 'reasoning', text: block.thinking }) + } else if (block.type === 'redacted_thinking') { + reasoningParts.push({ type: 'reasoning', text: block.data }) + } else if (block.type === 'image') { + const source = block.source + if (source.type === 'base64') { + imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` }) + } else if (source.type === 'url') { + imageParts.push({ type: 'image', image: source.url }) + } + } else if (block.type === 'tool_use') { + const options: ProviderOptions = {} + logger.debug('Processing tool call block', { block, msgRole: msg.role, model: params.model }) + if (isGemini3ModelId(params.model)) { + if (googleReasoningCache.get(`google-${block.name}`)) { + options.google = { + thoughtSignature: MAGIC_STRING + } + } + } + if (openRouterReasoningCache.get(`openrouter-${block.id}`)) { + options.openrouter = { + reasoning_details: + (sanitizeJson(openRouterReasoningCache.get(`openrouter-${block.id}`)) as JSONValue[]) || [] + } + } + toolCallParts.push({ + type: 'tool-call', + toolName: block.name, + toolCallId: block.id, + input: block.input, + providerOptions: options + }) + } else if (block.type === 'tool_result') { + // Look up toolName from the pre-built map (covers cross-message references) + const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown' + toolResultParts.push({ + type: 'tool-result', + toolCallId: block.tool_use_id, + toolName, + output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' } + }) + } + } + + if (toolResultParts.length > 0) { + messages.push({ role: 'tool', content: [...toolResultParts] }) + } + + if (msg.role === 'user') { + const userContent = [...textParts, ...imageParts] + if (userContent.length > 0) { + messages.push({ role: 'user', content: userContent }) + } + } else { + const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts] + if (assistantContent.length > 0) { + let providerOptions: ProviderOptions | undefined = undefined + if (openRouterReasoningCache.get('openrouter')) { + providerOptions = { + openrouter: { + reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || [] + } + } + } else if (isGemini3ModelId(params.model)) { + providerOptions = { + google: { + thoughtSignature: MAGIC_STRING + } + } + } + messages.push({ role: 'assistant', content: assistantContent, providerOptions }) + } + } + } + } + + return messages +} + +interface ExecuteStreamConfig { + provider: Provider + modelId: string + params: MessageCreateParams + middlewares?: LanguageModelV2Middleware[] + plugins?: AiPlugin[] + onEvent?: (event: Parameters[0]) => void +} + +/** + * Create AI SDK provider instance from config + * Similar to renderer's createAiSdkProvider + */ +async function createAiSdkProvider(config: AiSdkConfig): Promise { + let providerId = config.providerId + + // Handle special provider modes (same as renderer) + if (providerId === 'openai' && config.options?.mode === 'chat') { + providerId = 'openai-chat' + } else if (providerId === 'azure' && config.options?.mode === 'responses') { + providerId = 'azure-responses' + } else if (providerId === 'cherryin' && config.options?.mode === 'chat') { + providerId = 'cherryin-chat' + } + + const provider = await createProviderCore(providerId, config.options) + + return provider +} + +/** + * Prepare special provider configuration for providers that need dynamic tokens + * Similar to renderer's prepareSpecialProviderConfig + */ +async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkConfig): Promise { + switch (provider.id) { + case 'copilot': { + const storedHeaders = + ((await reduxService.select('state.copilot.defaultHeaders')) as Record | null) ?? {} + const headers: Record = { + ...COPILOT_DEFAULT_HEADERS, + ...storedHeaders + } + + try { + const { token } = await copilotService.getToken(null as any, headers) + config.options.apiKey = token + const existingHeaders = (config.options.headers as Record | undefined) ?? {} + config.options.headers = { + ...headers, + ...existingHeaders + } + } catch (error) { + logger.error('Failed to get Copilot token', error as Error) + throw new Error('Failed to get Copilot token. Please re-authorize Copilot.') + } + break + } + case 'anthropic': { + if (provider.authType === 'oauth') { + try { + const oauthToken = await anthropicService.getValidAccessToken() + if (!oauthToken) { + throw new Error('Anthropic OAuth token not available. Please re-authorize.') + } + config.options = { + ...config.options, + headers: { + ...(config.options.headers ? config.options.headers : {}), + 'Content-Type': 'application/json', + 'anthropic-version': '2023-06-01', + 'anthropic-beta': 'oauth-2025-04-20', + Authorization: `Bearer ${oauthToken}` + }, + baseURL: 'https://api.anthropic.com/v1', + apiKey: '' + } + } catch (error) { + logger.error('Failed to get Anthropic OAuth token', error as Error) + throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.') + } + } + break + } + case 'cherryai': { + // Create a signed fetch wrapper for cherryai + const baseFetch = net.fetch as typeof globalThis.fetch + config.options.fetch = async (url: RequestInfo | URL, options?: RequestInit) => { + if (!options?.body) { + return baseFetch(url, options) + } + const signature = cherryaiGenerateSignature({ + method: 'POST', + path: '/chat/completions', + query: '', + body: JSON.parse(options.body as string) + }) + return baseFetch(url, { + ...options, + headers: { + ...(options.headers as Record), + ...signature + } + }) + } + break + } + } + return config +} + +function mapAnthropicThinkToAISdkProviderOptions( + provider: Provider, + config: MessageCreateParams['thinking'] +): ProviderOptions | undefined { + if (!config) return undefined + if (isAnthropicProvider(provider)) { + return { + anthropic: { + ...mapToAnthropicProviderOptions(config) + } + } + } + if (isGeminiProvider(provider)) { + return { + google: { + ...mapToGeminiProviderOptions(config) + } + } + } + if (isOpenAIProvider(provider)) { + return { + openai: { + ...mapToOpenAIProviderOptions(config) + } + } + } + if (provider.id === SystemProviderIds.openrouter) { + return { + openrouter: { + ...mapToOpenRouterProviderOptions(config) + } + } + } + return undefined +} + +function mapToAnthropicProviderOptions(config: NonNullable): AnthropicProviderOptions { + return { + thinking: { + type: config.type, + budgetTokens: config.type === 'enabled' ? config.budget_tokens : undefined + } + } +} + +function mapToGeminiProviderOptions( + config: NonNullable +): GoogleGenerativeAIProviderOptions { + return { + thinkingConfig: { + thinkingBudget: config.type === 'enabled' ? config.budget_tokens : -1, + includeThoughts: config.type === 'enabled' + } + } +} + +function mapToOpenAIProviderOptions( + config: NonNullable +): OpenAIResponsesProviderOptions { + return { + reasoningEffort: config.type === 'enabled' ? 'high' : 'none' + } +} + +function mapToOpenRouterProviderOptions( + config: NonNullable +): OpenRouterProviderOptions { + return { + reasoning: { + enabled: config.type === 'enabled', + effort: 'high' + } + } +} + +/** + * Core stream execution function - single source of truth for AI SDK calls + */ +async function executeStream(config: ExecuteStreamConfig): Promise { + const { provider, modelId, params, middlewares = [], plugins = [], onEvent } = config + + // Convert provider config to AI SDK config + let sdkConfig = providerToAiSdkConfig(provider, modelId) + + // Prepare special provider config (Copilot, Anthropic OAuth, etc.) + sdkConfig = await prepareSpecialProviderConfig(provider, sdkConfig) + + // Create provider instance and get language model + const aiSdkProvider = await createAiSdkProvider(sdkConfig) + const baseModel = aiSdkProvider.languageModel(modelId) + + // Apply middlewares if present + const model = + middlewares.length > 0 && typeof baseModel === 'object' + ? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel) + : baseModel + + // Create executor with plugins + const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins) + + // Convert messages and tools + const coreMessages = convertAnthropicToAiMessages(params) + const tools = convertAnthropicToolsToAiSdk(params.tools) + + // Create the adapter + const adapter = new AiSdkToAnthropicSSE({ + model: `${provider.id}:${modelId}`, + onEvent: onEvent || (() => {}) + }) + + const result = await executor.streamText({ + model, + messages: coreMessages, + // FIXME: Claude Code传入的maxToken会超出有些模型限制,需做特殊处理,可能在v2好修复一点,现在维护的成本有点高 + // 已知: 豆包 + maxOutputTokens: params.max_tokens, + temperature: params.temperature, + topP: params.top_p, + topK: params.top_k, + stopSequences: params.stop_sequences, + stopWhen: stepCountIs(100), + headers: defaultAppHeaders(), + tools, + providerOptions: mapAnthropicThinkToAISdkProviderOptions(provider, params.thinking) + }) + + // Process the stream through the adapter + await adapter.processStream(result.fullStream) + + return adapter +} + +/** + * Stream a message request using AI SDK executor and convert to Anthropic SSE format + */ +export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise { + const { response, provider, modelId, params, onError, onComplete, middlewares = [], plugins = [] } = config + + logger.info('Starting unified message stream', { + providerId: provider.id, + providerType: provider.type, + modelId, + stream: params.stream, + middlewareCount: middlewares.length, + pluginCount: plugins.length + }) + + try { + response.setHeader('Content-Type', 'text/event-stream') + response.setHeader('Cache-Control', 'no-cache') + response.setHeader('Connection', 'keep-alive') + response.setHeader('X-Accel-Buffering', 'no') + + await executeStream({ + provider, + modelId, + params, + middlewares, + plugins, + onEvent: (event) => { + logger.silly('Streaming event', { eventType: event.type }) + const sseData = formatSSEEvent(event) + response.write(sseData) + } + }) + + // Send done marker + response.write(formatSSEDone()) + response.end() + + logger.info('Unified message stream completed', { providerId: provider.id, modelId }) + onComplete?.() + } catch (error) { + logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId }) + onError?.(error) + throw error + } +} + +/** + * Generate a non-streaming message response + * + * Uses simulateStreamingMiddleware to reuse the same streaming logic, + * similar to renderer's ModernAiProvider pattern. + */ +export async function generateUnifiedMessage( + providerOrConfig: Provider | GenerateUnifiedMessageConfig, + modelId?: string, + params?: MessageCreateParams +): Promise> { + // Support both old signature and new config-based signature + let config: GenerateUnifiedMessageConfig + if ('provider' in providerOrConfig && 'modelId' in providerOrConfig && 'params' in providerOrConfig) { + config = providerOrConfig + } else { + config = { + provider: providerOrConfig as Provider, + modelId: modelId!, + params: params! + } + } + + const { provider, middlewares = [], plugins = [] } = config + + logger.info('Starting unified message generation', { + providerId: provider.id, + providerType: provider.type, + modelId: config.modelId, + middlewareCount: middlewares.length, + pluginCount: plugins.length + }) + + try { + // Add simulateStreamingMiddleware to reuse streaming logic for non-streaming + const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares] + + const adapter = await executeStream({ + provider, + modelId: config.modelId, + params: config.params, + middlewares: allMiddlewares, + plugins + }) + + const finalResponse = adapter.buildNonStreamingResponse() + + logger.info('Unified message generation completed', { + providerId: provider.id, + modelId: config.modelId + }) + + return finalResponse + } catch (error) { + logger.error('Error in unified message generation', error as Error, { + providerId: provider.id, + modelId: config.modelId + }) + throw error + } +} + +export default { + streamUnifiedMessages, + generateUnifiedMessage +} diff --git a/src/main/apiServer/utils/index.ts b/src/main/apiServer/utils/index.ts index e25b49e75..17d3f9f08 100644 --- a/src/main/apiServer/utils/index.ts +++ b/src/main/apiServer/utils/index.ts @@ -1,7 +1,7 @@ import { CacheService } from '@main/services/CacheService' import { loggerService } from '@main/services/LoggerService' import { reduxService } from '@main/services/ReduxService' -import { isSiliconAnthropicCompatibleModel } from '@shared/config/providers' +import { isPpioAnthropicCompatibleModel, isSiliconAnthropicCompatibleModel } from '@shared/config/providers' import type { ApiModel, Model, Provider } from '@types' const logger = loggerService.withContext('ApiServerUtils') @@ -28,10 +28,9 @@ export async function getAvailableProviders(): Promise { return [] } - // Support OpenAI and Anthropic type providers for API server - const supportedProviders = providers.filter( - (p: Provider) => p.enabled && (p.type === 'openai' || p.type === 'anthropic') - ) + // Support all provider types that AI SDK can handle + // The unified-messages service uses AI SDK which supports many providers + const supportedProviders = providers.filter((p: Provider) => p.enabled) // Cache the filtered results CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL) @@ -160,7 +159,7 @@ export async function validateModelId(model: string): Promise<{ valid: false, error: { type: 'provider_not_found', - message: `Provider '${providerId}' not found, not enabled, or not supported. Only OpenAI providers are currently supported.`, + message: `Provider '${providerId}' not found or not enabled.`, code: 'provider_not_found' } } @@ -262,14 +261,8 @@ export function validateProvider(provider: Provider): boolean { return false } - // Support OpenAI and Anthropic type providers - if (provider.type !== 'openai' && provider.type !== 'anthropic') { - logger.debug('Provider type not supported', { - providerId: provider.id, - providerType: provider.type - }) - return false - } + // AI SDK supports many provider types, no longer need to filter by type + // The unified-messages service handles all supported types return true } catch (error: any) { @@ -290,8 +283,39 @@ export const getProviderAnthropicModelChecker = (providerId: string): ((m: Model return (m: Model) => m.id.includes('claude') case 'silicon': return (m: Model) => isSiliconAnthropicCompatibleModel(m.id) + case 'ppio': + return (m: Model) => isPpioAnthropicCompatibleModel(m.id) default: // allow all models when checker not configured return () => true } } + +/** + * Check if a specific model is compatible with Anthropic API for a given provider. + * + * This is used for fine-grained routing decisions at the model level. + * For aggregated providers (like Silicon), only certain models support the Anthropic API endpoint. + * + * @param provider - The provider to check + * @param modelId - The model ID to check (without provider prefix) + * @returns true if the model supports Anthropic API endpoint + */ +export function isModelAnthropicCompatible(provider: Provider, modelId: string): boolean { + const checker = getProviderAnthropicModelChecker(provider.id) + + const model = provider.models?.find((m) => m.id === modelId) + + if (model) { + return checker(model) + } + + const minimalModel: Model = { + id: modelId, + name: modelId, + provider: provider.id, + group: '' + } + + return checker(minimalModel) +} diff --git a/src/main/services/agents/services/claudecode/claude-stream-state.ts b/src/main/services/agents/services/claudecode/claude-stream-state.ts index 30b5790c8..5266fda99 100644 --- a/src/main/services/agents/services/claudecode/claude-stream-state.ts +++ b/src/main/services/agents/services/claudecode/claude-stream-state.ts @@ -87,6 +87,7 @@ export class ClaudeStreamState { private pendingUsage: PendingUsageState = {} private pendingToolCalls = new Map() private stepActive = false + private _streamFinished = false constructor(options: ClaudeStreamStateOptions) { this.logger = loggerService.withContext('ClaudeStreamState') @@ -289,6 +290,16 @@ export class ClaudeStreamState { getNamespacedToolCallId(rawToolCallId: string): string { return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId) } + + /** Marks the stream as finished (either completed or errored). */ + markFinished(): void { + this._streamFinished = true + } + + /** Returns true if the stream has already emitted a terminal event. */ + isFinished(): boolean { + return this._streamFinished + } } export type { PendingToolCall } diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts index 45cecb049..50dd5a6d3 100644 --- a/src/main/services/agents/services/claudecode/index.ts +++ b/src/main/services/agents/services/claudecode/index.ts @@ -87,18 +87,14 @@ class ClaudeCodeService implements AgentServiceInterface { }) return aiStream } - if ( - (modelInfo.provider?.type !== 'anthropic' && - (modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) || - modelInfo.provider.apiKey === '' - ) { - logger.error('Anthropic provider configuration is missing', { - modelInfo - }) - + // Validate provider has required configuration + // Note: We no longer restrict to anthropic type only - the API Server's unified adapter + // handles format conversion for any provider type (OpenAI, Gemini, etc.) + if (!modelInfo.provider?.apiKey) { + logger.error('Provider API key is missing', { modelInfo }) aiStream.emit('data', { type: 'error', - error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`) + error: new Error(`Provider '${modelInfo.provider?.id}' is missing API key configuration.`) }) return aiStream } @@ -112,15 +108,14 @@ class ClaudeCodeService implements AgentServiceInterface { // Auto-discover Git Bash path on Windows (already logs internally) const customGitBashPath = isWin ? autoDiscoverGitBash() : null + // Route through local API Server which handles format conversion via unified adapter + // This enables Claude Code Agent to work with any provider (OpenAI, Gemini, etc.) + // The API Server converts AI SDK responses to Anthropic SSE format transparently const env = { ...loginShellEnvWithoutProxies, - // TODO: fix the proxy api server - // ANTHROPIC_API_KEY: apiConfig.apiKey, - // ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey, - // ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`, - ANTHROPIC_API_KEY: modelInfo.provider.apiKey, - ANTHROPIC_AUTH_TOKEN: modelInfo.provider.apiKey, - ANTHROPIC_BASE_URL: modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost, + ANTHROPIC_API_KEY: apiConfig.apiKey, + ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey, + ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`, ANTHROPIC_MODEL: modelInfo.modelId, ANTHROPIC_DEFAULT_OPUS_MODEL: modelInfo.modelId, ANTHROPIC_DEFAULT_SONNET_MODEL: modelInfo.modelId, @@ -545,6 +540,19 @@ class ClaudeCodeService implements AgentServiceInterface { return } + // Skip emitting error if stream already finished (error was handled via result message) + if (streamState.isFinished()) { + logger.debug('SDK process exited after stream finished, skipping duplicate error event', { + duration, + error: errorObj instanceof Error ? { name: errorObj.name, message: errorObj.message } : String(errorObj) + }) + // Still emit complete to signal stream end + stream.emit('data', { + type: 'complete' + }) + return + } + errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj)) const errorMessage = errorChunks.join('\n\n') logger.error('SDK query failed', { diff --git a/src/main/services/agents/services/claudecode/transform.ts b/src/main/services/agents/services/claudecode/transform.ts index 00be683ba..094076e50 100644 --- a/src/main/services/agents/services/claudecode/transform.ts +++ b/src/main/services/agents/services/claudecode/transform.ts @@ -121,7 +121,7 @@ export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state: case 'system': return handleSystemMessage(sdkMessage) case 'result': - return handleResultMessage(sdkMessage) + return handleResultMessage(sdkMessage, state) default: logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type }) return [] @@ -193,6 +193,30 @@ function handleAssistantMessage( } break } + case 'thinking': + case 'redacted_thinking': { + const thinkingText = block.type === 'thinking' ? block.thinking : block.data + if (thinkingText) { + const id = generateMessageId() + chunks.push({ + type: 'reasoning-start', + id, + providerMetadata + }) + chunks.push({ + type: 'reasoning-delta', + id, + text: thinkingText, + providerMetadata + }) + chunks.push({ + type: 'reasoning-end', + id, + providerMetadata + }) + } + break + } case 'tool_use': handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks) break @@ -445,7 +469,11 @@ function handleStreamEvent( case 'content_block_stop': { const block = state.closeBlock(event.index) if (!block) { - logger.warn('Received content_block_stop for unknown index', { index: event.index }) + // Some providers (e.g., Gemini) send content via assistant message before stream events, + // so the block may not exist in state. This is expected behavior, not an error. + logger.debug('Received content_block_stop for unknown index (may be from non-streaming content)', { + index: event.index + }) break } @@ -679,7 +707,13 @@ function handleSystemMessage(message: Extract): * Successful runs yield a `finish` frame with aggregated usage metrics, while * failures are surfaced as `error` frames. */ -function handleResultMessage(message: Extract): AgentStreamPart[] { +function handleResultMessage( + message: Extract, + state: ClaudeStreamState +): AgentStreamPart[] { + // Mark stream as finished to prevent duplicate error events when SDK process exits + state.markFinished() + const chunks: AgentStreamPart[] = [] let usage: LanguageModelUsage | undefined @@ -691,26 +725,33 @@ function handleResultMessage(message: Extract): } } - if (message.subtype === 'success') { - chunks.push({ - type: 'finish', - totalUsage: usage ?? emptyUsage, - finishReason: mapClaudeCodeFinishReason(message.subtype), - providerMetadata: { - ...sdkMessageToProviderMetadata(message), - usage: message.usage, - durationMs: message.duration_ms, - costUsd: message.total_cost_usd, - raw: message - } - } as AgentStreamPart) - } else { + chunks.push({ + type: 'finish', + totalUsage: usage ?? emptyUsage, + finishReason: mapClaudeCodeFinishReason(message.subtype), + providerMetadata: { + ...sdkMessageToProviderMetadata(message), + usage: message.usage, + durationMs: message.duration_ms, + costUsd: message.total_cost_usd, + raw: message + } + } as AgentStreamPart) + if (message.subtype !== 'success') { chunks.push({ type: 'error', error: { message: `${message.subtype}: Process failed after ${message.num_turns} turns` } } as AgentStreamPart) + } else { + if (message.is_error) { + const errorMatch = message.result.match(/\{.*\}/) + if (errorMatch) { + const errorDetail = JSON.parse(errorMatch[0]) + chunks.push(errorDetail) + } + } } return chunks } diff --git a/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts index fb371d9ae..3eaf2f0fb 100644 --- a/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/gemini/VertexAPIClient.ts @@ -24,7 +24,7 @@ export class VertexAPIClient extends GeminiAPIClient { this.anthropicVertexClient = new AnthropicVertexClient(provider) // 如果传入的是普通 Provider,转换为 VertexProvider if (isVertexProvider(provider)) { - this.vertexProvider = provider + this.vertexProvider = provider as VertexProvider } else { this.vertexProvider = createVertexProvider(provider) } diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index b2a796bd3..b1e2c0773 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -5,6 +5,7 @@ import type { MCPTool } from '@renderer/types' import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider' +import { openrouterReasoningMiddleware, skipGeminiThoughtSignatureMiddleware } from '@shared/ai-sdk-middlewares' import type { LanguageModelMiddleware } from 'ai' import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' @@ -12,9 +13,7 @@ import { getAiSdkProviderId } from '../provider/factory' import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' import { noThinkMiddleware } from './noThinkMiddleware' import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware' -import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware' import { qwenThinkingMiddleware } from './qwenThinkingMiddleware' -import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware' const logger = loggerService.withContext('AiSdkMiddlewareBuilder') diff --git a/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts b/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts deleted file mode 100644 index 9ef3df61e..000000000 --- a/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts +++ /dev/null @@ -1,50 +0,0 @@ -import type { LanguageModelV2StreamPart } from '@ai-sdk/provider' -import type { LanguageModelMiddleware } from 'ai' - -/** - * https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude - * - * @returns LanguageModelMiddleware - a middleware filter redacted block - */ -export function openrouterReasoningMiddleware(): LanguageModelMiddleware { - const REDACTED_BLOCK = '[REDACTED]' - return { - middlewareVersion: 'v2', - wrapGenerate: async ({ doGenerate }) => { - const { content, ...rest } = await doGenerate() - const modifiedContent = content.map((part) => { - if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) { - return { - ...part, - text: part.text.replace(REDACTED_BLOCK, '') - } - } - return part - }) - return { content: modifiedContent, ...rest } - }, - wrapStream: async ({ doStream }) => { - const { stream, ...rest } = await doStream() - return { - stream: stream.pipeThrough( - new TransformStream({ - transform( - chunk: LanguageModelV2StreamPart, - controller: TransformStreamDefaultController - ) { - if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) { - controller.enqueue({ - ...chunk, - delta: chunk.delta.replace(REDACTED_BLOCK, '') - }) - } else { - controller.enqueue(chunk) - } - } - }) - ), - ...rest - } - } - } -} diff --git a/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts b/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts deleted file mode 100644 index da318ea60..000000000 --- a/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts +++ /dev/null @@ -1,36 +0,0 @@ -import type { LanguageModelMiddleware } from 'ai' - -/** - * skip Gemini Thought Signature Middleware - * 由于多模型客户端请求的复杂性(可以中途切换其他模型),这里选择通过中间件方式添加跳过所有 Gemini3 思考签名 - * Due to the complexity of multi-model client requests (which can switch to other models mid-process), - * it was decided to add a skip for all Gemini3 thinking signatures via middleware. - * @param aiSdkId AI SDK Provider ID - * @returns LanguageModelMiddleware - */ -export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware { - const MAGIC_STRING = 'skip_thought_signature_validator' - return { - middlewareVersion: 'v2', - - transformParams: async ({ params }) => { - const transformedParams = { ...params } - // Process messages in prompt - if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) { - transformedParams.prompt = transformedParams.prompt.map((message) => { - if (typeof message.content !== 'string') { - for (const part of message.content) { - const googleOptions = part?.providerOptions?.[aiSdkId] - if (googleOptions?.thoughtSignature) { - googleOptions.thoughtSignature = MAGIC_STRING - } - } - } - return message - }) - } - - return transformedParams - } - } -} diff --git a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts index b1d8e34fc..3d1bc2ac3 100644 --- a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts +++ b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts @@ -37,7 +37,7 @@ vi.mock('@renderer/utils/api', () => ({ if (isSupportedAPIVersion === false) { return host // Return host as-is when isSupportedAPIVersion is false } - return `${host}/v1` // Default behavior when isSupportedAPIVersion is true + return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true }), routeToEndpoint: vi.fn((host) => ({ baseURL: host, @@ -46,6 +46,20 @@ vi.mock('@renderer/utils/api', () => ({ isWithTrailingSharp: vi.fn((host) => host?.endsWith('#') || false) })) +// Also mock @shared/utils/url since formatProviderApiHost uses it directly +vi.mock('@shared/utils/url', async (importOriginal) => { + const actual = (await importOriginal()) as any + return { + ...actual, + formatApiHost: vi.fn((host, isSupportedAPIVersion = true) => { + if (isSupportedAPIVersion === false) { + return host || '' // Return host as-is when isSupportedAPIVersion is false + } + return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true + }) + } +}) + vi.mock('@renderer/utils/provider', async (importOriginal) => { const actual = (await importOriginal()) as any return { @@ -78,8 +92,8 @@ vi.mock('@renderer/services/AssistantService', () => ({ import { getProviderByModel } from '@renderer/services/AssistantService' import type { Model, Provider } from '@renderer/types' -import { formatApiHost } from '@renderer/utils/api' import { isAzureOpenAIProvider, isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider' +import { formatApiHost } from '@shared/utils/url' import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants' import { getActualProvider, providerToAiSdkConfig } from '../providerConfig' @@ -96,6 +110,31 @@ const createWindowKeyv = () => { } } +/** + * 创建默认的 mock state,包含所有必需的字段 + */ +const createDefaultMockState = (overrides?: { + includeUsage?: boolean | undefined + copilotHeaders?: Record +}) => ({ + copilot: { defaultHeaders: overrides?.copilotHeaders ?? {} }, + settings: { + openAI: { + streamOptions: { + includeUsage: overrides?.includeUsage + } + } + }, + llm: { + settings: { + vertexai: { + projectId: '', + location: '' + } + } + } +}) + const createCopilotProvider = (): Provider => ({ id: 'copilot', type: 'openai', @@ -150,16 +189,7 @@ describe('Copilot responses routing', () => { ...(globalThis as any).window, keyv: createWindowKeyv() } - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: undefined - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState()) }) it('detects official GPT-5 Codex identifiers case-insensitively', () => { @@ -195,16 +225,7 @@ describe('CherryAI provider configuration', () => { ...(globalThis as any).window, keyv: createWindowKeyv() } - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: undefined - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState()) vi.clearAllMocks() }) @@ -276,16 +297,7 @@ describe('Perplexity provider configuration', () => { ...(globalThis as any).window, keyv: createWindowKeyv() } - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: undefined - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState()) vi.clearAllMocks() }) @@ -360,6 +372,7 @@ describe('Stream options includeUsage configuration', () => { ...(globalThis as any).window, keyv: createWindowKeyv() } + mockGetState.mockReturnValue(createDefaultMockState()) vi.clearAllMocks() }) @@ -374,16 +387,7 @@ describe('Stream options includeUsage configuration', () => { }) it('uses includeUsage from settings when undefined', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: undefined - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: undefined })) const provider = createOpenAIProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai')) @@ -392,16 +396,7 @@ describe('Stream options includeUsage configuration', () => { }) it('uses includeUsage from settings when set to true', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: true - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: true })) const provider = createOpenAIProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai')) @@ -410,16 +405,7 @@ describe('Stream options includeUsage configuration', () => { }) it('uses includeUsage from settings when set to false', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: false - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: false })) const provider = createOpenAIProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'openai')) @@ -428,16 +414,7 @@ describe('Stream options includeUsage configuration', () => { }) it('respects includeUsage setting for non-supporting providers', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: true - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: true })) const testProvider: Provider = { id: 'test', @@ -459,16 +436,7 @@ describe('Stream options includeUsage configuration', () => { }) it('uses includeUsage from settings for Copilot provider when set to false', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: false - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: false })) const provider = createCopilotProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot')) @@ -478,16 +446,7 @@ describe('Stream options includeUsage configuration', () => { }) it('uses includeUsage from settings for Copilot provider when set to true', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: true - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: true })) const provider = createCopilotProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot')) @@ -497,16 +456,7 @@ describe('Stream options includeUsage configuration', () => { }) it('uses includeUsage from settings for Copilot provider when undefined', () => { - mockGetState.mockReturnValue({ - copilot: { defaultHeaders: {} }, - settings: { - openAI: { - streamOptions: { - includeUsage: undefined - } - } - } - }) + mockGetState.mockReturnValue(createDefaultMockState({ includeUsage: undefined })) const provider = createCopilotProvider() const config = providerToAiSdkConfig(provider, createModel('gpt-4', 'GPT-4', 'copilot')) diff --git a/src/renderer/src/aiCore/provider/config/azure-anthropic.ts b/src/renderer/src/aiCore/provider/config/azure-anthropic.ts deleted file mode 100644 index c6cb52138..000000000 --- a/src/renderer/src/aiCore/provider/config/azure-anthropic.ts +++ /dev/null @@ -1,22 +0,0 @@ -import type { Provider } from '@renderer/types' - -import { provider2Provider, startsWith } from './helper' -import type { RuleSet } from './types' - -// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry -const AZURE_ANTHROPIC_RULES: RuleSet = { - rules: [ - { - match: startsWith('claude'), - provider: (provider: Provider) => ({ - ...provider, - type: 'anthropic', - apiHost: provider.apiHost + 'anthropic/v1', - id: 'azure-anthropic' - }) - } - ], - fallbackRule: (provider: Provider) => provider -} - -export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES) diff --git a/src/renderer/src/aiCore/provider/config/helper.ts b/src/renderer/src/aiCore/provider/config/helper.ts deleted file mode 100644 index 656911fc7..000000000 --- a/src/renderer/src/aiCore/provider/config/helper.ts +++ /dev/null @@ -1,22 +0,0 @@ -import type { Model, Provider } from '@renderer/types' - -import type { RuleSet } from './types' - -export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase()) -export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type - -/** - * 解析模型对应的Provider - * @param ruleSet 规则集对象 - * @param model 模型对象 - * @param provider 原始provider对象 - * @returns 解析出的provider对象 - */ -export function provider2Provider(ruleSet: RuleSet, model: Model, provider: Provider): Provider { - for (const rule of ruleSet.rules) { - if (rule.match(model)) { - return rule.provider(provider) - } - } - return ruleSet.fallbackRule(provider) -} diff --git a/src/renderer/src/aiCore/provider/config/index.ts b/src/renderer/src/aiCore/provider/config/index.ts index 2f51234ce..b1d57d5a1 100644 --- a/src/renderer/src/aiCore/provider/config/index.ts +++ b/src/renderer/src/aiCore/provider/config/index.ts @@ -1,3 +1,7 @@ -export { aihubmixProviderCreator } from './aihubmix' -export { newApiResolverCreator } from './newApi' -export { vertexAnthropicProviderCreator } from './vertext-anthropic' +// Re-export from shared config +export { + aihubmixProviderCreator, + azureAnthropicProviderCreator, + newApiResolverCreator, + vertexAnthropicProviderCreator +} from '@shared/provider/config' diff --git a/src/renderer/src/aiCore/provider/config/types.ts b/src/renderer/src/aiCore/provider/config/types.ts deleted file mode 100644 index f3938b84d..000000000 --- a/src/renderer/src/aiCore/provider/config/types.ts +++ /dev/null @@ -1,9 +0,0 @@ -import type { Model, Provider } from '@renderer/types' - -export interface RuleSet { - rules: Array<{ - match: (model: Model) => boolean - provider: (provider: Provider) => Provider - }> - fallbackRule: (provider: Provider) => Provider -} diff --git a/src/renderer/src/aiCore/provider/config/vertext-anthropic.ts b/src/renderer/src/aiCore/provider/config/vertext-anthropic.ts deleted file mode 100644 index 23c8b5185..000000000 --- a/src/renderer/src/aiCore/provider/config/vertext-anthropic.ts +++ /dev/null @@ -1,19 +0,0 @@ -import type { Provider } from '@renderer/types' - -import { provider2Provider, startsWith } from './helper' -import type { RuleSet } from './types' - -const VERTEX_ANTHROPIC_RULES: RuleSet = { - rules: [ - { - match: startsWith('claude'), - provider: (provider: Provider) => ({ - ...provider, - id: 'google-vertex-anthropic' - }) - } - ], - fallbackRule: (provider: Provider) => provider -} - -export const vertexAnthropicProviderCreator = provider2Provider.bind(null, VERTEX_ANTHROPIC_RULES) diff --git a/src/renderer/src/aiCore/provider/constants.ts b/src/renderer/src/aiCore/provider/constants.ts index c7cd90bd9..57dad9fbc 100644 --- a/src/renderer/src/aiCore/provider/constants.ts +++ b/src/renderer/src/aiCore/provider/constants.ts @@ -1,25 +1 @@ -import type { Model } from '@renderer/types' - -export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1' -export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7' -export const COPILOT_INTEGRATION_ID = 'vscode-chat' -export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7' - -export const COPILOT_DEFAULT_HEADERS = { - 'Copilot-Integration-Id': COPILOT_INTEGRATION_ID, - 'User-Agent': COPILOT_USER_AGENT, - 'Editor-Version': COPILOT_EDITOR_VERSION, - 'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION, - 'editor-version': COPILOT_EDITOR_VERSION, - 'editor-plugin-version': COPILOT_PLUGIN_VERSION, - 'copilot-vision-request': 'true' -} as const - -// Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560) -const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex'] - -export function isCopilotResponsesModel(model: Model): boolean { - const normalizedId = model.id?.trim().toLowerCase() - const normalizedName = model.name?.trim().toLowerCase() - return COPILOT_RESPONSES_MODEL_IDS.some((target) => normalizedId === target || normalizedName === target) -} +export { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '@shared/provider/constant' diff --git a/src/renderer/src/aiCore/provider/factory.ts b/src/renderer/src/aiCore/provider/factory.ts index ff100051b..97ab29db8 100644 --- a/src/renderer/src/aiCore/provider/factory.ts +++ b/src/renderer/src/aiCore/provider/factory.ts @@ -1,8 +1,7 @@ -import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider' import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' import type { Provider } from '@renderer/types' -import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider' +import { getAiSdkProviderId as sharedGetAiSdkProviderId } from '@shared/provider' import type { Provider as AiSdkProvider } from 'ai' import type { AiSdkConfig } from '../types' @@ -22,69 +21,12 @@ const logger = loggerService.withContext('ProviderFactory') } })() -/** - * 静态Provider映射表 - * 处理Cherry Studio特有的provider ID到AI SDK标准ID的映射 - */ -const STATIC_PROVIDER_MAPPING: Record = { - gemini: 'google', // Google Gemini -> google - 'azure-openai': 'azure', // Azure OpenAI -> azure - 'openai-response': 'openai', // OpenAI Responses -> openai - grok: 'xai', // Grok -> xai - copilot: 'github-copilot-openai-compatible' -} - -/** - * 尝试解析provider标识符(支持静态映射和别名) - */ -function tryResolveProviderId(identifier: string): ProviderId | null { - // 1. 检查静态映射 - const staticMapping = STATIC_PROVIDER_MAPPING[identifier] - if (staticMapping) { - return staticMapping - } - - // 2. 检查AiCore是否支持(包括别名支持) - if (hasProviderConfigByAlias(identifier)) { - // 解析为真实的Provider ID - return resolveProviderConfigId(identifier) as ProviderId - } - - return null -} - /** * 获取AI SDK Provider ID - * 简化版:减少重复逻辑,利用通用解析函数 - * TODO: 整理函数逻辑 + * Uses shared implementation with renderer-specific config checker */ export function getAiSdkProviderId(provider: Provider): string { - // 1. 尝试解析provider.id - const resolvedFromId = tryResolveProviderId(provider.id) - if (isAzureOpenAIProvider(provider)) { - if (isAzureResponsesEndpoint(provider)) { - return 'azure-responses' - } else { - return 'azure' - } - } - if (resolvedFromId) { - return resolvedFromId - } - - // 2. 尝试解析provider.type - // 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上 - if (provider.type !== 'openai') { - const resolvedFromType = tryResolveProviderId(provider.type) - if (resolvedFromType) { - return resolvedFromType - } - } - if (provider.apiHost.includes('api.openai.com')) { - return 'openai-chat' - } - // 3. 最后的fallback(使用provider本身的id) - return provider.id + return sharedGetAiSdkProviderId(provider) } export async function createAiSdkProvider(config: AiSdkConfig): Promise { diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index 0ad15ea89..e4724997e 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -1,4 +1,4 @@ -import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider' +import { hasProviderConfig } from '@cherrystudio/ai-core/provider' import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models' import { getAwsBedrockAccessKeyId, @@ -9,58 +9,65 @@ import { } from '@renderer/hooks/useAwsBedrock' import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI' import { getProviderByModel } from '@renderer/services/AssistantService' -import { getProviderById } from '@renderer/services/ProviderService' import store from '@renderer/store' -import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types' -import type { OpenAICompletionsStreamOptions } from '@renderer/types/aiCoreTypes' +import { isSystemProvider, type Model, type Provider } from '@renderer/types' +import { isSupportStreamOptionsProvider } from '@renderer/utils/provider' import { - formatApiHost, - formatAzureOpenAIApiHost, - formatOllamaApiHost, - formatVertexApiHost, - isWithTrailingSharp, - routeToEndpoint -} from '@renderer/utils/api' -import { - isAnthropicProvider, - isAzureOpenAIProvider, - isCherryAIProvider, - isGeminiProvider, - isNewApiProvider, - isOllamaProvider, - isPerplexityProvider, - isSupportStreamOptionsProvider, - isVertexProvider -} from '@renderer/utils/provider' -import { defaultAppHeaders } from '@shared/utils' -import { cloneDeep, isEmpty } from 'lodash' + type AiSdkConfigContext, + formatProviderApiHost as sharedFormatProviderApiHost, + type ProviderFormatContext, + providerToAiSdkConfig as sharedProviderToAiSdkConfig, + resolveActualProvider +} from '@shared/provider' +import { cloneDeep } from 'lodash' import type { AiSdkConfig } from '../types' -import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config' -import { azureAnthropicProviderCreator } from './config/azure-anthropic' import { COPILOT_DEFAULT_HEADERS } from './constants' import { getAiSdkProviderId } from './factory' /** - * 处理特殊provider的转换逻辑 + * Renderer-specific context for providerToAiSdkConfig + * Provides implementations using browser APIs, store, and hooks */ -function handleSpecialProviders(model: Model, provider: Provider): Provider { - if (isNewApiProvider(provider)) { - return newApiResolverCreator(model, provider) +function createRendererSdkContext(model: Model): AiSdkConfigContext { + return { + isOpenAIChatCompletionOnlyModel: () => isOpenAIChatCompletionOnlyModel(model), + isSupportStreamOptionsProvider: (provider) => isSupportStreamOptionsProvider(provider as Provider), + getIncludeUsageSetting: () => store.getState().settings.openAI?.streamOptions?.includeUsage, + getCopilotDefaultHeaders: () => COPILOT_DEFAULT_HEADERS, + getCopilotStoredHeaders: () => store.getState().copilot.defaultHeaders ?? {}, + getAwsBedrockConfig: () => { + const authType = getAwsBedrockAuthType() + return { + authType, + region: getAwsBedrockRegion(), + apiKey: authType === 'apiKey' ? getAwsBedrockApiKey() : undefined, + accessKeyId: authType === 'iam' ? getAwsBedrockAccessKeyId() : undefined, + secretAccessKey: authType === 'iam' ? getAwsBedrockSecretAccessKey() : undefined + } + }, + getVertexConfig: (provider) => { + if (!isVertexAIConfigured()) { + return undefined + } + return createVertexProvider(provider as Provider) + }, + getEndpointType: () => model.endpoint_type } +} - if (isSystemProvider(provider)) { - if (provider.id === 'aihubmix') { - return aihubmixProviderCreator(model, provider) - } - if (provider.id === 'vertexai') { - return vertexAnthropicProviderCreator(model, provider) +/** + * 主要用来对齐AISdk的BaseURL格式 + * Uses shared implementation with renderer-specific context + */ +function getRendererFormatContext(): ProviderFormatContext { + const vertexSettings = store.getState().llm.settings.vertexai + return { + vertex: { + project: vertexSettings.projectId || 'default-project', + location: vertexSettings.location || 'us-central1' } } - if (isAzureOpenAIProvider(provider)) { - return azureAnthropicProviderCreator(model, provider) - } - return provider } /** @@ -70,38 +77,8 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider { * @param provider - The provider whose API host is to be formatted. * @returns A new provider instance with the formatted API host. */ -export function formatProviderApiHost(provider: Provider): Provider { - const formatted = { ...provider } - const appendApiVersion = !isWithTrailingSharp(provider.apiHost) - if (formatted.anthropicApiHost) { - formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost, appendApiVersion) - } - - if (isAnthropicProvider(provider)) { - const baseHost = formatted.anthropicApiHost || formatted.apiHost - // AI SDK needs /v1 in baseURL, Anthropic SDK will strip it in getSdkClient - formatted.apiHost = formatApiHost(baseHost, appendApiVersion) - if (!formatted.anthropicApiHost) { - formatted.anthropicApiHost = formatted.apiHost - } - } else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) { - formatted.apiHost = formatApiHost(formatted.apiHost, false) - } else if (isOllamaProvider(formatted)) { - formatted.apiHost = formatOllamaApiHost(formatted.apiHost) - } else if (isGeminiProvider(formatted)) { - formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion, 'v1beta') - } else if (isAzureOpenAIProvider(formatted)) { - formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost) - } else if (isVertexProvider(formatted)) { - formatted.apiHost = formatVertexApiHost(formatted) - } else if (isCherryAIProvider(formatted)) { - formatted.apiHost = formatApiHost(formatted.apiHost, false) - } else if (isPerplexityProvider(formatted)) { - formatted.apiHost = formatApiHost(formatted.apiHost, false) - } else { - formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion) - } - return formatted +function formatProviderApiHost(provider: Provider): Provider { + return sharedFormatProviderApiHost(provider, getRendererFormatContext()) } /** @@ -132,7 +109,9 @@ export function adaptProvider({ provider, model }: { provider: Provider; model?: // Apply transformations in order if (model) { - adaptedProvider = handleSpecialProviders(model, adaptedProvider) + adaptedProvider = resolveActualProvider(adaptedProvider, model, { + isSystemProvider + }) } adaptedProvider = formatProviderApiHost(adaptedProvider) @@ -141,148 +120,11 @@ export function adaptProvider({ provider, model }: { provider: Provider; model?: /** * 将 Provider 配置转换为新 AI SDK 格式 - * 简化版:利用新的别名映射系统 + * Uses shared implementation with renderer-specific context */ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig { - const aiSdkProviderId = getAiSdkProviderId(actualProvider) - - // 构建基础配置 - const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost) - const baseConfig = { - baseURL: baseURL, - apiKey: actualProvider.apiKey - } - let includeUsage: OpenAICompletionsStreamOptions['include_usage'] = undefined - if (isSupportStreamOptionsProvider(actualProvider)) { - includeUsage = store.getState().settings.openAI?.streamOptions?.includeUsage - } - - const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot - if (isCopilotProvider) { - const storedHeaders = store.getState().copilot.defaultHeaders ?? {} - const options = ProviderConfigFactory.fromProvider('github-copilot-openai-compatible', baseConfig, { - headers: { - ...COPILOT_DEFAULT_HEADERS, - ...storedHeaders, - ...actualProvider.extra_headers - }, - name: actualProvider.id, - includeUsage - }) - - return { - providerId: 'github-copilot-openai-compatible', - options - } - } - - if (isOllamaProvider(actualProvider)) { - return { - providerId: 'ollama', - options: { - ...baseConfig, - headers: { - ...actualProvider.extra_headers, - Authorization: !isEmpty(baseConfig.apiKey) ? `Bearer ${baseConfig.apiKey}` : undefined - } - } - } - } - - // 处理OpenAI模式 - const extraOptions: any = {} - extraOptions.endpoint = endpoint - if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) { - extraOptions.mode = 'responses' - } else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) { - extraOptions.mode = 'chat' - } - - extraOptions.headers = { - ...defaultAppHeaders(), - ...actualProvider.extra_headers - } - - if (aiSdkProviderId === 'openai') { - extraOptions.headers['X-Api-Key'] = baseConfig.apiKey - } - // azure - // https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest - // https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api - if (aiSdkProviderId === 'azure-responses') { - extraOptions.mode = 'responses' - } else if (aiSdkProviderId === 'azure') { - extraOptions.mode = 'chat' - } - if (isAzureOpenAIProvider(actualProvider)) { - const apiVersion = actualProvider.apiVersion?.trim() - if (apiVersion) { - extraOptions.apiVersion = apiVersion - if (!['preview', 'v1'].includes(apiVersion)) { - extraOptions.useDeploymentBasedUrls = true - } - } - } - - // bedrock - if (aiSdkProviderId === 'bedrock') { - const authType = getAwsBedrockAuthType() - extraOptions.region = getAwsBedrockRegion() - - if (authType === 'apiKey') { - extraOptions.apiKey = getAwsBedrockApiKey() - } else { - extraOptions.accessKeyId = getAwsBedrockAccessKeyId() - extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey() - } - } - // google-vertex - if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') { - if (!isVertexAIConfigured()) { - throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') - } - const { project, location, googleCredentials } = createVertexProvider(actualProvider) - extraOptions.project = project - extraOptions.location = location - extraOptions.googleCredentials = { - ...googleCredentials, - privateKey: formatPrivateKey(googleCredentials.privateKey) - } - baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models' - } - - // cherryin - if (aiSdkProviderId === 'cherryin') { - if (model.endpoint_type) { - extraOptions.endpointType = model.endpoint_type - } - // CherryIN API Host - const cherryinProvider = getProviderById(SystemProviderIds.cherryin) - if (cherryinProvider) { - extraOptions.anthropicBaseURL = cherryinProvider.anthropicApiHost + '/v1' - extraOptions.geminiBaseURL = cherryinProvider.apiHost + '/v1beta/models' - } - } - - if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { - const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions) - return { - providerId: aiSdkProviderId, - options - } - } - - // 否则fallback到openai-compatible - const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey) - return { - providerId: 'openai-compatible', - options: { - ...options, - name: actualProvider.id, - ...extraOptions, - includeUsage - } - } + const context = createRendererSdkContext(model) + return sharedProviderToAiSdkConfig(actualProvider, model.id, context) as AiSdkConfig } /** @@ -325,13 +167,13 @@ export async function prepareSpecialProviderConfig( break } case 'cherryai': { - config.options.fetch = async (url, options) => { + config.options.fetch = async (url: RequestInfo | URL, options: RequestInit) => { // 在这里对最终参数进行签名 const signature = await window.api.cherryai.generateSignature({ method: 'POST', path: '/chat/completions', query: '', - body: JSON.parse(options.body) + body: JSON.parse(options.body as string) }) return fetch(url, { ...options, diff --git a/src/renderer/src/aiCore/provider/providerInitialization.ts b/src/renderer/src/aiCore/provider/providerInitialization.ts index 51176c1e6..5254e7885 100644 --- a/src/renderer/src/aiCore/provider/providerInitialization.ts +++ b/src/renderer/src/aiCore/provider/providerInitialization.ts @@ -1,124 +1,13 @@ -import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' -import * as z from 'zod' +import { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from '@shared/provider' const logger = loggerService.withContext('ProviderConfigs') -/** - * 新Provider配置定义 - * 定义了需要动态注册的AI Providers - */ -export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [ - { - id: 'openrouter', - name: 'OpenRouter', - import: () => import('@openrouter/ai-sdk-provider'), - creatorFunctionName: 'createOpenRouter', - supportsImageGeneration: true, - aliases: ['openrouter'] - }, - { - id: 'google-vertex', - name: 'Google Vertex AI', - import: () => import('@ai-sdk/google-vertex/edge'), - creatorFunctionName: 'createVertex', - supportsImageGeneration: true, - aliases: ['vertexai'] - }, - { - id: 'google-vertex-anthropic', - name: 'Google Vertex AI Anthropic', - import: () => import('@ai-sdk/google-vertex/anthropic/edge'), - creatorFunctionName: 'createVertexAnthropic', - supportsImageGeneration: true, - aliases: ['vertexai-anthropic'] - }, - { - id: 'azure-anthropic', - name: 'Azure AI Anthropic', - import: () => import('@ai-sdk/anthropic'), - creatorFunctionName: 'createAnthropic', - supportsImageGeneration: false, - aliases: ['azure-anthropic'] - }, - { - id: 'github-copilot-openai-compatible', - name: 'GitHub Copilot OpenAI Compatible', - import: () => import('@opeoginni/github-copilot-openai-compatible'), - creatorFunctionName: 'createGitHubCopilotOpenAICompatible', - supportsImageGeneration: false, - aliases: ['copilot', 'github-copilot'] - }, - { - id: 'bedrock', - name: 'Amazon Bedrock', - import: () => import('@ai-sdk/amazon-bedrock'), - creatorFunctionName: 'createAmazonBedrock', - supportsImageGeneration: true, - aliases: ['aws-bedrock'] - }, - { - id: 'perplexity', - name: 'Perplexity', - import: () => import('@ai-sdk/perplexity'), - creatorFunctionName: 'createPerplexity', - supportsImageGeneration: false, - aliases: ['perplexity'] - }, - { - id: 'mistral', - name: 'Mistral', - import: () => import('@ai-sdk/mistral'), - creatorFunctionName: 'createMistral', - supportsImageGeneration: false, - aliases: ['mistral'] - }, - { - id: 'huggingface', - name: 'HuggingFace', - import: () => import('@ai-sdk/huggingface'), - creatorFunctionName: 'createHuggingFace', - supportsImageGeneration: true, - aliases: ['hf', 'hugging-face'] - }, - { - id: 'gateway', - name: 'Vercel AI Gateway', - import: () => import('@ai-sdk/gateway'), - creatorFunctionName: 'createGateway', - supportsImageGeneration: true, - aliases: ['ai-gateway'] - }, - { - id: 'cerebras', - name: 'Cerebras', - import: () => import('@ai-sdk/cerebras'), - creatorFunctionName: 'createCerebras', - supportsImageGeneration: false - }, - { - id: 'ollama', - name: 'Ollama', - import: () => import('ollama-ai-provider-v2'), - creatorFunctionName: 'createOllama', - supportsImageGeneration: false - } -] as const +export const NEW_PROVIDER_CONFIGS = SHARED_PROVIDER_CONFIGS -export const registeredNewProviderIds = NEW_PROVIDER_CONFIGS.map((config) => config.id) -export const registeredNewProviderIdSchema = z.enum(registeredNewProviderIds) - -/** - * 初始化新的Providers - * 使用aiCore的动态注册功能 - */ export async function initializeNewProviders(): Promise { - try { - const successCount = registerMultipleProviderConfigs(NEW_PROVIDER_CONFIGS) - if (successCount < NEW_PROVIDER_CONFIGS.length) { - logger.warn('Some providers failed to register. Check previous error logs.') - } - } catch (error) { - logger.error('Failed to initialize new providers:', error as Error) - } + initializeSharedProviders({ + warn: (message) => logger.warn(message), + error: (message, error) => logger.error(message, error) + }) } diff --git a/src/renderer/src/config/models/__tests__/tooluse.test.ts b/src/renderer/src/config/models/__tests__/tooluse.test.ts index 6ba3c7f1f..60d3657af 100644 --- a/src/renderer/src/config/models/__tests__/tooluse.test.ts +++ b/src/renderer/src/config/models/__tests__/tooluse.test.ts @@ -6,6 +6,29 @@ import { isDeepSeekHybridInferenceModel } from '../reasoning' import { isFunctionCallingModel } from '../tooluse' import { isPureGenerateImageModel, isTextToImageModel } from '../vision' +vi.mock('@renderer/i18n', () => ({ + __esModule: true, + default: { + t: vi.fn((key: string) => key) + } +})) + +vi.mock('@renderer/services/AssistantService', () => ({ + getProviderByModel: vi.fn().mockReturnValue({ + id: 'openai', + type: 'openai', + name: 'OpenAI', + models: [] + }), + getAssistantSettings: vi.fn(), + getDefaultAssistant: vi.fn().mockReturnValue({ + id: 'default', + name: 'Default Assistant', + prompt: '', + settings: {} + }) +})) + vi.mock('@renderer/hooks/useStore', () => ({ getStoreProviders: vi.fn(() => []) })) diff --git a/src/renderer/src/config/models/__tests__/utils.test.ts b/src/renderer/src/config/models/__tests__/utils.test.ts index 602b0737a..0b2294fd4 100644 --- a/src/renderer/src/config/models/__tests__/utils.test.ts +++ b/src/renderer/src/config/models/__tests__/utils.test.ts @@ -15,6 +15,7 @@ import { isSupportVerbosityModel } from '../openai' import { isQwenMTModel } from '../qwen' +import { isFunctionCallingModel } from '../tooluse' import { agentModelFilter, getModelSupportedVerbosity, @@ -71,6 +72,29 @@ vi.mock('@renderer/store/settings', () => { ) }) +vi.mock('@renderer/i18n', () => ({ + __esModule: true, + default: { + t: vi.fn((key: string) => key) + } +})) + +vi.mock('@renderer/services/AssistantService', () => ({ + getProviderByModel: vi.fn().mockReturnValue({ + id: 'openai', + type: 'openai', + name: 'OpenAI', + models: [] + }), + getAssistantSettings: vi.fn(), + getDefaultAssistant: vi.fn().mockReturnValue({ + id: 'default', + name: 'Default Assistant', + prompt: '', + settings: {} + }) +})) + vi.mock('@renderer/hooks/useSettings', () => ({ useSettings: vi.fn(() => ({})), useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })), @@ -101,6 +125,10 @@ vi.mock('../websearch', () => ({ isOpenAIWebSearchChatCompletionOnlyModel: vi.fn() })) +vi.mock('../tooluse', () => ({ + isFunctionCallingModel: vi.fn() +})) + const createModel = (overrides: Partial = {}): Model => ({ id: 'gpt-4o', name: 'gpt-4o', @@ -116,6 +144,7 @@ const textToImageMock = vi.mocked(isTextToImageModel) const generateImageMock = vi.mocked(isGenerateImageModel) const reasoningMock = vi.mocked(isOpenAIReasoningModel) const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel) +const isFunctionCallingModelMock = vi.mocked(isFunctionCallingModel) describe('model utils', () => { beforeEach(() => { @@ -124,9 +153,10 @@ describe('model utils', () => { rerankMock.mockReturnValue(false) visionMock.mockReturnValue(true) textToImageMock.mockReturnValue(false) - generateImageMock.mockReturnValue(true) + generateImageMock.mockReturnValue(false) reasoningMock.mockReturnValue(false) openAIWebSearchOnlyMock.mockReturnValue(false) + isFunctionCallingModelMock.mockReturnValue(true) }) describe('OpenAI model detection', () => { @@ -598,6 +628,7 @@ describe('model utils', () => { describe('isGenerateImageModels', () => { it('returns true when all models support image generation', () => { const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })] + generateImageMock.mockReturnValue(true) expect(isGenerateImageModels(models)).toBe(true) }) @@ -636,12 +667,22 @@ describe('model utils', () => { expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false) }) + it('filters out non-function-call models', () => { + rerankMock.mockReturnValue(false) + isFunctionCallingModelMock.mockReturnValueOnce(false) + expect(agentModelFilter(createModel({ id: 'DeepSeek R1' }))).toBe(false) + }) + it('filters out text-to-image models', () => { rerankMock.mockReturnValue(false) textToImageMock.mockReturnValueOnce(true) expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false) }) }) + + textToImageMock.mockReturnValue(false) + generateImageMock.mockReturnValueOnce(true) + expect(agentModelFilter(createModel({ id: 'dall-e-3' }))).toBe(false) }) describe('Temperature limits', () => { diff --git a/src/renderer/src/config/models/tooluse.ts b/src/renderer/src/config/models/tooluse.ts index 54d371dfd..4eaf66752 100644 --- a/src/renderer/src/config/models/tooluse.ts +++ b/src/renderer/src/config/models/tooluse.ts @@ -1,6 +1,8 @@ +import { getProviderByModel } from '@renderer/services/AssistantService' import type { Model } from '@renderer/types' import { isSystemProviderId } from '@renderer/types' import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils' +import { isAzureOpenAIProvider } from '@shared/provider' import { isEmbeddingModel, isRerankModel } from './embedding' import { isDeepSeekHybridInferenceModel } from './reasoning' @@ -55,6 +57,13 @@ export const FUNCTION_CALLING_REGEX = new RegExp( 'i' ) +const AZURE_FUNCTION_CALLING_EXCLUDED_MODELS = [ + '(?:Meta-)?Llama-3(?:\\.\\d+)?-[\\w-]+', + 'Phi-[34](?:\\.[\\w-]+)?(?:-[\\w-]+)?', + 'DeepSeek-(?:R1|V3)', + 'Codestral-2501' +] + export function isFunctionCallingModel(model?: Model): boolean { if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) { return false @@ -70,6 +79,15 @@ export function isFunctionCallingModel(model?: Model): boolean { return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name) } + const provider = getProviderByModel(model) + + if (isAzureOpenAIProvider(provider)) { + const azureExcludedRegex = new RegExp(`\\b(?:${AZURE_FUNCTION_CALLING_EXCLUDED_MODELS.join('|')})\\b`, 'i') + if (azureExcludedRegex.test(modelId)) { + return false + } + } + // 2025/08/26 百炼与火山引擎均不支持 v3.1 函数调用 // 先默认支持 if (isDeepSeekHybridInferenceModel(model)) { diff --git a/src/renderer/src/config/models/utils.ts b/src/renderer/src/config/models/utils.ts index 12e85326c..0a60244e1 100644 --- a/src/renderer/src/config/models/utils.ts +++ b/src/renderer/src/config/models/utils.ts @@ -1,5 +1,6 @@ import type OpenAI from '@cherrystudio/openai' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding' +import { getProviderByModel } from '@renderer/services/AssistantService' import type { Assistant } from '@renderer/types' import { type Model, SystemProviderIds } from '@renderer/types' import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes' @@ -17,6 +18,7 @@ import { } from './openai' import { isQwenMTModel } from './qwen' import { isClaude45ReasoningModel } from './reasoning' +import { isFunctionCallingModel } from './tooluse' import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision' export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i') @@ -247,8 +249,21 @@ export const isGrokModel = (model: Model) => { // zhipu 视觉推理模型用这组 special token 标记推理结果 export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const +// TODO: 支持提示词模式的工具调用 export const agentModelFilter = (model: Model): boolean => { - return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model) + const provider = getProviderByModel(model) + + // 需要适配,且容易超出限额 + if (provider.id === SystemProviderIds.copilot) { + return false + } + return ( + !isEmbeddingModel(model) && + !isRerankModel(model) && + !isTextToImageModel(model) && + !isGenerateImageModel(model) && + isFunctionCallingModel(model) + ) } export const isMaxTemperatureOneModel = (model: Model): boolean => { diff --git a/src/renderer/src/pages/code/CodeToolsPage.tsx b/src/renderer/src/pages/code/CodeToolsPage.tsx index fcb2dbf48..a4314dfef 100644 --- a/src/renderer/src/pages/code/CodeToolsPage.tsx +++ b/src/renderer/src/pages/code/CodeToolsPage.tsx @@ -17,7 +17,7 @@ import type { EndpointType, Model } from '@renderer/types' import { getClaudeSupportedProviders } from '@renderer/utils/provider' import type { TerminalConfig } from '@shared/config/constant' import { codeTools, terminalApps } from '@shared/config/constant' -import { isSiliconAnthropicCompatibleModel } from '@shared/config/providers' +import { isPpioAnthropicCompatibleModel, isSiliconAnthropicCompatibleModel } from '@shared/config/providers' import { Alert, Avatar, Button, Checkbox, Input, Popover, Select, Space, Tooltip } from 'antd' import { ArrowUpRight, Download, FolderOpen, HelpCircle, Terminal, X } from 'lucide-react' import type { FC } from 'react' @@ -82,10 +82,12 @@ const CodeToolsPage: FC = () => { if (m.supported_endpoint_types) { return m.supported_endpoint_types.includes('anthropic') } - // Special handling for silicon provider: only specific models support Anthropic API if (m.provider === 'silicon') { return isSiliconAnthropicCompatibleModel(m.id) } + if (m.provider === 'ppio') { + return isPpioAnthropicCompatibleModel(m.id) + } return m.id.includes('claude') || CLAUDE_OFFICIAL_SUPPORTED_PROVIDERS.includes(m.provider) } diff --git a/src/renderer/src/pages/home/Inputbar/AgentSessionInputbar.tsx b/src/renderer/src/pages/home/Inputbar/AgentSessionInputbar.tsx index 850be7f72..2aedab0b2 100644 --- a/src/renderer/src/pages/home/Inputbar/AgentSessionInputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/AgentSessionInputbar.tsx @@ -23,6 +23,7 @@ import { abortCompletion } from '@renderer/utils/abortController' import { buildAgentSessionTopicId } from '@renderer/utils/agentSession' import { getSendMessageShortcutLabel } from '@renderer/utils/input' import { createMainTextBlock, createMessage } from '@renderer/utils/messageUtils/create' +import { parseModelId } from '@renderer/utils/model' import { documentExts, imageExts, textExts } from '@shared/config/constant' import type { FC } from 'react' import React, { useCallback, useEffect, useMemo, useRef } from 'react' @@ -67,8 +68,9 @@ const AgentSessionInputbar: FC = ({ agentId, sessionId }) => { if (!session) return null // Extract model info - const [providerId, actualModelId] = session.model?.split(':') ?? [undefined, undefined] - const actualModel = actualModelId ? getModel(actualModelId, providerId) : undefined + // Use parseModelId to handle model IDs with colons (e.g., "openrouter:anthropic/claude:free") + const parsed = parseModelId(session.model) + const actualModel = parsed ? getModel(parsed.modelId, parsed.providerId) : undefined const model: Model | undefined = actualModel ? { diff --git a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx index 049c14c0d..e1e01883f 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx @@ -81,7 +81,8 @@ const ANTHROPIC_COMPATIBLE_PROVIDER_IDS = [ SystemProviderIds.silicon, SystemProviderIds.qiniu, SystemProviderIds.dmxapi, - SystemProviderIds.mimo + SystemProviderIds.mimo, + SystemProviderIds.ppio ] as const type AnthropicCompatibleProviderId = (typeof ANTHROPIC_COMPATIBLE_PROVIDER_IDS)[number] diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 5fe1bc090..a3dd1cbaa 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -2945,6 +2945,11 @@ const migrateConfig = { includeUsage: DEFAULT_STREAM_OPTIONS_INCLUDE_USAGE } } + state.llm.providers.forEach((provider) => { + if (provider.id === SystemProviderIds.ppio) { + provider.anthropicApiHost = 'https://api.ppinfra.com/anthropic' + } + }) logger.info('migrate 182 success') return state } catch (error) { diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index eefa380a6..5893a31ba 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -7,6 +7,7 @@ import type { CSSProperties } from 'react' export * from './file' export * from './note' +import type { MinimalModel } from '@shared/provider/types' import * as z from 'zod' import type { StreamTextParams } from './aiCoreTypes' @@ -274,7 +275,7 @@ export type ModelCapability = { isUserSelected?: boolean } -export type Model = { +export type Model = MinimalModel & { id: string provider: string name: string diff --git a/src/renderer/src/types/provider.ts b/src/renderer/src/types/provider.ts index edab3a730..573e1e100 100644 --- a/src/renderer/src/types/provider.ts +++ b/src/renderer/src/types/provider.ts @@ -1,25 +1,14 @@ import type OpenAI from '@cherrystudio/openai' +import type { MinimalProvider } from '@shared/provider' +import type { ProviderType, SystemProviderId, SystemProviderIdTypeMap } from '@shared/provider/types' +import { isSystemProviderId, SystemProviderIds } from '@shared/provider/types' import type { Model } from '@types' -import * as z from 'zod' import type { OpenAIVerbosity } from './aiCoreTypes' -export const ProviderTypeSchema = z.enum([ - 'openai', - 'openai-response', - 'anthropic', - 'gemini', - 'azure-openai', - 'vertexai', - 'mistral', - 'aws-bedrock', - 'vertex-anthropic', - 'new-api', - 'gateway', - 'ollama' -]) - -export type ProviderType = z.infer +export type { ProviderType } from '@shared/provider' +export type { SystemProviderId, SystemProviderIdTypeMap } from '@shared/provider/types' +export { isSystemProviderId, ProviderTypeSchema, SystemProviderIds } from '@shared/provider/types' // undefined is treated as supported, enabled by default export type ProviderApiOptions = { @@ -94,7 +83,7 @@ export function isAwsBedrockAuthType(type: string): type is AwsBedrockAuthType { return Object.hasOwn(AwsBedrockAuthTypes, type) } -export type Provider = { +export type Provider = MinimalProvider & { id: string type: ProviderType name: string @@ -129,142 +118,6 @@ export type Provider = { extra_headers?: Record } -export const SystemProviderIdSchema = z.enum([ - 'cherryin', - 'silicon', - 'aihubmix', - 'ocoolai', - 'deepseek', - 'ppio', - 'alayanew', - 'qiniu', - 'dmxapi', - 'burncloud', - 'tokenflux', - '302ai', - 'cephalon', - 'lanyun', - 'ph8', - 'openrouter', - 'ollama', - 'ovms', - 'new-api', - 'lmstudio', - 'anthropic', - 'openai', - 'azure-openai', - 'gemini', - 'vertexai', - 'github', - 'copilot', - 'zhipu', - 'yi', - 'moonshot', - 'baichuan', - 'dashscope', - 'stepfun', - 'doubao', - 'infini', - 'minimax', - 'groq', - 'together', - 'fireworks', - 'nvidia', - 'grok', - 'hyperbolic', - 'mistral', - 'jina', - 'perplexity', - 'modelscope', - 'xirang', - 'hunyuan', - 'tencent-cloud-ti', - 'baidu-cloud', - 'gpustack', - 'voyageai', - 'aws-bedrock', - 'poe', - 'aionly', - 'longcat', - 'huggingface', - 'sophnet', - 'gateway', - 'cerebras', - 'mimo' -]) - -export type SystemProviderId = z.infer - -export const isSystemProviderId = (id: string): id is SystemProviderId => { - return SystemProviderIdSchema.safeParse(id).success -} - -export const SystemProviderIds = { - cherryin: 'cherryin', - silicon: 'silicon', - aihubmix: 'aihubmix', - ocoolai: 'ocoolai', - deepseek: 'deepseek', - ppio: 'ppio', - alayanew: 'alayanew', - qiniu: 'qiniu', - dmxapi: 'dmxapi', - burncloud: 'burncloud', - tokenflux: 'tokenflux', - '302ai': '302ai', - cephalon: 'cephalon', - lanyun: 'lanyun', - ph8: 'ph8', - sophnet: 'sophnet', - openrouter: 'openrouter', - ollama: 'ollama', - ovms: 'ovms', - 'new-api': 'new-api', - lmstudio: 'lmstudio', - anthropic: 'anthropic', - openai: 'openai', - 'azure-openai': 'azure-openai', - gemini: 'gemini', - vertexai: 'vertexai', - github: 'github', - copilot: 'copilot', - zhipu: 'zhipu', - yi: 'yi', - moonshot: 'moonshot', - baichuan: 'baichuan', - dashscope: 'dashscope', - stepfun: 'stepfun', - doubao: 'doubao', - infini: 'infini', - minimax: 'minimax', - groq: 'groq', - together: 'together', - fireworks: 'fireworks', - nvidia: 'nvidia', - grok: 'grok', - hyperbolic: 'hyperbolic', - mistral: 'mistral', - jina: 'jina', - perplexity: 'perplexity', - modelscope: 'modelscope', - xirang: 'xirang', - hunyuan: 'hunyuan', - 'tencent-cloud-ti': 'tencent-cloud-ti', - 'baidu-cloud': 'baidu-cloud', - gpustack: 'gpustack', - voyageai: 'voyageai', - 'aws-bedrock': 'aws-bedrock', - poe: 'poe', - aionly: 'aionly', - longcat: 'longcat', - huggingface: 'huggingface', - gateway: 'gateway', - cerebras: 'cerebras', - mimo: 'mimo' -} as const satisfies Record - -type SystemProviderIdTypeMap = typeof SystemProviderIds - export type SystemProvider = Provider & { id: SystemProviderId isSystem: true diff --git a/src/renderer/src/utils/__tests__/api.test.ts b/src/renderer/src/utils/__tests__/api.test.ts index f5251b839..4705e0b7f 100644 --- a/src/renderer/src/utils/__tests__/api.test.ts +++ b/src/renderer/src/utils/__tests__/api.test.ts @@ -326,18 +326,7 @@ describe('api', () => { }) it('uses global endpoint when location equals global', () => { - getStateMock.mockReturnValueOnce({ - llm: { - settings: { - vertexai: { - projectId: 'global-project', - location: 'global' - } - } - } - }) - - expect(formatVertexApiHost(createVertexProvider(''))).toBe( + expect(formatVertexApiHost(createVertexProvider(''), 'global-project', 'global')).toBe( 'https://aiplatform.googleapis.com/v1/projects/global-project/locations/global' ) }) diff --git a/src/renderer/src/utils/__tests__/model.test.ts b/src/renderer/src/utils/__tests__/model.test.ts index fe1697e3e..d06a7a2f4 100644 --- a/src/renderer/src/utils/__tests__/model.test.ts +++ b/src/renderer/src/utils/__tests__/model.test.ts @@ -1,7 +1,7 @@ import type { Model, ModelTag } from '@renderer/types' import { describe, expect, it, vi } from 'vitest' -import { getModelTags, isFreeModel } from '../model' +import { getModelTags, isFreeModel, parseModelId } from '../model' // Mock the model checking functions from @renderer/config/models vi.mock('@renderer/config/models', () => ({ @@ -92,4 +92,85 @@ describe('model', () => { expect(getModelTags(models_2)).toStrictEqual(expected_2) }) }) + + describe('parseModelId', () => { + it('should parse model identifiers with single colon', () => { + expect(parseModelId('anthropic:claude-3-sonnet')).toEqual({ + providerId: 'anthropic', + modelId: 'claude-3-sonnet' + }) + + expect(parseModelId('openai:gpt-4')).toEqual({ + providerId: 'openai', + modelId: 'gpt-4' + }) + }) + + it('should parse model identifiers with multiple colons', () => { + expect(parseModelId('openrouter:anthropic/claude-3.5-sonnet:free')).toEqual({ + providerId: 'openrouter', + modelId: 'anthropic/claude-3.5-sonnet:free' + }) + + expect(parseModelId('provider:model:suffix:extra')).toEqual({ + providerId: 'provider', + modelId: 'model:suffix:extra' + }) + }) + + it('should handle model identifiers without provider prefix', () => { + expect(parseModelId('claude-3-sonnet')).toEqual({ + providerId: undefined, + modelId: 'claude-3-sonnet' + }) + + expect(parseModelId('gpt-4')).toEqual({ + providerId: undefined, + modelId: 'gpt-4' + }) + }) + + it('should return undefined for invalid inputs', () => { + expect(parseModelId(undefined)).toBeUndefined() + expect(parseModelId('')).toBeUndefined() + expect(parseModelId(' ')).toBeUndefined() + }) + + it('should handle edge cases with colons', () => { + // Colon at start - treat as modelId without provider + expect(parseModelId(':missing-provider')).toEqual({ + providerId: undefined, + modelId: ':missing-provider' + }) + + // Colon at end - treat everything before as modelId + expect(parseModelId('missing-model:')).toEqual({ + providerId: undefined, + modelId: 'missing-model' + }) + + // Only colon - treat as modelId without provider + expect(parseModelId(':')).toEqual({ + providerId: undefined, + modelId: ':' + }) + }) + + it('should handle edge cases', () => { + expect(parseModelId('a:b')).toEqual({ + providerId: 'a', + modelId: 'b' + }) + + expect(parseModelId('provider:model-with-dashes')).toEqual({ + providerId: 'provider', + modelId: 'model-with-dashes' + }) + + expect(parseModelId('provider:model/with/slashes')).toEqual({ + providerId: 'provider', + modelId: 'model/with/slashes' + }) + }) + }) }) diff --git a/src/renderer/src/utils/api.ts b/src/renderer/src/utils/api.ts index 25a73dcb1..216a8c51f 100644 --- a/src/renderer/src/utils/api.ts +++ b/src/renderer/src/utils/api.ts @@ -1,6 +1,20 @@ -import store from '@renderer/store' -import type { VertexProvider } from '@renderer/types' -import { trim } from 'lodash' +export { + formatApiHost, + formatAzureOpenAIApiHost, + formatOllamaApiHost, + formatVertexApiHost, + getAiSdkBaseUrl, + getTrailingApiVersion, + hasAPIVersion, + isWithTrailingSharp, + routeToEndpoint, + SUPPORTED_ENDPOINT_LIST, + SUPPORTED_IMAGE_ENDPOINT_LIST, + validateApiHost, + withoutTrailingApiVersion, + withoutTrailingSharp, + withoutTrailingSlash +} from '@shared/utils/url' /** * 格式化 API key 字符串。 @@ -12,228 +26,6 @@ export function formatApiKeys(value: string): string { return value.replaceAll(',', ',').replaceAll('\n', ',') } -/** - * Matches a version segment in a path that starts with `/v` and optionally - * continues with `alpha` or `beta`. The segment may be followed by `/` or the end - * of the string (useful for cases like `/v3alpha/resources`). - */ -const VERSION_REGEX_PATTERN = '\\/v\\d+(?:alpha|beta)?(?=\\/|$)' - -/** - * Matches an API version at the end of a URL (with optional trailing slash). - * Used to detect and extract versions only from the trailing position. - */ -const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i - -/** - * 判断 host 的 path 中是否包含形如版本的字符串(例如 /v1、/v2beta 等), - * - * @param host - 要检查的 host 或 path 字符串 - * @returns 如果 path 中包含版本字符串则返回 true,否则 false - */ -export function hasAPIVersion(host?: string): boolean { - if (!host) return false - - const regex = new RegExp(VERSION_REGEX_PATTERN, 'i') - - try { - const url = new URL(host) - return regex.test(url.pathname) - } catch { - // 若无法作为完整 URL 解析,则当作路径直接检测 - return regex.test(host) - } -} - -/** - * Removes the trailing slash from a URL string if it exists. - * - * @template T - The string type to preserve type safety - * @param {T} url - The URL string to process - * @returns {T} The URL string without a trailing slash - * - * @example - * ```ts - * withoutTrailingSlash('https://example.com/') // 'https://example.com' - * withoutTrailingSlash('https://example.com') // 'https://example.com' - * ``` - */ -export function withoutTrailingSlash(url: T): T { - return url.replace(/\/$/, '') as T -} - -/** - * Checks if a URL string ends with a trailing '#' character. - * - * @template T - The string type to preserve type safety - * @param {T} url - The URL string to check - * @returns {boolean} True if the URL ends with '#', false otherwise - * - * @example - * ```ts - * isWithTrailingSharp('https://example.com#') // true - * isWithTrailingSharp('https://example.com') // false - * ``` - */ -export function isWithTrailingSharp(url: T): boolean { - return url.endsWith('#') -} - -/** - * Removes the trailing '#' from a URL string if it exists. - * - * @template T - The string type to preserve type safety - * @param {T} url - The URL string to process - * @returns {T} The URL string without a trailing '#' - * - * @example - * ```ts - * withoutTrailingSharp('https://example.com#') // 'https://example.com' - * withoutTrailingSharp('https://example.com') // 'https://example.com' - * ``` - */ -export function withoutTrailingSharp(url: T): T { - return url.replace(/#$/, '') as T -} - -/** - * Formats an API host URL by normalizing it and optionally appending an API version. - * - * @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed. - * @param supportApiVersion - Whether the API version is supported. Defaults to `true`. - * @param apiVersion - The API version to append if needed. Defaults to `'v1'`. - * - * @returns The formatted API host URL. If the host is empty after normalization, returns an empty string. - * If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host with trailing '#' removed. - * Otherwise, returns the host with the API version appended. - * - * @example - * formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1' - * formatApiHost('https://api.example.com#') // Returns 'https://api.example.com' - * formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2' - */ -export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string { - const normalizedHost = withoutTrailingSlash(trim(host)) - if (!normalizedHost) { - return '' - } - - const shouldAppendApiVersion = !(normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost)) - - if (shouldAppendApiVersion) { - return `${normalizedHost}/${apiVersion}` - } else { - return withoutTrailingSharp(normalizedHost) - } -} - -/** - * 格式化 Ollama 的 API 主机地址。 - */ -export function formatOllamaApiHost(host: string): string { - const normalizedHost = withoutTrailingSlash(host) - ?.replace(/\/v1$/, '') - ?.replace(/\/api$/, '') - ?.replace(/\/chat$/, '') - return formatApiHost(normalizedHost + '/api', false) -} - -/** - * 格式化 Azure OpenAI 的 API 主机地址。 - */ -export function formatAzureOpenAIApiHost(host: string): string { - const normalizedHost = withoutTrailingSlash(host) - ?.replace(/\/v1$/, '') - .replace(/\/openai$/, '') - // NOTE: AISDK会添加上`v1` - return formatApiHost(normalizedHost + '/openai', false) -} - -export function formatVertexApiHost(provider: VertexProvider): string { - const { apiHost } = provider - const { projectId: project, location } = store.getState().llm.settings.vertexai - const trimmedHost = withoutTrailingSlash(trim(apiHost)) - if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) { - const host = - location == 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com` - return `${formatApiHost(host)}/projects/${project}/locations/${location}` - } - return formatApiHost(trimmedHost) -} - -// 目前对话界面只支持这些端点 -export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const -export const SUPPORTED_ENDPOINT_LIST = [ - 'chat/completions', - 'responses', - 'messages', - 'generateContent', - 'streamGenerateContent', - ...SUPPORTED_IMAGE_ENDPOINT_LIST -] as const - -/** - * Converts an API host URL into separate base URL and endpoint components. - * - * @param apiHost - The API host string to parse. Expected to be a trimmed URL that may end with '#' followed by an endpoint identifier. - * @returns An object containing: - * - `baseURL`: The base URL without the endpoint suffix - * - `endpoint`: The matched endpoint identifier, or empty string if no match found - * - * @description - * This function extracts endpoint information from a composite API host string. - * If the host ends with '#', it attempts to match the preceding part against the supported endpoint list. - * The '#' delimiter is removed before processing. - * - * @example - * routeToEndpoint('https://api.example.com/openai/chat/completions#') - * // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' } - * - * @example - * routeToEndpoint('https://api.example.com/v1') - * // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' } - */ -export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } { - const trimmedHost = trim(apiHost) - // 前面已经确保apiHost合法 - if (!trimmedHost.endsWith('#')) { - return { baseURL: trimmedHost, endpoint: '' } - } - // 去掉结尾的 # - const host = trimmedHost.slice(0, -1) - const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint)) - if (!endpointMatch) { - const baseURL = withoutTrailingSlash(host) - return { baseURL, endpoint: '' } - } - const baseSegment = host.slice(0, host.length - endpointMatch.length) - const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // 去掉结尾可能存在的冒号(gemini的特殊情况) - return { baseURL, endpoint: endpointMatch } -} - -/** - * 验证 API 主机地址是否合法。 - * - * @param {string} apiHost - 需要验证的 API 主机地址。 - * @returns {boolean} 如果是合法的 URL 则返回 true,否则返回 false。 - */ -export function validateApiHost(apiHost: string): boolean { - // 允许apiHost为空 - if (!apiHost || !trim(apiHost)) { - return true - } - try { - const url = new URL(trim(apiHost)) - // 验证协议是否为 http 或 https - if (url.protocol !== 'http:' && url.protocol !== 'https:') { - return false - } - return true - } catch { - return false - } -} - /** * API key 脱敏函数。仅保留部分前后字符,中间用星号代替。 * @@ -272,50 +64,3 @@ export function splitApiKeyString(keyStr: string): string[] { .map((k) => k.replace(/\\,/g, ',')) .filter((k) => k) } - -/** - * Extracts the trailing API version segment from a URL path. - * - * This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL. - * Only versions at the end of the path are extracted, not versions in the middle. - * The returned version string does not include leading or trailing slashes. - * - * @param {string} url - The URL string to parse. - * @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found. - * - * @example - * getTrailingApiVersion('https://api.example.com/v1') // 'v1' - * getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta' - * getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end) - * getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta' - * getTrailingApiVersion('https://api.example.com') // undefined - */ -export function getTrailingApiVersion(url: string): string | undefined { - const match = url.match(TRAILING_VERSION_REGEX) - - if (match) { - // Extract version without leading slash and trailing slash - return match[0].replace(/^\//, '').replace(/\/$/, '') - } - - return undefined -} - -/** - * Removes the trailing API version segment from a URL path. - * - * This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL. - * Only versions at the end of the path are removed, not versions in the middle. - * - * @param {string} url - The URL string to process. - * @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found. - * - * @example - * withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com' - * withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com' - * withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change) - * withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com' - */ -export function withoutTrailingApiVersion(url: string): string { - return url.replace(TRAILING_VERSION_REGEX, '') -} diff --git a/src/renderer/src/utils/model.ts b/src/renderer/src/utils/model.ts index a74ffab25..c3efda2cb 100644 --- a/src/renderer/src/utils/model.ts +++ b/src/renderer/src/utils/model.ts @@ -81,3 +81,57 @@ export const apiModelAdapter = (model: ApiModel): AdaptedApiModel => { origin: model } } + +/** + * Parse a model identifier in the format "provider:modelId" + * where modelId may contain additional colons (e.g., "openrouter:anthropic/claude-3.5-sonnet:free") + * + * @param modelIdentifier - The full model identifier string + * @returns Object with providerId and modelId. If no provider prefix found, providerId will be undefined + * + * @example + * parseModelId("openrouter:anthropic/claude-3.5-sonnet:free") + * // => { providerId: "openrouter", modelId: "anthropic/claude-3.5-sonnet:free" } + * + * @example + * parseModelId("anthropic:claude-3-sonnet") + * // => { providerId: "anthropic", modelId: "claude-3-sonnet" } + * + * @example + * parseModelId("claude-3-sonnet") + * // => { providerId: undefined, modelId: "claude-3-sonnet" } + * + * @example + * parseModelId("") // => undefined + */ +export function parseModelId( + modelIdentifier: string | undefined +): { providerId: string | undefined; modelId: string } | undefined { + if (!modelIdentifier || typeof modelIdentifier !== 'string' || modelIdentifier.trim() === '') { + return undefined + } + + const colonIndex = modelIdentifier.indexOf(':') + + // No colon found or colon at the start - treat entire string as modelId + if (colonIndex <= 0) { + return { + providerId: undefined, + modelId: modelIdentifier + } + } + + // Colon at the end - treat everything before as modelId + if (colonIndex >= modelIdentifier.length - 1) { + return { + providerId: undefined, + modelId: modelIdentifier.substring(0, colonIndex) + } + } + + // Standard format: "provider:modelId" + return { + providerId: modelIdentifier.substring(0, colonIndex), + modelId: modelIdentifier.substring(colonIndex + 1) + } +} diff --git a/src/renderer/src/utils/naming.ts b/src/renderer/src/utils/naming.ts index 2ebd9c1fc..d258cee81 100644 --- a/src/renderer/src/utils/naming.ts +++ b/src/renderer/src/utils/naming.ts @@ -2,6 +2,8 @@ import { getProviderLabel } from '@renderer/i18n/label' import type { Provider } from '@renderer/types' import { isSystemProvider } from '@renderer/types' +export { getBaseModelName, getLowerBaseModelName } from '@shared/utils/naming' + /** * 从模型 ID 中提取默认组名。 * 规则如下: @@ -50,42 +52,6 @@ export const getDefaultGroupName = (id: string, provider?: string): string => { return str } -/** - * 从模型 ID 中提取基础名称。 - * 例如: - * - 'deepseek/deepseek-r1' => 'deepseek-r1' - * - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1' - * @param {string} id 模型 ID - * @param {string} [delimiter='/'] 分隔符,默认为 '/' - * @returns {string} 基础名称 - */ -export const getBaseModelName = (id: string, delimiter: string = '/'): string => { - const parts = id.split(delimiter) - return parts[parts.length - 1] -} - -/** - * 从模型 ID 中提取基础名称并转换为小写。 - * 例如: - * - 'deepseek/DeepSeek-R1' => 'deepseek-r1' - * - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1' - * @param {string} id 模型 ID - * @param {string} [delimiter='/'] 分隔符,默认为 '/' - * @returns {string} 小写的基础名称 - */ -export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => { - const baseModelName = getBaseModelName(id, delimiter).toLowerCase() - // for openrouter - if (baseModelName.endsWith(':free')) { - return baseModelName.replace(':free', '') - } - // for cherryin - if (baseModelName.endsWith('(free)')) { - return baseModelName.replace('(free)', '') - } - return baseModelName -} - /** * 获取模型服务商名称,根据是否内置服务商来决定要不要翻译 * @param provider 服务商 diff --git a/src/renderer/src/utils/provider.ts b/src/renderer/src/utils/provider.ts index 86544de99..b66d9098e 100644 --- a/src/renderer/src/utils/provider.ts +++ b/src/renderer/src/utils/provider.ts @@ -1,10 +1,21 @@ import { CLAUDE_SUPPORTED_PROVIDERS } from '@renderer/pages/code' -import type { AzureOpenAIProvider, ProviderType, VertexProvider } from '@renderer/types' +import type { ProviderType } from '@renderer/types' import { isSystemProvider, type Provider, type SystemProviderId, SystemProviderIds } from '@renderer/types' - -export const isAzureResponsesEndpoint = (provider: AzureOpenAIProvider) => { - return provider.apiVersion === 'preview' || provider.apiVersion === 'v1' -} +export { + isAIGatewayProvider, + isAnthropicProvider, + isAwsBedrockProvider, + isAzureOpenAIProvider, + isAzureResponsesEndpoint, + isCherryAIProvider, + isGeminiProvider, + isNewApiProvider, + isOllamaProvider, + isOpenAICompatibleProvider, + isOpenAIProvider, + isPerplexityProvider, + isVertexProvider +} from '@shared/provider' export const getClaudeSupportedProviders = (providers: Provider[]) => { return providers.filter( @@ -127,59 +138,6 @@ export const isGeminiWebSearchProvider = (provider: Provider) => { return SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS.some((id) => id === provider.id) } -export const isNewApiProvider = (provider: Provider) => { - return ['new-api', 'cherryin'].includes(provider.id) || provider.type === 'new-api' -} - -export function isCherryAIProvider(provider: Provider): boolean { - return provider.id === 'cherryai' -} - -export function isPerplexityProvider(provider: Provider): boolean { - return provider.id === 'perplexity' -} - -/** - * 判断是否为 OpenAI 兼容的提供商 - * @param {Provider} provider 提供商对象 - * @returns {boolean} 是否为 OpenAI 兼容提供商 - */ -export function isOpenAICompatibleProvider(provider: Provider): boolean { - return ['openai', 'new-api', 'mistral'].includes(provider.type) -} - -export function isAzureOpenAIProvider(provider: Provider): provider is AzureOpenAIProvider { - return provider.type === 'azure-openai' -} - -export function isOpenAIProvider(provider: Provider): boolean { - return provider.type === 'openai-response' -} - -export function isVertexProvider(provider: Provider): provider is VertexProvider { - return provider.type === 'vertexai' -} - -export function isAwsBedrockProvider(provider: Provider): boolean { - return provider.type === 'aws-bedrock' -} - -export function isAnthropicProvider(provider: Provider): boolean { - return provider.type === 'anthropic' -} - -export function isGeminiProvider(provider: Provider): boolean { - return provider.type === 'gemini' -} - -export function isAIGatewayProvider(provider: Provider): boolean { - return provider.type === 'gateway' -} - -export function isOllamaProvider(provider: Provider): boolean { - return provider.type === 'ollama' -} - const NOT_SUPPORT_API_VERSION_PROVIDERS = ['github', 'copilot', 'perplexity'] as const satisfies SystemProviderId[] export const isSupportAPIVersionProvider = (provider: Provider) => { diff --git a/tests/main.setup.ts b/tests/main.setup.ts index 5cadb89d0..237078156 100644 --- a/tests/main.setup.ts +++ b/tests/main.setup.ts @@ -61,7 +61,19 @@ vi.mock('electron', () => ({ getPrimaryDisplay: vi.fn(), getAllDisplays: vi.fn() }, - Notification: vi.fn() + Notification: vi.fn(), + net: { + fetch: vi.fn(() => + Promise.resolve({ + ok: true, + status: 200, + statusText: 'OK', + json: vi.fn(() => Promise.resolve({})), + text: vi.fn(() => Promise.resolve('')), + headers: new Headers() + }) + ) + } })) // Mock Winston for LoggerService dependencies @@ -97,15 +109,40 @@ vi.mock('winston-daily-rotate-file', () => { })) }) -// Mock Node.js modules -vi.mock('node:os', () => ({ - platform: vi.fn(() => 'darwin'), - arch: vi.fn(() => 'x64'), - version: vi.fn(() => '20.0.0'), - cpus: vi.fn(() => [{ model: 'Mock CPU' }]), - totalmem: vi.fn(() => 8 * 1024 * 1024 * 1024) // 8GB +// Mock main process services +vi.mock('@main/services/AnthropicService', () => ({ + default: {} })) +vi.mock('@main/services/CopilotService', () => ({ + default: {} +})) + +vi.mock('@main/services/ReduxService', () => ({ + reduxService: { + selectSync: vi.fn() + } +})) + +vi.mock('@main/integration/cherryai', () => ({ + generateSignature: vi.fn() +})) + +// Mock Node.js modules +vi.mock('node:os', async () => { + const actual = await vi.importActual('node:os') + return { + ...actual, + default: actual, + platform: vi.fn(() => 'darwin'), + arch: vi.fn(() => 'x64'), + version: vi.fn(() => '20.0.0'), + cpus: vi.fn(() => [{ model: 'Mock CPU' }]), + totalmem: vi.fn(() => 8 * 1024 * 1024 * 1024), // 8GB + homedir: vi.fn(() => '/tmp') + } +}) + vi.mock('node:path', async () => { const actual = await vi.importActual('node:path') return { diff --git a/tsconfig.node.json b/tsconfig.node.json index 6953fa7b3..61d1e404c 100644 --- a/tsconfig.node.json +++ b/tsconfig.node.json @@ -8,8 +8,10 @@ "src/preload/**/*", "src/renderer/src/services/traceApi.ts", "src/renderer/src/types/*", + "packages/aiCore/src/**/*", "packages/mcp-trace/**/*", "packages/shared/**/*", + "packages/ai-sdk-provider/**/*" ], "compilerOptions": { "composite": true, @@ -26,7 +28,12 @@ "@types": ["./src/renderer/src/types/index.ts"], "@shared/*": ["./packages/shared/*"], "@mcp-trace/*": ["./packages/mcp-trace/*"], - "@modelcontextprotocol/sdk/*": ["./node_modules/@modelcontextprotocol/sdk/dist/esm/*"] + "@modelcontextprotocol/sdk/*": ["./node_modules/@modelcontextprotocol/sdk/dist/esm/*"], + "@cherrystudio/ai-core/provider": ["./packages/aiCore/src/core/providers/index.ts"], + "@cherrystudio/ai-core/built-in/plugins": ["./packages/aiCore/src/core/plugins/built-in/index.ts"], + "@cherrystudio/ai-core/*": ["./packages/aiCore/src/*"], + "@cherrystudio/ai-core": ["./packages/aiCore/src/index.ts"], + "@cherrystudio/ai-sdk-provider": ["./packages/ai-sdk-provider/src/index.ts"] }, "experimentalDecorators": true, "emitDecoratorMetadata": true,