Fix custom parameters placement for Vercel AI Gateway (#11605)

* Initial plan

* Fix custom parameters placement for Vercel AI Gateway

For AI Gateway provider, custom parameters are now placed at the body level
instead of being nested inside providerOptions.gateway. This fixes the issue
where parameters like 'tools' were being incorrectly added to
providerOptions.gateway when they should be at the same level as providerOptions.

Fixes #4197

Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com>

* Revert "Fix custom parameters placement for Vercel AI Gateway"

This reverts commit b14e48dd78.

* fix: rename 'ai-gateway' to 'gateway' across the codebase and update related configurations

* fix: resolve PR review issues for custom parameters field

- Fix Migration 174: use string literal 'ai-gateway' instead of non-existent constant for historical compatibility
- Fix Migration 180: update model.provider references to prevent orphaned models when renaming provider ID
- Add logging in mapVertexAIGatewayModelToProviderId when unknown model type is encountered
- Replace `any` with `Record<string, unknown>` in buildAIGatewayOptions return type for better type safety
- Add gateway mapping to getAiSdkProviderId mock in options.test.ts to match production behavior

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* chore: version

* fix(options): enhance custom parameters handling for proxy providers

* fix(options): add support for cherryin provider with custom parameters handling

* chore

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com>
Co-authored-by: suyao <sy20010504@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Copilot 2025-12-04 21:19:30 +08:00 committed by GitHub
parent 981bb9f451
commit a2a6c62f48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 746 additions and 160 deletions

View File

@ -35,7 +35,6 @@ export interface WebSearchPluginConfig {
anthropic?: AnthropicSearchConfig
xai?: ProviderOptionsMap['xai']['searchParameters']
google?: GoogleSearchConfig
'google-vertex'?: GoogleSearchConfig
openrouter?: OpenRouterSearchConfig
}
@ -44,7 +43,6 @@ export interface WebSearchPluginConfig {
*/
export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
google: {},
'google-vertex': {},
openai: {},
'openai-chat': {},
xai: {
@ -97,55 +95,28 @@ export type WebSearchToolInputSchema = {
'openai-chat': InferToolInput<OpenAIChatWebSearchTool>
}
export const switchWebSearchTool = (providerId: string, config: WebSearchPluginConfig, params: any) => {
switch (providerId) {
case 'openai': {
if (config.openai) {
if (!params.tools) params.tools = {}
params.tools.web_search = openai.tools.webSearch(config.openai)
}
break
}
case 'openai-chat': {
if (config['openai-chat']) {
if (!params.tools) params.tools = {}
params.tools.web_search_preview = openai.tools.webSearchPreview(config['openai-chat'])
}
break
}
case 'anthropic': {
if (config.anthropic) {
if (!params.tools) params.tools = {}
params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic)
}
break
}
case 'google': {
// case 'google-vertex':
if (!params.tools) params.tools = {}
params.tools.web_search = google.tools.googleSearch(config.google || {})
break
}
case 'xai': {
if (config.xai) {
const searchOptions = createXaiOptions({
searchParameters: { ...config.xai, mode: 'on' }
})
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
}
break
}
case 'openrouter': {
if (config.openrouter) {
const searchOptions = createOpenRouterOptions(config.openrouter)
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
}
break
}
export const switchWebSearchTool = (config: WebSearchPluginConfig, params: any) => {
if (config.openai) {
if (!params.tools) params.tools = {}
params.tools.web_search = openai.tools.webSearch(config.openai)
} else if (config['openai-chat']) {
if (!params.tools) params.tools = {}
params.tools.web_search_preview = openai.tools.webSearchPreview(config['openai-chat'])
} else if (config.anthropic) {
if (!params.tools) params.tools = {}
params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic)
} else if (config.google) {
// case 'google-vertex':
if (!params.tools) params.tools = {}
params.tools.web_search = google.tools.googleSearch(config.google || {})
} else if (config.xai) {
const searchOptions = createXaiOptions({
searchParameters: { ...config.xai, mode: 'on' }
})
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
} else if (config.openrouter) {
const searchOptions = createOpenRouterOptions(config.openrouter)
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
}
return params
}

View File

@ -4,7 +4,6 @@
*/
import { definePlugin } from '../../'
import type { AiRequestContext } from '../../types'
import type { WebSearchPluginConfig } from './helper'
import { DEFAULT_WEB_SEARCH_CONFIG, switchWebSearchTool } from './helper'
@ -18,15 +17,8 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
name: 'webSearch',
enforce: 'pre',
transformParams: async (params: any, context: AiRequestContext) => {
const { providerId } = context
switchWebSearchTool(providerId, config, params)
if (providerId === 'cherryin' || providerId === 'cherryin-chat') {
// cherryin.gemini
const _providerId = params.model.provider.split('.')[1]
switchWebSearchTool(_providerId, config, params)
}
transformParams: async (params: any) => {
switchWebSearchTool(config, params)
return params
}
})

View File

@ -189,7 +189,7 @@ export default class ModernAiProvider {
config: ModernAiProviderConfig
): Promise<CompletionsResult> {
// ai-gateway不是image/generation 端点所以就先不走legacy了
if (config.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds['ai-gateway']) {
if (config.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds.gateway) {
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
if (!config.uiMessages) {
throw new Error('uiMessages is required for image generation endpoint')
@ -480,7 +480,7 @@ export default class ModernAiProvider {
// 代理其他方法到原有实现
public async models() {
if (this.actualProvider.id === SystemProviderIds['ai-gateway']) {
if (this.actualProvider.id === SystemProviderIds.gateway) {
const formatModel = function (models: GatewayLanguageModelEntry[]): Model[] {
return models.map((m) => ({
id: m.id,

View File

@ -11,11 +11,15 @@ import { vertex } from '@ai-sdk/google-vertex/edge'
import { combineHeaders } from '@ai-sdk/provider-utils'
import type { AnthropicSearchConfig, WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import { isBaseProvider } from '@cherrystudio/ai-core/core/providers/schemas'
import type { BaseProviderId } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import {
isAnthropicModel,
isFixedReasoningModel,
isGeminiModel,
isGenerateImageModel,
isGrokModel,
isOpenAIModel,
isOpenRouterBuiltInWebSearchModel,
isSupportedReasoningEffortModel,
isSupportedThinkingTokenModel,
@ -24,11 +28,12 @@ import {
import { getDefaultModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
import type { Model } from '@renderer/types'
import { type Assistant, type MCPTool, type Provider, SystemProviderIds } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
import { replacePromptVariables } from '@renderer/utils/prompt'
import { isAwsBedrockProvider } from '@renderer/utils/provider'
import { isAIGatewayProvider, isAwsBedrockProvider } from '@renderer/utils/provider'
import type { ModelMessage, Tool } from 'ai'
import { stepCountIs } from 'ai'
@ -43,6 +48,25 @@ const logger = loggerService.withContext('parameterBuilder')
type ProviderDefinedTool = Extract<Tool<any, any>, { type: 'provider-defined' }>
function mapVertexAIGatewayModelToProviderId(model: Model): BaseProviderId | undefined {
if (isAnthropicModel(model)) {
return 'anthropic'
}
if (isGeminiModel(model)) {
return 'google'
}
if (isGrokModel(model)) {
return 'xai'
}
if (isOpenAIModel(model)) {
return 'openai'
}
logger.warn(
`[mapVertexAIGatewayModelToProviderId] Unknown model type for AI Gateway: ${model.id}. Web search will not be enabled.`
)
return undefined
}
/**
* AI SDK
*
@ -117,6 +141,11 @@ export async function buildStreamTextParams(
if (enableWebSearch) {
if (isBaseProvider(aiSdkProviderId)) {
webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(aiSdkProviderId, webSearchConfig, model)
} else if (isAIGatewayProvider(provider) || SystemProviderIds.gateway === provider.id) {
const aiSdkProviderId = mapVertexAIGatewayModelToProviderId(model)
if (aiSdkProviderId) {
webSearchPluginConfig = buildProviderBuiltinWebSearchConfig(aiSdkProviderId, webSearchConfig, model)
}
}
if (!tools) {
tools = {}

View File

@ -1,5 +1,6 @@
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import * as z from 'zod'
const logger = loggerService.withContext('ProviderConfigs')
@ -81,12 +82,12 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
aliases: ['hf', 'hugging-face']
},
{
id: 'ai-gateway',
name: 'AI Gateway',
id: 'gateway',
name: 'Vercel AI Gateway',
import: () => import('@ai-sdk/gateway'),
creatorFunctionName: 'createGateway',
supportsImageGeneration: true,
aliases: ['gateway']
aliases: ['ai-gateway']
},
{
id: 'cerebras',
@ -104,6 +105,9 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
}
] as const
export const registeredNewProviderIds = NEW_PROVIDER_CONFIGS.map((config) => config.id)
export const registeredNewProviderIdSchema = z.enum(registeredNewProviderIds)
/**
* Providers
* 使aiCore的动态注册功能

View File

@ -27,7 +27,8 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
'xai',
'deepseek',
'openrouter',
'openai-compatible'
'openai-compatible',
'cherryin'
]
if (baseProviders.includes(id)) {
return { success: true, data: id }
@ -37,7 +38,15 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
},
customProviderIdSchema: {
safeParse: vi.fn((id) => {
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock']
const customProviders = [
'google-vertex',
'google-vertex-anthropic',
'bedrock',
'gateway',
'aihubmix',
'newapi',
'ollama'
]
if (customProviders.includes(id)) {
return { success: true, data: id }
}
@ -47,20 +56,7 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
}
})
vi.mock('../provider/factory', () => ({
getAiSdkProviderId: vi.fn((provider) => {
// Simulate the provider ID mapping
const mapping: Record<string, string> = {
[SystemProviderIds.gemini]: 'google',
[SystemProviderIds.openai]: 'openai',
[SystemProviderIds.anthropic]: 'anthropic',
[SystemProviderIds.grok]: 'xai',
[SystemProviderIds.deepseek]: 'deepseek',
[SystemProviderIds.openrouter]: 'openrouter'
}
return mapping[provider.id] || provider.id
})
}))
// Don't mock getAiSdkProviderId - use real implementation for more accurate tests
vi.mock('@renderer/config/models', async (importOriginal) => ({
...(await importOriginal()),
@ -179,8 +175,11 @@ describe('options utils', () => {
provider: SystemProviderIds.openai
} as Model
beforeEach(() => {
beforeEach(async () => {
vi.clearAllMocks()
// Reset getCustomParameters to return empty object by default
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({})
})
describe('buildProviderOptions', () => {
@ -391,7 +390,6 @@ describe('options utils', () => {
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.providerOptions).toHaveProperty('deepseek')
expect(result.providerOptions.deepseek).toBeDefined()
})
@ -461,10 +459,14 @@ describe('options utils', () => {
}
)
expect(result.providerOptions.openai).toHaveProperty('custom_param')
expect(result.providerOptions.openai.custom_param).toBe('custom_value')
expect(result.providerOptions.openai).toHaveProperty('another_param')
expect(result.providerOptions.openai.another_param).toBe(123)
expect(result.providerOptions).toStrictEqual({
openai: {
custom_param: 'custom_value',
another_param: 123,
serviceTier: undefined,
textVerbosity: undefined
}
})
})
it('should extract AI SDK standard params from custom parameters', async () => {
@ -696,5 +698,459 @@ describe('options utils', () => {
})
})
})
describe('AI Gateway provider', () => {
const gatewayProvider: Provider = {
id: SystemProviderIds.gateway,
name: 'Vercel AI Gateway',
type: 'gateway',
apiKey: 'test-key',
apiHost: 'https://gateway.vercel.com',
isSystem: true
} as Provider
it('should build OpenAI options for OpenAI models through gateway', () => {
const openaiModel: Model = {
id: 'openai/gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.gateway
} as Model
const result = buildProviderOptions(mockAssistant, openaiModel, gatewayProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.providerOptions).toHaveProperty('openai')
expect(result.providerOptions.openai).toBeDefined()
})
it('should build Anthropic options for Anthropic models through gateway', () => {
const anthropicModel: Model = {
id: 'anthropic/claude-3-5-sonnet-20241022',
name: 'Claude 3.5 Sonnet',
provider: SystemProviderIds.gateway
} as Model
const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.providerOptions).toHaveProperty('anthropic')
expect(result.providerOptions.anthropic).toBeDefined()
})
it('should build Google options for Gemini models through gateway', () => {
const geminiModel: Model = {
id: 'google/gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gateway
} as Model
const result = buildProviderOptions(mockAssistant, geminiModel, gatewayProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.providerOptions).toHaveProperty('google')
expect(result.providerOptions.google).toBeDefined()
})
it('should build xAI options for Grok models through gateway', () => {
const grokModel: Model = {
id: 'xai/grok-2-latest',
name: 'Grok 2',
provider: SystemProviderIds.gateway
} as Model
const result = buildProviderOptions(mockAssistant, grokModel, gatewayProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.providerOptions).toHaveProperty('xai')
expect(result.providerOptions.xai).toBeDefined()
})
it('should include reasoning parameters for Anthropic models when enabled', () => {
const anthropicModel: Model = {
id: 'anthropic/claude-3-5-sonnet-20241022',
name: 'Claude 3.5 Sonnet',
provider: SystemProviderIds.gateway
} as Model
const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.providerOptions.anthropic).toHaveProperty('thinking')
expect(result.providerOptions.anthropic.thinking).toEqual({
type: 'enabled',
budgetTokens: 5000
})
})
it('should merge gateway routing options from custom parameters', async () => {
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({
gateway: {
order: ['vertex', 'anthropic'],
only: ['vertex', 'anthropic']
}
})
const anthropicModel: Model = {
id: 'anthropic/claude-3-5-sonnet-20241022',
name: 'Claude 3.5 Sonnet',
provider: SystemProviderIds.gateway
} as Model
const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
// Should have both anthropic provider options and gateway routing options
expect(result.providerOptions).toHaveProperty('anthropic')
expect(result.providerOptions).toHaveProperty('gateway')
expect(result.providerOptions.gateway).toEqual({
order: ['vertex', 'anthropic'],
only: ['vertex', 'anthropic']
})
})
it('should combine provider-specific options with gateway routing options', async () => {
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({
gateway: {
order: ['openai', 'anthropic']
}
})
const openaiModel: Model = {
id: 'openai/gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.gateway
} as Model
const result = buildProviderOptions(mockAssistant, openaiModel, gatewayProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
// Should have OpenAI provider options with reasoning
expect(result.providerOptions.openai).toBeDefined()
expect(result.providerOptions.openai).toHaveProperty('reasoningEffort')
// Should also have gateway routing options
expect(result.providerOptions.gateway).toBeDefined()
expect(result.providerOptions.gateway.order).toEqual(['openai', 'anthropic'])
})
it('should build generic options for unknown model types through gateway', () => {
const unknownModel: Model = {
id: 'unknown-provider/model-name',
name: 'Unknown Model',
provider: SystemProviderIds.gateway
} as Model
const result = buildProviderOptions(mockAssistant, unknownModel, gatewayProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.providerOptions).toHaveProperty('openai-compatible')
expect(result.providerOptions['openai-compatible']).toBeDefined()
})
})
describe('Proxy provider custom parameters mapping', () => {
it('should map cherryin provider ID to actual AI SDK provider ID (Google)', async () => {
const { getCustomParameters } = await import('../reasoning')
// Mock Cherry In provider that uses Google SDK
const cherryinProvider = {
id: 'cherryin',
name: 'Cherry In',
type: 'gemini', // Using Google SDK
apiKey: 'test-key',
apiHost: 'https://cherryin.com',
models: [] as Model[]
} as Provider
const geminiModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: 'cherryin'
} as Model
// User provides custom parameters with Cherry Studio provider ID
vi.mocked(getCustomParameters).mockReturnValue({
cherryin: {
customOption1: 'value1',
customOption2: 'value2'
}
})
const result = buildProviderOptions(mockAssistant, geminiModel, cherryinProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
// Should map to 'google' AI SDK provider, not 'cherryin'
expect(result.providerOptions).toHaveProperty('google')
expect(result.providerOptions).not.toHaveProperty('cherryin')
expect(result.providerOptions.google).toMatchObject({
customOption1: 'value1',
customOption2: 'value2'
})
})
it('should map cherryin provider ID to actual AI SDK provider ID (OpenAI)', async () => {
const { getCustomParameters } = await import('../reasoning')
// Mock Cherry In provider that uses OpenAI SDK
const cherryinProvider = {
id: 'cherryin',
name: 'Cherry In',
type: 'openai-response', // Using OpenAI SDK
apiKey: 'test-key',
apiHost: 'https://cherryin.com',
models: [] as Model[]
} as Provider
const openaiModel: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: 'cherryin'
} as Model
// User provides custom parameters with Cherry Studio provider ID
vi.mocked(getCustomParameters).mockReturnValue({
cherryin: {
customOpenAIOption: 'openai_value'
}
})
const result = buildProviderOptions(mockAssistant, openaiModel, cherryinProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
// Should map to 'openai' AI SDK provider, not 'cherryin'
expect(result.providerOptions).toHaveProperty('openai')
expect(result.providerOptions).not.toHaveProperty('cherryin')
expect(result.providerOptions.openai).toMatchObject({
customOpenAIOption: 'openai_value'
})
})
it('should allow direct AI SDK provider ID in custom parameters', async () => {
const { getCustomParameters } = await import('../reasoning')
const geminiProvider = {
id: SystemProviderIds.gemini,
name: 'Google',
type: 'gemini',
apiKey: 'test-key',
apiHost: 'https://generativelanguage.googleapis.com',
models: [] as Model[]
} as Provider
const geminiModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gemini
} as Model
// User provides custom parameters directly with AI SDK provider ID
vi.mocked(getCustomParameters).mockReturnValue({
google: {
directGoogleOption: 'google_value'
}
})
const result = buildProviderOptions(mockAssistant, geminiModel, geminiProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
// Should merge directly to 'google' provider
expect(result.providerOptions.google).toMatchObject({
directGoogleOption: 'google_value'
})
})
it('should map gateway provider custom parameters to actual AI SDK provider', async () => {
const { getCustomParameters } = await import('../reasoning')
const gatewayProvider: Provider = {
id: SystemProviderIds.gateway,
name: 'Vercel AI Gateway',
type: 'gateway',
apiKey: 'test-key',
apiHost: 'https://gateway.vercel.com',
isSystem: true
} as Provider
const anthropicModel: Model = {
id: 'anthropic/claude-3-5-sonnet-20241022',
name: 'Claude 3.5 Sonnet',
provider: SystemProviderIds.gateway
} as Model
// User provides both gateway routing options and gateway-scoped custom parameters
vi.mocked(getCustomParameters).mockReturnValue({
gateway: {
order: ['vertex', 'anthropic'],
only: ['vertex']
},
customParam: 'should_go_to_anthropic'
})
const result = buildProviderOptions(mockAssistant, anthropicModel, gatewayProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
// Gateway routing options should be preserved
expect(result.providerOptions.gateway).toEqual({
order: ['vertex', 'anthropic'],
only: ['vertex']
})
// Custom parameters should go to the actual AI SDK provider (anthropic)
expect(result.providerOptions.anthropic).toMatchObject({
customParam: 'should_go_to_anthropic'
})
})
it('should handle mixed custom parameters (AI SDK provider ID + custom params)', async () => {
const { getCustomParameters } = await import('../reasoning')
const openaiProvider: Provider = {
id: SystemProviderIds.openai,
name: 'OpenAI',
type: 'openai-response',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1',
isSystem: true
} as Provider
// User provides both direct AI SDK provider params and custom params
vi.mocked(getCustomParameters).mockReturnValue({
openai: {
providerSpecific: 'value1'
},
customParam1: 'value2',
customParam2: 123
})
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
// Should merge both into 'openai' provider options
expect(result.providerOptions.openai).toMatchObject({
providerSpecific: 'value1',
customParam1: 'value2',
customParam2: 123
})
})
// Note: For proxy providers like aihubmix/newapi, users should write AI SDK provider ID (google/anthropic)
// instead of the Cherry Studio provider ID for custom parameters to work correctly
it('should handle cherryin fallback to openai-compatible with custom parameters', async () => {
const { getCustomParameters } = await import('../reasoning')
// Mock cherryin provider that falls back to openai-compatible (default case)
const cherryinProvider = {
id: 'cherryin',
name: 'Cherry In',
type: 'openai',
apiKey: 'test-key',
apiHost: 'https://cherryin.com',
models: [] as Model[]
} as Provider
const testModel: Model = {
id: 'some-model',
name: 'Some Model',
provider: 'cherryin'
} as Model
// User provides custom parameters with cherryin provider ID
vi.mocked(getCustomParameters).mockReturnValue({
customCherryinOption: 'cherryin_value'
})
const result = buildProviderOptions(mockAssistant, testModel, cherryinProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
// When cherryin falls back to default case, it should use rawProviderId (cherryin)
// User's cherryin params should merge with the provider options
expect(result.providerOptions).toHaveProperty('cherryin')
expect(result.providerOptions.cherryin).toMatchObject({
customCherryinOption: 'cherryin_value'
})
})
it('should handle cross-provider configurations', async () => {
const { getCustomParameters } = await import('../reasoning')
const openaiProvider: Provider = {
id: SystemProviderIds.openai,
name: 'OpenAI',
type: 'openai-response',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1',
isSystem: true
} as Provider
// User provides parameters for multiple providers
// In real usage, anthropic/google params would be treated as regular params for openai provider
vi.mocked(getCustomParameters).mockReturnValue({
openai: {
openaiSpecific: 'openai_value'
},
customParam: 'value'
})
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
// Should have openai provider options with both scoped and custom params
expect(result.providerOptions).toHaveProperty('openai')
expect(result.providerOptions.openai).toMatchObject({
openaiSpecific: 'openai_value',
customParam: 'value'
})
})
})
})
})

View File

@ -1,5 +1,5 @@
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import { type AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { XaiProviderOptions } from '@ai-sdk/xai'
@ -7,6 +7,9 @@ import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-c
import { loggerService } from '@logger'
import {
getModelSupportedVerbosity,
isAnthropicModel,
isGeminiModel,
isGrokModel,
isOpenAIModel,
isQwenMTModel,
isSupportFlexServiceTierModel,
@ -158,8 +161,8 @@ export function buildProviderOptions(
providerOptions: Record<string, Record<string, JSONValue>>
standardParams: Partial<Record<AiSdkParam, any>>
} {
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
const rawProviderId = getAiSdkProviderId(actualProvider)
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities, rawProviderId })
// 构建 provider 特定的选项
let providerSpecificOptions: Record<string, any> = {}
const serviceTier = getServiceTier(model, actualProvider)
@ -174,14 +177,13 @@ export function buildProviderOptions(
case 'azure':
case 'azure-responses':
{
const options: OpenAIResponsesProviderOptions = buildOpenAIProviderOptions(
providerSpecificOptions = buildOpenAIProviderOptions(
assistant,
model,
capabilities,
serviceTier,
textVerbosity
)
providerSpecificOptions = options
}
break
case 'anthropic':
@ -199,10 +201,13 @@ export function buildProviderOptions(
case 'openrouter':
case 'openai-compatible': {
// 对于其他 provider使用通用的构建逻辑
const genericOptions = buildGenericProviderOptions(rawProviderId, assistant, model, capabilities)
providerSpecificOptions = {
...buildGenericProviderOptions(assistant, model, capabilities),
serviceTier,
textVerbosity
[rawProviderId]: {
...genericOptions[rawProviderId],
serviceTier,
textVerbosity
}
}
break
}
@ -241,50 +246,105 @@ export function buildProviderOptions(
case SystemProviderIds.ollama:
providerSpecificOptions = buildOllamaProviderOptions(assistant, capabilities)
break
case SystemProviderIds.gateway:
providerSpecificOptions = buildAIGatewayOptions(assistant, model, capabilities, serviceTier, textVerbosity)
break
default:
// 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = buildGenericProviderOptions(rawProviderId, assistant, model, capabilities)
// Merge serviceTier and textVerbosity
providerSpecificOptions = {
...buildGenericProviderOptions(assistant, model, capabilities),
serviceTier,
textVerbosity
...providerSpecificOptions,
[rawProviderId]: {
...providerSpecificOptions[rawProviderId],
serviceTier,
textVerbosity
}
}
}
} else {
throw error
}
}
// 获取自定义参数并分离标准参数和 provider 特定参数
logger.debug('Built providerSpecificOptions', { providerSpecificOptions })
/**
* Retrieve custom parameters and separate standard parameters from provider-specific parameters.
*/
const customParams = getCustomParameters(assistant)
const { standardParams, providerParams } = extractAiSdkStandardParams(customParams)
logger.debug('Extracted standardParams and providerParams', { standardParams, providerParams })
// 合并 provider 特定的自定义参数到 providerSpecificOptions
providerSpecificOptions = {
...providerSpecificOptions,
...providerParams
}
let rawProviderKey =
{
'google-vertex': 'google',
'google-vertex-anthropic': 'anthropic',
'azure-anthropic': 'anthropic',
'ai-gateway': 'gateway',
azure: 'openai',
'azure-responses': 'openai'
}[rawProviderId] || rawProviderId
if (rawProviderKey === 'cherryin') {
rawProviderKey =
{ gemini: 'google', ['openai-response']: 'openai', openai: 'cherryin' }[actualProvider.type] ||
actualProvider.type
/**
* Get the actual AI SDK provider ID(s) from the already-built providerSpecificOptions.
* For proxy providers (cherryin, aihubmix, newapi), this will be the actual SDK provider (e.g., 'google', 'openai', 'anthropic')
* For regular providers, this will be the provider itself
*/
const actualAiSdkProviderIds = Object.keys(providerSpecificOptions)
const primaryAiSdkProviderId = actualAiSdkProviderIds[0] // Use the first one as primary for non-scoped params
/**
* Merge custom parameters into providerSpecificOptions.
* Simple logic:
* 1. If key is in actualAiSdkProviderIds merge directly (user knows the actual AI SDK provider ID)
* 2. If key == rawProviderId:
* - If it's gateway/ollama preserve (they need their own config for routing/options)
* - Otherwise map to primary (this is a proxy provider like cherryin)
* 3. Otherwise treat as regular parameter, merge to primary provider
*
* Example:
* - User writes `cherryin: { opt: 'val' }` mapped to `google: { opt: 'val' }` (case 2, proxy)
* - User writes `gateway: { order: [...] }` stays as `gateway: { order: [...] }` (case 2, routing config)
* - User writes `google: { opt: 'val' }` stays as `google: { opt: 'val' }` (case 1)
* - User writes `customKey: 'val'` merged to `google: { customKey: 'val' }` (case 3)
*/
for (const key of Object.keys(providerParams)) {
if (actualAiSdkProviderIds.includes(key)) {
// Case 1: Key is an actual AI SDK provider ID - merge directly
providerSpecificOptions = {
...providerSpecificOptions,
[key]: {
...providerSpecificOptions[key],
...providerParams[key]
}
}
} else if (key === rawProviderId && !actualAiSdkProviderIds.includes(rawProviderId)) {
// Case 2: Key is the current provider (not in actualAiSdkProviderIds, so it's a proxy or special provider)
// Gateway is special: it needs routing config preserved
if (key === SystemProviderIds.gateway) {
// Preserve gateway config for routing
providerSpecificOptions = {
...providerSpecificOptions,
[key]: {
...providerSpecificOptions[key],
...providerParams[key]
}
}
} else {
// Proxy provider (cherryin, etc.) - map to actual AI SDK provider
providerSpecificOptions = {
...providerSpecificOptions,
[primaryAiSdkProviderId]: {
...providerSpecificOptions[primaryAiSdkProviderId],
...providerParams[key]
}
}
}
} else {
// Case 3: Regular parameter - merge to primary provider
providerSpecificOptions = {
...providerSpecificOptions,
[primaryAiSdkProviderId]: {
...providerSpecificOptions[primaryAiSdkProviderId],
[key]: providerParams[key]
}
}
}
}
logger.debug('Final providerSpecificOptions after merging providerParams', { providerSpecificOptions })
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数
return {
providerOptions: {
[rawProviderKey]: providerSpecificOptions
},
providerOptions: providerSpecificOptions,
standardParams
}
}
@ -302,7 +362,7 @@ function buildOpenAIProviderOptions(
},
serviceTier: OpenAIServiceTier,
textVerbosity?: OpenAIVerbosity
): OpenAIResponsesProviderOptions {
): Record<string, OpenAIResponsesProviderOptions> {
const { enableReasoning } = capabilities
let providerOptions: OpenAIResponsesProviderOptions = {}
// OpenAI 推理参数
@ -341,7 +401,9 @@ function buildOpenAIProviderOptions(
textVerbosity
}
return providerOptions
return {
openai: providerOptions
}
}
/**
@ -355,7 +417,7 @@ function buildAnthropicProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): AnthropicProviderOptions {
): Record<string, AnthropicProviderOptions> {
const { enableReasoning } = capabilities
let providerOptions: AnthropicProviderOptions = {}
@ -368,7 +430,11 @@ function buildAnthropicProviderOptions(
}
}
return providerOptions
return {
anthropic: {
...providerOptions
}
}
}
/**
@ -382,7 +448,7 @@ function buildGeminiProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): GoogleGenerativeAIProviderOptions {
): Record<string, GoogleGenerativeAIProviderOptions> {
const { enableReasoning, enableGenerateImage } = capabilities
let providerOptions: GoogleGenerativeAIProviderOptions = {}
@ -402,7 +468,11 @@ function buildGeminiProviderOptions(
}
}
return providerOptions
return {
google: {
...providerOptions
}
}
}
function buildXAIProviderOptions(
@ -413,7 +483,7 @@ function buildXAIProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): XaiProviderOptions {
): Record<string, XaiProviderOptions> {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
@ -425,7 +495,11 @@ function buildXAIProviderOptions(
}
}
return providerOptions
return {
xai: {
...providerOptions
}
}
}
function buildCherryInProviderOptions(
@ -439,21 +513,19 @@ function buildCherryInProviderOptions(
actualProvider: Provider,
serviceTier: OpenAIServiceTier,
textVerbosity: OpenAIVerbosity
): OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions {
): Record<string, OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions> {
switch (actualProvider.type) {
case 'openai':
return buildGenericProviderOptions(assistant, model, capabilities)
return buildGenericProviderOptions('cherryin', assistant, model, capabilities)
case 'openai-response':
return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity)
case 'anthropic':
return buildAnthropicProviderOptions(assistant, model, capabilities)
case 'gemini':
return buildGeminiProviderOptions(assistant, model, capabilities)
default:
return buildGenericProviderOptions(assistant, model, capabilities)
return buildGenericProviderOptions('cherryin', assistant, model, capabilities)
}
}
@ -468,7 +540,7 @@ function buildBedrockProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): BedrockProviderOptions {
): Record<string, BedrockProviderOptions> {
const { enableReasoning } = capabilities
let providerOptions: BedrockProviderOptions = {}
@ -485,7 +557,9 @@ function buildBedrockProviderOptions(
providerOptions.anthropicBeta = betaHeaders
}
return providerOptions
return {
bedrock: providerOptions
}
}
function buildOllamaProviderOptions(
@ -495,20 +569,23 @@ function buildOllamaProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): OllamaCompletionProviderOptions {
): Record<string, OllamaCompletionProviderOptions> {
const { enableReasoning } = capabilities
const providerOptions: OllamaCompletionProviderOptions = {}
const reasoningEffort = assistant.settings?.reasoning_effort
if (enableReasoning) {
providerOptions.think = !['none', undefined].includes(reasoningEffort)
}
return providerOptions
return {
ollama: providerOptions
}
}
/**
* providerOptions provider
*/
function buildGenericProviderOptions(
providerId: string,
assistant: Assistant,
model: Model,
capabilities: {
@ -551,5 +628,37 @@ function buildGenericProviderOptions(
}
}
return providerOptions
return {
[providerId]: providerOptions
}
}
function buildAIGatewayOptions(
assistant: Assistant,
model: Model,
capabilities: {
enableReasoning: boolean
enableWebSearch: boolean
enableGenerateImage: boolean
},
serviceTier: OpenAIServiceTier,
textVerbosity?: OpenAIVerbosity
): Record<
string,
| OpenAIResponsesProviderOptions
| AnthropicProviderOptions
| GoogleGenerativeAIProviderOptions
| Record<string, unknown>
> {
if (isAnthropicModel(model)) {
return buildAnthropicProviderOptions(assistant, model, capabilities)
} else if (isOpenAIModel(model)) {
return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity)
} else if (isGeminiModel(model)) {
return buildGeminiProviderOptions(assistant, model, capabilities)
} else if (isGrokModel(model)) {
return buildXAIProviderOptions(assistant, model, capabilities)
} else {
return buildGenericProviderOptions('openai-compatible', assistant, model, capabilities)
}
}

View File

@ -1853,7 +1853,7 @@ export const SYSTEM_MODELS: Record<SystemProviderId | 'defaultModel', Model[]> =
}
],
huggingface: [],
'ai-gateway': [],
gateway: [],
cerebras: [
{
id: 'gpt-oss-120b',

View File

@ -179,6 +179,11 @@ export const isGeminiModel = (model: Model) => {
return modelId.includes('gemini')
}
export const isGrokModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('grok')
}
// zhipu 视觉推理模型用这组 special token 标记推理结果
export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const

View File

@ -676,10 +676,10 @@ export const SYSTEM_PROVIDERS_CONFIG: Record<SystemProviderId, SystemProvider> =
isSystem: true,
enabled: false
},
'ai-gateway': {
id: 'ai-gateway',
name: 'AI Gateway',
type: 'ai-gateway',
gateway: {
id: 'gateway',
name: 'Vercel AI Gateway',
type: 'gateway',
apiKey: '',
apiHost: 'https://ai-gateway.vercel.sh/v1/ai',
models: [],
@ -762,7 +762,7 @@ export const PROVIDER_LOGO_MAP: AtLeast<SystemProviderId, string> = {
longcat: LongCatProviderLogo,
huggingface: HuggingfaceProviderLogo,
sophnet: SophnetProviderLogo,
'ai-gateway': AIGatewayProviderLogo,
gateway: AIGatewayProviderLogo,
cerebras: CerebrasProviderLogo
} as const
@ -1413,7 +1413,7 @@ export const PROVIDER_URLS: Record<SystemProviderId, ProviderUrls> = {
models: 'https://huggingface.co/models'
}
},
'ai-gateway': {
gateway: {
api: {
url: 'https://ai-gateway.vercel.sh/v1/ai'
},

View File

@ -87,7 +87,7 @@ const providerKeyMap = {
longcat: 'provider.longcat',
huggingface: 'provider.huggingface',
sophnet: 'provider.sophnet',
'ai-gateway': 'provider.ai-gateway',
gateway: 'provider.ai-gateway',
cerebras: 'provider.cerebras'
} as const

View File

@ -2531,7 +2531,7 @@
},
"provider": {
"302ai": "302.AI",
"ai-gateway": "AI Gateway",
"ai-gateway": "Vercel AI Gateway",
"aihubmix": "AiHubMix",
"aionly": "AiOnly",
"alayanew": "Alaya NeW",

View File

@ -2531,7 +2531,7 @@
},
"provider": {
"302ai": "302.AI",
"ai-gateway": "AI Gateway",
"ai-gateway": "Vercel AI Gateway",
"aihubmix": "AiHubMix",
"aionly": "唯一AI (AiOnly)",
"alayanew": "Alaya NeW",

View File

@ -67,7 +67,7 @@ const persistedReducer = persistReducer(
{
key: 'cherry-studio',
storage,
version: 180,
version: 181,
blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs', 'toolPermissions'],
migrate
},

View File

@ -2810,7 +2810,7 @@ const migrateConfig = {
try {
addProvider(state, SystemProviderIds.longcat)
addProvider(state, SystemProviderIds['ai-gateway'])
addProvider(state, 'gateway')
addProvider(state, 'cerebras')
state.llm.providers.forEach((provider) => {
if (provider.id === SystemProviderIds.minimax) {
@ -2932,6 +2932,26 @@ const migrateConfig = {
logger.error('migrate 180 error', error as Error)
return state
}
},
'181': (state: RootState) => {
try {
state.llm.providers.forEach((provider) => {
if (provider.id === 'ai-gateway') {
provider.id = SystemProviderIds.gateway
}
// Also update model.provider references to avoid orphaned models
provider.models?.forEach((model) => {
if (model.provider === 'ai-gateway') {
model.provider = SystemProviderIds.gateway
}
})
})
logger.info('migrate 181 success')
return state
} catch (error) {
logger.error('migrate 181 error', error as Error)
return state
}
}
}

View File

@ -15,7 +15,7 @@ export const ProviderTypeSchema = z.enum([
'aws-bedrock',
'vertex-anthropic',
'new-api',
'ai-gateway',
'gateway',
'ollama'
])
@ -188,7 +188,7 @@ export const SystemProviderIdSchema = z.enum([
'longcat',
'huggingface',
'sophnet',
'ai-gateway',
'gateway',
'cerebras'
])
@ -257,7 +257,7 @@ export const SystemProviderIds = {
aionly: 'aionly',
longcat: 'longcat',
huggingface: 'huggingface',
'ai-gateway': 'ai-gateway',
gateway: 'gateway',
cerebras: 'cerebras'
} as const satisfies Record<SystemProviderId, SystemProviderId>

View File

@ -189,7 +189,7 @@ describe('provider utils', () => {
expect(isAnthropicProvider(createProvider({ type: 'anthropic' }))).toBe(true)
expect(isGeminiProvider(createProvider({ type: 'gemini' }))).toBe(true)
expect(isAIGatewayProvider(createProvider({ type: 'ai-gateway' }))).toBe(true)
expect(isAIGatewayProvider(createProvider({ type: 'gateway' }))).toBe(true)
})
it('computes API version support', () => {

View File

@ -172,7 +172,7 @@ export function isGeminiProvider(provider: Provider): boolean {
}
export function isAIGatewayProvider(provider: Provider): boolean {
return provider.type === 'ai-gateway'
return provider.type === 'gateway'
}
export function isOllamaProvider(provider: Provider): boolean {