diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts index a50356130d..61e6f49b81 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts @@ -35,7 +35,6 @@ export interface WebSearchPluginConfig { anthropic?: AnthropicSearchConfig xai?: ProviderOptionsMap['xai']['searchParameters'] google?: GoogleSearchConfig - 'google-vertex'?: GoogleSearchConfig openrouter?: OpenRouterSearchConfig } @@ -44,7 +43,6 @@ export interface WebSearchPluginConfig { */ export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = { google: {}, - 'google-vertex': {}, openai: {}, 'openai-chat': {}, xai: { @@ -97,55 +95,28 @@ export type WebSearchToolInputSchema = { 'openai-chat': InferToolInput } -export const switchWebSearchTool = (providerId: string, config: WebSearchPluginConfig, params: any) => { - switch (providerId) { - case 'openai': { - if (config.openai) { - if (!params.tools) params.tools = {} - params.tools.web_search = openai.tools.webSearch(config.openai) - } - break - } - case 'openai-chat': { - if (config['openai-chat']) { - if (!params.tools) params.tools = {} - params.tools.web_search_preview = openai.tools.webSearchPreview(config['openai-chat']) - } - break - } - - case 'anthropic': { - if (config.anthropic) { - if (!params.tools) params.tools = {} - params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic) - } - break - } - - case 'google': { - // case 'google-vertex': - if (!params.tools) params.tools = {} - params.tools.web_search = google.tools.googleSearch(config.google || {}) - break - } - - case 'xai': { - if (config.xai) { - const searchOptions = createXaiOptions({ - searchParameters: { ...config.xai, mode: 'on' } - }) - params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) - } - break - } - - case 'openrouter': { - if (config.openrouter) { - const searchOptions = createOpenRouterOptions(config.openrouter) - params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) - } - break - } +export const switchWebSearchTool = (config: WebSearchPluginConfig, params: any) => { + if (config.openai) { + if (!params.tools) params.tools = {} + params.tools.web_search = openai.tools.webSearch(config.openai) + } else if (config['openai-chat']) { + if (!params.tools) params.tools = {} + params.tools.web_search_preview = openai.tools.webSearchPreview(config['openai-chat']) + } else if (config.anthropic) { + if (!params.tools) params.tools = {} + params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic) + } else if (config.google) { + // case 'google-vertex': + if (!params.tools) params.tools = {} + params.tools.web_search = google.tools.googleSearch(config.google || {}) + } else if (config.xai) { + const searchOptions = createXaiOptions({ + searchParameters: { ...config.xai, mode: 'on' } + }) + params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) + } else if (config.openrouter) { + const searchOptions = createOpenRouterOptions(config.openrouter) + params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) } return params } diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts index 75692cdf36..a46df7dd4c 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts @@ -4,7 +4,6 @@ */ import { definePlugin } from '../../' -import type { AiRequestContext } from '../../types' import type { WebSearchPluginConfig } from './helper' import { DEFAULT_WEB_SEARCH_CONFIG, switchWebSearchTool } from './helper' @@ -18,15 +17,8 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR name: 'webSearch', enforce: 'pre', - transformParams: async (params: any, context: AiRequestContext) => { - const { providerId } = context - switchWebSearchTool(providerId, config, params) - - if (providerId === 'cherryin' || providerId === 'cherryin-chat') { - // cherryin.gemini - const _providerId = params.model.provider.split('.')[1] - switchWebSearchTool(_providerId, config, params) - } + transformParams: async (params: any) => { + switchWebSearchTool(config, params) return params } }) diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index ed92d4ddd8..8c031f7754 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -189,7 +189,7 @@ export default class ModernAiProvider { config: ModernAiProviderConfig ): Promise { // ai-gateway不是image/generation 端点,所以就先不走legacy了 - if (config.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds['ai-gateway']) { + if (config.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds.gateway) { // 使用 legacy 实现处理图像生成(支持图片编辑等高级功能) if (!config.uiMessages) { throw new Error('uiMessages is required for image generation endpoint') @@ -480,7 +480,7 @@ export default class ModernAiProvider { // 代理其他方法到原有实现 public async models() { - if (this.actualProvider.id === SystemProviderIds['ai-gateway']) { + if (this.actualProvider.id === SystemProviderIds.gateway) { const formatModel = function (models: GatewayLanguageModelEntry[]): Model[] { return models.map((m) => ({ id: m.id, diff --git a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts index 5749e6c7e6..cba7fcdb10 100644 --- a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts +++ b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts @@ -11,11 +11,15 @@ import { vertex } from '@ai-sdk/google-vertex/edge' import { combineHeaders } from '@ai-sdk/provider-utils' import type { AnthropicSearchConfig, WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins' import { isBaseProvider } from '@cherrystudio/ai-core/core/providers/schemas' +import type { BaseProviderId } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' import { isAnthropicModel, isFixedReasoningModel, + isGeminiModel, isGenerateImageModel, + isGrokModel, + isOpenAIModel, isOpenRouterBuiltInWebSearchModel, isSupportedReasoningEffortModel, isSupportedThinkingTokenModel, @@ -24,11 +28,12 @@ import { import { getDefaultModel } from '@renderer/services/AssistantService' import store from '@renderer/store' import type { CherryWebSearchConfig } from '@renderer/store/websearch' -import { type Assistant, type MCPTool, type Provider } from '@renderer/types' +import type { Model } from '@renderer/types' +import { type Assistant, type MCPTool, type Provider, SystemProviderIds } from '@renderer/types' import type { StreamTextParams } from '@renderer/types/aiCoreTypes' import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern' import { replacePromptVariables } from '@renderer/utils/prompt' -import { isAwsBedrockProvider } from '@renderer/utils/provider' +import { isAIGatewayProvider, isAwsBedrockProvider } from '@renderer/utils/provider' import type { ModelMessage, Tool } from 'ai' import { stepCountIs } from 'ai' @@ -43,6 +48,25 @@ const logger = loggerService.withContext('parameterBuilder') type ProviderDefinedTool = Extract, { type: 'provider-defined' }> +function mapVertexAIGatewayModelToProviderId(model: Model): BaseProviderId | undefined { + if (isAnthropicModel(model)) { + return 'anthropic' + } + if (isGeminiModel(model)) { + return 'google' + } + if (isGrokModel(model)) { + return 'xai' + } + if (isOpenAIModel(model)) { + return 'openai' + } + logger.warn( + `[mapVertexAIGatewayModelToProviderId] Unknown model type for AI Gateway: ${model.id}. Web search will not be enabled.` + ) + return undefined +} + /** * 构建 AI SDK 流式参数 * 这是主要的参数构建函数,整合所有转换逻辑 @@ -117,6 +141,11 @@ export async function buildStreamTextParams( if (enableWebSearch) { if (isBaseProvider(aiSdkProviderId)) { webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(aiSdkProviderId, webSearchConfig, model) + } else if (isAIGatewayProvider(provider) || SystemProviderIds.gateway === provider.id) { + const aiSdkProviderId = mapVertexAIGatewayModelToProviderId(model) + if (aiSdkProviderId) { + webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(aiSdkProviderId, webSearchConfig, model) + } } if (!tools) { tools = {} diff --git a/src/renderer/src/aiCore/provider/providerInitialization.ts b/src/renderer/src/aiCore/provider/providerInitialization.ts index a42e2ac659..51176c1e60 100644 --- a/src/renderer/src/aiCore/provider/providerInitialization.ts +++ b/src/renderer/src/aiCore/provider/providerInitialization.ts @@ -1,5 +1,6 @@ import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' +import * as z from 'zod' const logger = loggerService.withContext('ProviderConfigs') @@ -81,12 +82,12 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [ aliases: ['hf', 'hugging-face'] }, { - id: 'ai-gateway', - name: 'AI Gateway', + id: 'gateway', + name: 'Vercel AI Gateway', import: () => import('@ai-sdk/gateway'), creatorFunctionName: 'createGateway', supportsImageGeneration: true, - aliases: ['gateway'] + aliases: ['ai-gateway'] }, { id: 'cerebras', @@ -104,6 +105,9 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [ } ] as const +export const registeredNewProviderIds = NEW_PROVIDER_CONFIGS.map((config) => config.id) +export const registeredNewProviderIdSchema = z.enum(registeredNewProviderIds) + /** * 初始化新的Providers * 使用aiCore的动态注册功能 diff --git a/src/renderer/src/aiCore/utils/__tests__/options.test.ts b/src/renderer/src/aiCore/utils/__tests__/options.test.ts index ca6b883d74..9eeeac725b 100644 --- a/src/renderer/src/aiCore/utils/__tests__/options.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/options.test.ts @@ -27,7 +27,8 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => { 'xai', 'deepseek', 'openrouter', - 'openai-compatible' + 'openai-compatible', + 'cherryin' ] if (baseProviders.includes(id)) { return { success: true, data: id } @@ -37,7 +38,15 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => { }, customProviderIdSchema: { safeParse: vi.fn((id) => { - const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock'] + const customProviders = [ + 'google-vertex', + 'google-vertex-anthropic', + 'bedrock', + 'gateway', + 'aihubmix', + 'newapi', + 'ollama' + ] if (customProviders.includes(id)) { return { success: true, data: id } } @@ -47,20 +56,7 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => { } }) -vi.mock('../provider/factory', () => ({ - getAiSdkProviderId: vi.fn((provider) => { - // Simulate the provider ID mapping - const mapping: Record = { - [SystemProviderIds.gemini]: 'google', - [SystemProviderIds.openai]: 'openai', - [SystemProviderIds.anthropic]: 'anthropic', - [SystemProviderIds.grok]: 'xai', - [SystemProviderIds.deepseek]: 'deepseek', - [SystemProviderIds.openrouter]: 'openrouter' - } - return mapping[provider.id] || provider.id - }) -})) +// Don't mock getAiSdkProviderId - use real implementation for more accurate tests vi.mock('@renderer/config/models', async (importOriginal) => ({ ...(await importOriginal()), @@ -179,8 +175,11 @@ describe('options utils', () => { provider: SystemProviderIds.openai } as Model - beforeEach(() => { + beforeEach(async () => { vi.clearAllMocks() + // Reset getCustomParameters to return empty object by default + const { getCustomParameters } = await import('../reasoning') + vi.mocked(getCustomParameters).mockReturnValue({}) }) describe('buildProviderOptions', () => { @@ -391,7 +390,6 @@ describe('options utils', () => { enableWebSearch: false, enableGenerateImage: false }) - expect(result.providerOptions).toHaveProperty('deepseek') expect(result.providerOptions.deepseek).toBeDefined() }) @@ -461,10 +459,14 @@ describe('options utils', () => { } ) - expect(result.providerOptions.openai).toHaveProperty('custom_param') - expect(result.providerOptions.openai.custom_param).toBe('custom_value') - expect(result.providerOptions.openai).toHaveProperty('another_param') - expect(result.providerOptions.openai.another_param).toBe(123) + expect(result.providerOptions).toStrictEqual({ + openai: { + custom_param: 'custom_value', + another_param: 123, + serviceTier: undefined, + textVerbosity: undefined + } + }) }) it('should extract AI SDK standard params from custom parameters', async () => { @@ -696,5 +698,459 @@ describe('options utils', () => { }) }) }) + + describe('AI Gateway provider', () => { + const gatewayProvider: Provider = { + id: SystemProviderIds.gateway, + name: 'Vercel AI Gateway', + type: 'gateway', + apiKey: 'test-key', + apiHost: 'https://gateway.vercel.com', + isSystem: true + } as Provider + + it('should build OpenAI options for OpenAI models through gateway', () => { + const openaiModel: Model = { + id: 'openai/gpt-4', + name: 'GPT-4', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, openaiModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + expect(result.providerOptions).toHaveProperty('openai') + expect(result.providerOptions.openai).toBeDefined() + }) + + it('should build Anthropic options for Anthropic models through gateway', () => { + const anthropicModel: Model = { + id: 'anthropic/claude-3-5-sonnet-20241022', + name: 'Claude 3.5 Sonnet', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + expect(result.providerOptions).toHaveProperty('anthropic') + expect(result.providerOptions.anthropic).toBeDefined() + }) + + it('should build Google options for Gemini models through gateway', () => { + const geminiModel: Model = { + id: 'google/gemini-2.0-flash-exp', + name: 'Gemini 2.0 Flash', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, geminiModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + expect(result.providerOptions).toHaveProperty('google') + expect(result.providerOptions.google).toBeDefined() + }) + + it('should build xAI options for Grok models through gateway', () => { + const grokModel: Model = { + id: 'xai/grok-2-latest', + name: 'Grok 2', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, grokModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + expect(result.providerOptions).toHaveProperty('xai') + expect(result.providerOptions.xai).toBeDefined() + }) + + it('should include reasoning parameters for Anthropic models when enabled', () => { + const anthropicModel: Model = { + id: 'anthropic/claude-3-5-sonnet-20241022', + name: 'Claude 3.5 Sonnet', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, { + enableReasoning: true, + enableWebSearch: false, + enableGenerateImage: false + }) + + expect(result.providerOptions.anthropic).toHaveProperty('thinking') + expect(result.providerOptions.anthropic.thinking).toEqual({ + type: 'enabled', + budgetTokens: 5000 + }) + }) + + it('should merge gateway routing options from custom parameters', async () => { + const { getCustomParameters } = await import('../reasoning') + + vi.mocked(getCustomParameters).mockReturnValue({ + gateway: { + order: ['vertex', 'anthropic'], + only: ['vertex', 'anthropic'] + } + }) + + const anthropicModel: Model = { + id: 'anthropic/claude-3-5-sonnet-20241022', + name: 'Claude 3.5 Sonnet', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should have both anthropic provider options and gateway routing options + expect(result.providerOptions).toHaveProperty('anthropic') + expect(result.providerOptions).toHaveProperty('gateway') + expect(result.providerOptions.gateway).toEqual({ + order: ['vertex', 'anthropic'], + only: ['vertex', 'anthropic'] + }) + }) + + it('should combine provider-specific options with gateway routing options', async () => { + const { getCustomParameters } = await import('../reasoning') + + vi.mocked(getCustomParameters).mockReturnValue({ + gateway: { + order: ['openai', 'anthropic'] + } + }) + + const openaiModel: Model = { + id: 'openai/gpt-4', + name: 'GPT-4', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, openaiModel, gatewayProvider, { + enableReasoning: true, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should have OpenAI provider options with reasoning + expect(result.providerOptions.openai).toBeDefined() + expect(result.providerOptions.openai).toHaveProperty('reasoningEffort') + + // Should also have gateway routing options + expect(result.providerOptions.gateway).toBeDefined() + expect(result.providerOptions.gateway.order).toEqual(['openai', 'anthropic']) + }) + + it('should build generic options for unknown model types through gateway', () => { + const unknownModel: Model = { + id: 'unknown-provider/model-name', + name: 'Unknown Model', + provider: SystemProviderIds.gateway + } as Model + + const result = buildProviderOptions(mockAssistant, unknownModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + expect(result.providerOptions).toHaveProperty('openai-compatible') + expect(result.providerOptions['openai-compatible']).toBeDefined() + }) + }) + + describe('Proxy provider custom parameters mapping', () => { + it('should map cherryin provider ID to actual AI SDK provider ID (Google)', async () => { + const { getCustomParameters } = await import('../reasoning') + + // Mock Cherry In provider that uses Google SDK + const cherryinProvider = { + id: 'cherryin', + name: 'Cherry In', + type: 'gemini', // Using Google SDK + apiKey: 'test-key', + apiHost: 'https://cherryin.com', + models: [] as Model[] + } as Provider + + const geminiModel: Model = { + id: 'gemini-2.0-flash-exp', + name: 'Gemini 2.0 Flash', + provider: 'cherryin' + } as Model + + // User provides custom parameters with Cherry Studio provider ID + vi.mocked(getCustomParameters).mockReturnValue({ + cherryin: { + customOption1: 'value1', + customOption2: 'value2' + } + }) + + const result = buildProviderOptions(mockAssistant, geminiModel, cherryinProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should map to 'google' AI SDK provider, not 'cherryin' + expect(result.providerOptions).toHaveProperty('google') + expect(result.providerOptions).not.toHaveProperty('cherryin') + expect(result.providerOptions.google).toMatchObject({ + customOption1: 'value1', + customOption2: 'value2' + }) + }) + + it('should map cherryin provider ID to actual AI SDK provider ID (OpenAI)', async () => { + const { getCustomParameters } = await import('../reasoning') + + // Mock Cherry In provider that uses OpenAI SDK + const cherryinProvider = { + id: 'cherryin', + name: 'Cherry In', + type: 'openai-response', // Using OpenAI SDK + apiKey: 'test-key', + apiHost: 'https://cherryin.com', + models: [] as Model[] + } as Provider + + const openaiModel: Model = { + id: 'gpt-4', + name: 'GPT-4', + provider: 'cherryin' + } as Model + + // User provides custom parameters with Cherry Studio provider ID + vi.mocked(getCustomParameters).mockReturnValue({ + cherryin: { + customOpenAIOption: 'openai_value' + } + }) + + const result = buildProviderOptions(mockAssistant, openaiModel, cherryinProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should map to 'openai' AI SDK provider, not 'cherryin' + expect(result.providerOptions).toHaveProperty('openai') + expect(result.providerOptions).not.toHaveProperty('cherryin') + expect(result.providerOptions.openai).toMatchObject({ + customOpenAIOption: 'openai_value' + }) + }) + + it('should allow direct AI SDK provider ID in custom parameters', async () => { + const { getCustomParameters } = await import('../reasoning') + + const geminiProvider = { + id: SystemProviderIds.gemini, + name: 'Google', + type: 'gemini', + apiKey: 'test-key', + apiHost: 'https://generativelanguage.googleapis.com', + models: [] as Model[] + } as Provider + + const geminiModel: Model = { + id: 'gemini-2.0-flash-exp', + name: 'Gemini 2.0 Flash', + provider: SystemProviderIds.gemini + } as Model + + // User provides custom parameters directly with AI SDK provider ID + vi.mocked(getCustomParameters).mockReturnValue({ + google: { + directGoogleOption: 'google_value' + } + }) + + const result = buildProviderOptions(mockAssistant, geminiModel, geminiProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should merge directly to 'google' provider + expect(result.providerOptions.google).toMatchObject({ + directGoogleOption: 'google_value' + }) + }) + + it('should map gateway provider custom parameters to actual AI SDK provider', async () => { + const { getCustomParameters } = await import('../reasoning') + + const gatewayProvider: Provider = { + id: SystemProviderIds.gateway, + name: 'Vercel AI Gateway', + type: 'gateway', + apiKey: 'test-key', + apiHost: 'https://gateway.vercel.com', + isSystem: true + } as Provider + + const anthropicModel: Model = { + id: 'anthropic/claude-3-5-sonnet-20241022', + name: 'Claude 3.5 Sonnet', + provider: SystemProviderIds.gateway + } as Model + + // User provides both gateway routing options and gateway-scoped custom parameters + vi.mocked(getCustomParameters).mockReturnValue({ + gateway: { + order: ['vertex', 'anthropic'], + only: ['vertex'] + }, + customParam: 'should_go_to_anthropic' + }) + + const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Gateway routing options should be preserved + expect(result.providerOptions.gateway).toEqual({ + order: ['vertex', 'anthropic'], + only: ['vertex'] + }) + + // Custom parameters should go to the actual AI SDK provider (anthropic) + expect(result.providerOptions.anthropic).toMatchObject({ + customParam: 'should_go_to_anthropic' + }) + }) + + it('should handle mixed custom parameters (AI SDK provider ID + custom params)', async () => { + const { getCustomParameters } = await import('../reasoning') + + const openaiProvider: Provider = { + id: SystemProviderIds.openai, + name: 'OpenAI', + type: 'openai-response', + apiKey: 'test-key', + apiHost: 'https://api.openai.com/v1', + isSystem: true + } as Provider + + // User provides both direct AI SDK provider params and custom params + vi.mocked(getCustomParameters).mockReturnValue({ + openai: { + providerSpecific: 'value1' + }, + customParam1: 'value2', + customParam2: 123 + }) + + const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should merge both into 'openai' provider options + expect(result.providerOptions.openai).toMatchObject({ + providerSpecific: 'value1', + customParam1: 'value2', + customParam2: 123 + }) + }) + + // Note: For proxy providers like aihubmix/newapi, users should write AI SDK provider ID (google/anthropic) + // instead of the Cherry Studio provider ID for custom parameters to work correctly + + it('should handle cherryin fallback to openai-compatible with custom parameters', async () => { + const { getCustomParameters } = await import('../reasoning') + + // Mock cherryin provider that falls back to openai-compatible (default case) + const cherryinProvider = { + id: 'cherryin', + name: 'Cherry In', + type: 'openai', + apiKey: 'test-key', + apiHost: 'https://cherryin.com', + models: [] as Model[] + } as Provider + + const testModel: Model = { + id: 'some-model', + name: 'Some Model', + provider: 'cherryin' + } as Model + + // User provides custom parameters with cherryin provider ID + vi.mocked(getCustomParameters).mockReturnValue({ + customCherryinOption: 'cherryin_value' + }) + + const result = buildProviderOptions(mockAssistant, testModel, cherryinProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // When cherryin falls back to default case, it should use rawProviderId (cherryin) + // User's cherryin params should merge with the provider options + expect(result.providerOptions).toHaveProperty('cherryin') + expect(result.providerOptions.cherryin).toMatchObject({ + customCherryinOption: 'cherryin_value' + }) + }) + + it('should handle cross-provider configurations', async () => { + const { getCustomParameters } = await import('../reasoning') + + const openaiProvider: Provider = { + id: SystemProviderIds.openai, + name: 'OpenAI', + type: 'openai-response', + apiKey: 'test-key', + apiHost: 'https://api.openai.com/v1', + isSystem: true + } as Provider + + // User provides parameters for multiple providers + // In real usage, anthropic/google params would be treated as regular params for openai provider + vi.mocked(getCustomParameters).mockReturnValue({ + openai: { + openaiSpecific: 'openai_value' + }, + customParam: 'value' + }) + + const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + }) + + // Should have openai provider options with both scoped and custom params + expect(result.providerOptions).toHaveProperty('openai') + expect(result.providerOptions.openai).toMatchObject({ + openaiSpecific: 'openai_value', + customParam: 'value' + }) + }) + }) }) }) diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index 2efaa69cd0..8ec46c9df2 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -1,5 +1,5 @@ import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock' -import type { AnthropicProviderOptions } from '@ai-sdk/anthropic' +import { type AnthropicProviderOptions } from '@ai-sdk/anthropic' import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google' import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai' import type { XaiProviderOptions } from '@ai-sdk/xai' @@ -7,6 +7,9 @@ import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-c import { loggerService } from '@logger' import { getModelSupportedVerbosity, + isAnthropicModel, + isGeminiModel, + isGrokModel, isOpenAIModel, isQwenMTModel, isSupportFlexServiceTierModel, @@ -158,8 +161,8 @@ export function buildProviderOptions( providerOptions: Record> standardParams: Partial> } { - logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities }) const rawProviderId = getAiSdkProviderId(actualProvider) + logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities, rawProviderId }) // 构建 provider 特定的选项 let providerSpecificOptions: Record = {} const serviceTier = getServiceTier(model, actualProvider) @@ -174,14 +177,13 @@ export function buildProviderOptions( case 'azure': case 'azure-responses': { - const options: OpenAIResponsesProviderOptions = buildOpenAIProviderOptions( + providerSpecificOptions = buildOpenAIProviderOptions( assistant, model, capabilities, serviceTier, textVerbosity ) - providerSpecificOptions = options } break case 'anthropic': @@ -199,10 +201,13 @@ export function buildProviderOptions( case 'openrouter': case 'openai-compatible': { // 对于其他 provider,使用通用的构建逻辑 + const genericOptions = buildGenericProviderOptions(rawProviderId, assistant, model, capabilities) providerSpecificOptions = { - ...buildGenericProviderOptions(assistant, model, capabilities), - serviceTier, - textVerbosity + [rawProviderId]: { + ...genericOptions[rawProviderId], + serviceTier, + textVerbosity + } } break } @@ -241,50 +246,105 @@ export function buildProviderOptions( case SystemProviderIds.ollama: providerSpecificOptions = buildOllamaProviderOptions(assistant, capabilities) break + case SystemProviderIds.gateway: + providerSpecificOptions = buildAIGatewayOptions(assistant, model, capabilities, serviceTier, textVerbosity) + break default: // 对于其他 provider,使用通用的构建逻辑 + providerSpecificOptions = buildGenericProviderOptions(rawProviderId, assistant, model, capabilities) + // Merge serviceTier and textVerbosity providerSpecificOptions = { - ...buildGenericProviderOptions(assistant, model, capabilities), - serviceTier, - textVerbosity + ...providerSpecificOptions, + [rawProviderId]: { + ...providerSpecificOptions[rawProviderId], + serviceTier, + textVerbosity + } } } } else { throw error } } - - // 获取自定义参数并分离标准参数和 provider 特定参数 + logger.debug('Built providerSpecificOptions', { providerSpecificOptions }) + /** + * Retrieve custom parameters and separate standard parameters from provider-specific parameters. + */ const customParams = getCustomParameters(assistant) const { standardParams, providerParams } = extractAiSdkStandardParams(customParams) + logger.debug('Extracted standardParams and providerParams', { standardParams, providerParams }) - // 合并 provider 特定的自定义参数到 providerSpecificOptions - providerSpecificOptions = { - ...providerSpecificOptions, - ...providerParams - } - - let rawProviderKey = - { - 'google-vertex': 'google', - 'google-vertex-anthropic': 'anthropic', - 'azure-anthropic': 'anthropic', - 'ai-gateway': 'gateway', - azure: 'openai', - 'azure-responses': 'openai' - }[rawProviderId] || rawProviderId - - if (rawProviderKey === 'cherryin') { - rawProviderKey = - { gemini: 'google', ['openai-response']: 'openai', openai: 'cherryin' }[actualProvider.type] || - actualProvider.type + /** + * Get the actual AI SDK provider ID(s) from the already-built providerSpecificOptions. + * For proxy providers (cherryin, aihubmix, newapi), this will be the actual SDK provider (e.g., 'google', 'openai', 'anthropic') + * For regular providers, this will be the provider itself + */ + const actualAiSdkProviderIds = Object.keys(providerSpecificOptions) + const primaryAiSdkProviderId = actualAiSdkProviderIds[0] // Use the first one as primary for non-scoped params + + /** + * Merge custom parameters into providerSpecificOptions. + * Simple logic: + * 1. If key is in actualAiSdkProviderIds → merge directly (user knows the actual AI SDK provider ID) + * 2. If key == rawProviderId: + * - If it's gateway/ollama → preserve (they need their own config for routing/options) + * - Otherwise → map to primary (this is a proxy provider like cherryin) + * 3. Otherwise → treat as regular parameter, merge to primary provider + * + * Example: + * - User writes `cherryin: { opt: 'val' }` → mapped to `google: { opt: 'val' }` (case 2, proxy) + * - User writes `gateway: { order: [...] }` → stays as `gateway: { order: [...] }` (case 2, routing config) + * - User writes `google: { opt: 'val' }` → stays as `google: { opt: 'val' }` (case 1) + * - User writes `customKey: 'val'` → merged to `google: { customKey: 'val' }` (case 3) + */ + for (const key of Object.keys(providerParams)) { + if (actualAiSdkProviderIds.includes(key)) { + // Case 1: Key is an actual AI SDK provider ID - merge directly + providerSpecificOptions = { + ...providerSpecificOptions, + [key]: { + ...providerSpecificOptions[key], + ...providerParams[key] + } + } + } else if (key === rawProviderId && !actualAiSdkProviderIds.includes(rawProviderId)) { + // Case 2: Key is the current provider (not in actualAiSdkProviderIds, so it's a proxy or special provider) + // Gateway is special: it needs routing config preserved + if (key === SystemProviderIds.gateway) { + // Preserve gateway config for routing + providerSpecificOptions = { + ...providerSpecificOptions, + [key]: { + ...providerSpecificOptions[key], + ...providerParams[key] + } + } + } else { + // Proxy provider (cherryin, etc.) - map to actual AI SDK provider + providerSpecificOptions = { + ...providerSpecificOptions, + [primaryAiSdkProviderId]: { + ...providerSpecificOptions[primaryAiSdkProviderId], + ...providerParams[key] + } + } + } + } else { + // Case 3: Regular parameter - merge to primary provider + providerSpecificOptions = { + ...providerSpecificOptions, + [primaryAiSdkProviderId]: { + ...providerSpecificOptions[primaryAiSdkProviderId], + [key]: providerParams[key] + } + } + } } + logger.debug('Final providerSpecificOptions after merging providerParams', { providerSpecificOptions }) // 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数 return { - providerOptions: { - [rawProviderKey]: providerSpecificOptions - }, + providerOptions: providerSpecificOptions, standardParams } } @@ -302,7 +362,7 @@ function buildOpenAIProviderOptions( }, serviceTier: OpenAIServiceTier, textVerbosity?: OpenAIVerbosity -): OpenAIResponsesProviderOptions { +): Record { const { enableReasoning } = capabilities let providerOptions: OpenAIResponsesProviderOptions = {} // OpenAI 推理参数 @@ -341,7 +401,9 @@ function buildOpenAIProviderOptions( textVerbosity } - return providerOptions + return { + openai: providerOptions + } } /** @@ -355,7 +417,7 @@ function buildAnthropicProviderOptions( enableWebSearch: boolean enableGenerateImage: boolean } -): AnthropicProviderOptions { +): Record { const { enableReasoning } = capabilities let providerOptions: AnthropicProviderOptions = {} @@ -368,7 +430,11 @@ function buildAnthropicProviderOptions( } } - return providerOptions + return { + anthropic: { + ...providerOptions + } + } } /** @@ -382,7 +448,7 @@ function buildGeminiProviderOptions( enableWebSearch: boolean enableGenerateImage: boolean } -): GoogleGenerativeAIProviderOptions { +): Record { const { enableReasoning, enableGenerateImage } = capabilities let providerOptions: GoogleGenerativeAIProviderOptions = {} @@ -402,7 +468,11 @@ function buildGeminiProviderOptions( } } - return providerOptions + return { + google: { + ...providerOptions + } + } } function buildXAIProviderOptions( @@ -413,7 +483,7 @@ function buildXAIProviderOptions( enableWebSearch: boolean enableGenerateImage: boolean } -): XaiProviderOptions { +): Record { const { enableReasoning } = capabilities let providerOptions: Record = {} @@ -425,7 +495,11 @@ function buildXAIProviderOptions( } } - return providerOptions + return { + xai: { + ...providerOptions + } + } } function buildCherryInProviderOptions( @@ -439,21 +513,19 @@ function buildCherryInProviderOptions( actualProvider: Provider, serviceTier: OpenAIServiceTier, textVerbosity: OpenAIVerbosity -): OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions { +): Record { switch (actualProvider.type) { case 'openai': - return buildGenericProviderOptions(assistant, model, capabilities) + return buildGenericProviderOptions('cherryin', assistant, model, capabilities) case 'openai-response': return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity) - case 'anthropic': return buildAnthropicProviderOptions(assistant, model, capabilities) - case 'gemini': return buildGeminiProviderOptions(assistant, model, capabilities) default: - return buildGenericProviderOptions(assistant, model, capabilities) + return buildGenericProviderOptions('cherryin', assistant, model, capabilities) } } @@ -468,7 +540,7 @@ function buildBedrockProviderOptions( enableWebSearch: boolean enableGenerateImage: boolean } -): BedrockProviderOptions { +): Record { const { enableReasoning } = capabilities let providerOptions: BedrockProviderOptions = {} @@ -485,7 +557,9 @@ function buildBedrockProviderOptions( providerOptions.anthropicBeta = betaHeaders } - return providerOptions + return { + bedrock: providerOptions + } } function buildOllamaProviderOptions( @@ -495,20 +569,23 @@ function buildOllamaProviderOptions( enableWebSearch: boolean enableGenerateImage: boolean } -): OllamaCompletionProviderOptions { +): Record { const { enableReasoning } = capabilities const providerOptions: OllamaCompletionProviderOptions = {} const reasoningEffort = assistant.settings?.reasoning_effort if (enableReasoning) { providerOptions.think = !['none', undefined].includes(reasoningEffort) } - return providerOptions + return { + ollama: providerOptions + } } /** * 构建通用的 providerOptions(用于其他 provider) */ function buildGenericProviderOptions( + providerId: string, assistant: Assistant, model: Model, capabilities: { @@ -551,5 +628,37 @@ function buildGenericProviderOptions( } } - return providerOptions + return { + [providerId]: providerOptions + } +} + +function buildAIGatewayOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + }, + serviceTier: OpenAIServiceTier, + textVerbosity?: OpenAIVerbosity +): Record< + string, + | OpenAIResponsesProviderOptions + | AnthropicProviderOptions + | GoogleGenerativeAIProviderOptions + | Record +> { + if (isAnthropicModel(model)) { + return buildAnthropicProviderOptions(assistant, model, capabilities) + } else if (isOpenAIModel(model)) { + return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity) + } else if (isGeminiModel(model)) { + return buildGeminiProviderOptions(assistant, model, capabilities) + } else if (isGrokModel(model)) { + return buildXAIProviderOptions(assistant, model, capabilities) + } else { + return buildGenericProviderOptions('openai-compatible', assistant, model, capabilities) + } } diff --git a/src/renderer/src/config/models/default.ts b/src/renderer/src/config/models/default.ts index 3adf0da53d..b0bb00f23c 100644 --- a/src/renderer/src/config/models/default.ts +++ b/src/renderer/src/config/models/default.ts @@ -1853,7 +1853,7 @@ export const SYSTEM_MODELS: Record = } ], huggingface: [], - 'ai-gateway': [], + gateway: [], cerebras: [ { id: 'gpt-oss-120b', diff --git a/src/renderer/src/config/models/utils.ts b/src/renderer/src/config/models/utils.ts index accf85e2cd..625a5b5c63 100644 --- a/src/renderer/src/config/models/utils.ts +++ b/src/renderer/src/config/models/utils.ts @@ -179,6 +179,11 @@ export const isGeminiModel = (model: Model) => { return modelId.includes('gemini') } +export const isGrokModel = (model: Model) => { + const modelId = getLowerBaseModelName(model.id) + return modelId.includes('grok') +} + // zhipu 视觉推理模型用这组 special token 标记推理结果 export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const diff --git a/src/renderer/src/config/providers.ts b/src/renderer/src/config/providers.ts index 6f50bbfaea..bc32ef3490 100644 --- a/src/renderer/src/config/providers.ts +++ b/src/renderer/src/config/providers.ts @@ -676,10 +676,10 @@ export const SYSTEM_PROVIDERS_CONFIG: Record = isSystem: true, enabled: false }, - 'ai-gateway': { - id: 'ai-gateway', - name: 'AI Gateway', - type: 'ai-gateway', + gateway: { + id: 'gateway', + name: 'Vercel AI Gateway', + type: 'gateway', apiKey: '', apiHost: 'https://ai-gateway.vercel.sh/v1/ai', models: [], @@ -762,7 +762,7 @@ export const PROVIDER_LOGO_MAP: AtLeast = { longcat: LongCatProviderLogo, huggingface: HuggingfaceProviderLogo, sophnet: SophnetProviderLogo, - 'ai-gateway': AIGatewayProviderLogo, + gateway: AIGatewayProviderLogo, cerebras: CerebrasProviderLogo } as const @@ -1413,7 +1413,7 @@ export const PROVIDER_URLS: Record = { models: 'https://huggingface.co/models' } }, - 'ai-gateway': { + gateway: { api: { url: 'https://ai-gateway.vercel.sh/v1/ai' }, diff --git a/src/renderer/src/i18n/label.ts b/src/renderer/src/i18n/label.ts index bd74ecd452..7a6ad843d4 100644 --- a/src/renderer/src/i18n/label.ts +++ b/src/renderer/src/i18n/label.ts @@ -87,7 +87,7 @@ const providerKeyMap = { longcat: 'provider.longcat', huggingface: 'provider.huggingface', sophnet: 'provider.sophnet', - 'ai-gateway': 'provider.ai-gateway', + gateway: 'provider.ai-gateway', cerebras: 'provider.cerebras' } as const diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index 85e76b5cf4..427cbdcffd 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -2531,7 +2531,7 @@ }, "provider": { "302ai": "302.AI", - "ai-gateway": "AI Gateway", + "ai-gateway": "Vercel AI Gateway", "aihubmix": "AiHubMix", "aionly": "AiOnly", "alayanew": "Alaya NeW", diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index 0ccfa0b16d..69f4e63ee4 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -2531,7 +2531,7 @@ }, "provider": { "302ai": "302.AI", - "ai-gateway": "AI Gateway", + "ai-gateway": "Vercel AI Gateway", "aihubmix": "AiHubMix", "aionly": "唯一AI (AiOnly)", "alayanew": "Alaya NeW", diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index 94b51474b9..516d66cdc3 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -67,7 +67,7 @@ const persistedReducer = persistReducer( { key: 'cherry-studio', storage, - version: 180, + version: 181, blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs', 'toolPermissions'], migrate }, diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 0e89227907..d10e2dfcbd 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -2810,7 +2810,7 @@ const migrateConfig = { try { addProvider(state, SystemProviderIds.longcat) - addProvider(state, SystemProviderIds['ai-gateway']) + addProvider(state, 'gateway') addProvider(state, 'cerebras') state.llm.providers.forEach((provider) => { if (provider.id === SystemProviderIds.minimax) { @@ -2932,6 +2932,26 @@ const migrateConfig = { logger.error('migrate 180 error', error as Error) return state } + }, + '181': (state: RootState) => { + try { + state.llm.providers.forEach((provider) => { + if (provider.id === 'ai-gateway') { + provider.id = SystemProviderIds.gateway + } + // Also update model.provider references to avoid orphaned models + provider.models?.forEach((model) => { + if (model.provider === 'ai-gateway') { + model.provider = SystemProviderIds.gateway + } + }) + }) + logger.info('migrate 181 success') + return state + } catch (error) { + logger.error('migrate 181 error', error as Error) + return state + } } } diff --git a/src/renderer/src/types/provider.ts b/src/renderer/src/types/provider.ts index aea72fa287..4e3e34760c 100644 --- a/src/renderer/src/types/provider.ts +++ b/src/renderer/src/types/provider.ts @@ -15,7 +15,7 @@ export const ProviderTypeSchema = z.enum([ 'aws-bedrock', 'vertex-anthropic', 'new-api', - 'ai-gateway', + 'gateway', 'ollama' ]) @@ -188,7 +188,7 @@ export const SystemProviderIdSchema = z.enum([ 'longcat', 'huggingface', 'sophnet', - 'ai-gateway', + 'gateway', 'cerebras' ]) @@ -257,7 +257,7 @@ export const SystemProviderIds = { aionly: 'aionly', longcat: 'longcat', huggingface: 'huggingface', - 'ai-gateway': 'ai-gateway', + gateway: 'gateway', cerebras: 'cerebras' } as const satisfies Record diff --git a/src/renderer/src/utils/__tests__/provider.test.ts b/src/renderer/src/utils/__tests__/provider.test.ts index a7823eda06..269c384901 100644 --- a/src/renderer/src/utils/__tests__/provider.test.ts +++ b/src/renderer/src/utils/__tests__/provider.test.ts @@ -189,7 +189,7 @@ describe('provider utils', () => { expect(isAnthropicProvider(createProvider({ type: 'anthropic' }))).toBe(true) expect(isGeminiProvider(createProvider({ type: 'gemini' }))).toBe(true) - expect(isAIGatewayProvider(createProvider({ type: 'ai-gateway' }))).toBe(true) + expect(isAIGatewayProvider(createProvider({ type: 'gateway' }))).toBe(true) }) it('computes API version support', () => { diff --git a/src/renderer/src/utils/provider.ts b/src/renderer/src/utils/provider.ts index 0af511b97e..0586099cff 100644 --- a/src/renderer/src/utils/provider.ts +++ b/src/renderer/src/utils/provider.ts @@ -172,7 +172,7 @@ export function isGeminiProvider(provider: Provider): boolean { } export function isAIGatewayProvider(provider: Provider): boolean { - return provider.type === 'ai-gateway' + return provider.type === 'gateway' } export function isOllamaProvider(provider: Provider): boolean {