mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 14:41:24 +08:00
Fix custom parameters placement for Vercel AI Gateway
For AI Gateway provider, custom parameters are now placed at the body level instead of being nested inside providerOptions.gateway. This fixes the issue where parameters like 'tools' were being incorrectly added to providerOptions.gateway when they should be at the same level as providerOptions. Fixes #4197 Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com>
This commit is contained in:
parent
64fde27f9e
commit
b14e48dd78
@ -107,7 +107,7 @@ export async function buildStreamTextParams(
|
||||
searchWithTime: store.getState().websearch.searchWithTime
|
||||
}
|
||||
|
||||
const { providerOptions, standardParams } = buildProviderOptions(assistant, model, provider, {
|
||||
const { providerOptions, standardParams, bodyParams } = buildProviderOptions(assistant, model, provider, {
|
||||
enableReasoning,
|
||||
enableWebSearch,
|
||||
enableGenerateImage
|
||||
@ -185,6 +185,7 @@ export async function buildStreamTextParams(
|
||||
// Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed)
|
||||
// are extracted from custom parameters and passed directly to streamText()
|
||||
// instead of being placed in providerOptions
|
||||
// Note: bodyParams are custom parameters for AI Gateway that should be at body level
|
||||
const params: StreamTextParams = {
|
||||
messages: sdkMessages,
|
||||
maxOutputTokens: getMaxTokens(assistant, model),
|
||||
@ -192,6 +193,8 @@ export async function buildStreamTextParams(
|
||||
topP: getTopP(assistant, model),
|
||||
// Include AI SDK standard params extracted from custom parameters
|
||||
...standardParams,
|
||||
// Include body-level params for AI Gateway custom parameters
|
||||
...bodyParams,
|
||||
abortSignal: options.requestOptions?.signal,
|
||||
headers,
|
||||
providerOptions,
|
||||
|
||||
@ -37,7 +37,7 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
|
||||
},
|
||||
customProviderIdSchema: {
|
||||
safeParse: vi.fn((id) => {
|
||||
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock']
|
||||
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock', 'ai-gateway']
|
||||
if (customProviders.includes(id)) {
|
||||
return { success: true, data: id }
|
||||
}
|
||||
@ -56,7 +56,8 @@ vi.mock('../provider/factory', () => ({
|
||||
[SystemProviderIds.anthropic]: 'anthropic',
|
||||
[SystemProviderIds.grok]: 'xai',
|
||||
[SystemProviderIds.deepseek]: 'deepseek',
|
||||
[SystemProviderIds.openrouter]: 'openrouter'
|
||||
[SystemProviderIds.openrouter]: 'openrouter',
|
||||
[SystemProviderIds['ai-gateway']]: 'ai-gateway'
|
||||
}
|
||||
return mapping[provider.id] || provider.id
|
||||
})
|
||||
@ -204,6 +205,8 @@ describe('options utils', () => {
|
||||
expect(result.providerOptions).toHaveProperty('openai')
|
||||
expect(result.providerOptions.openai).toBeDefined()
|
||||
expect(result.standardParams).toBeDefined()
|
||||
expect(result.bodyParams).toBeDefined()
|
||||
expect(result.bodyParams).toEqual({})
|
||||
})
|
||||
|
||||
it('should include reasoning parameters when enabled', () => {
|
||||
@ -696,5 +699,90 @@ describe('options utils', () => {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('AI Gateway provider', () => {
|
||||
const aiGatewayProvider: Provider = {
|
||||
id: SystemProviderIds['ai-gateway'],
|
||||
name: 'AI Gateway',
|
||||
type: 'ai-gateway',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://ai-gateway.vercel.sh/v1/ai',
|
||||
isSystem: true,
|
||||
models: [] as Model[]
|
||||
} as Provider
|
||||
|
||||
const aiGatewayModel: Model = {
|
||||
id: 'openai/gpt-4',
|
||||
name: 'GPT-4',
|
||||
provider: SystemProviderIds['ai-gateway']
|
||||
} as Model
|
||||
|
||||
it('should build basic AI Gateway options with empty bodyParams', () => {
|
||||
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.providerOptions).toHaveProperty('gateway')
|
||||
expect(result.providerOptions.gateway).toBeDefined()
|
||||
expect(result.bodyParams).toEqual({})
|
||||
})
|
||||
|
||||
it('should place custom parameters in bodyParams for AI Gateway instead of providerOptions', async () => {
|
||||
const { getCustomParameters } = await import('../reasoning')
|
||||
|
||||
vi.mocked(getCustomParameters).mockReturnValue({
|
||||
tools: [{ id: 'openai.image_generation' }],
|
||||
custom_param: 'custom_value'
|
||||
})
|
||||
|
||||
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
// Custom parameters should be in bodyParams, NOT in providerOptions.gateway
|
||||
expect(result.bodyParams).toHaveProperty('tools')
|
||||
expect(result.bodyParams.tools).toEqual([{ id: 'openai.image_generation' }])
|
||||
expect(result.bodyParams).toHaveProperty('custom_param')
|
||||
expect(result.bodyParams.custom_param).toBe('custom_value')
|
||||
|
||||
// providerOptions.gateway should NOT contain custom parameters
|
||||
expect(result.providerOptions.gateway).not.toHaveProperty('tools')
|
||||
expect(result.providerOptions.gateway).not.toHaveProperty('custom_param')
|
||||
})
|
||||
|
||||
it('should still extract AI SDK standard params from custom parameters for AI Gateway', async () => {
|
||||
const { getCustomParameters } = await import('../reasoning')
|
||||
|
||||
vi.mocked(getCustomParameters).mockReturnValue({
|
||||
topK: 5,
|
||||
frequencyPenalty: 0.5,
|
||||
tools: [{ id: 'openai.image_generation' }]
|
||||
})
|
||||
|
||||
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
// Standard params should be extracted and returned separately
|
||||
expect(result.standardParams).toEqual({
|
||||
topK: 5,
|
||||
frequencyPenalty: 0.5
|
||||
})
|
||||
|
||||
// Custom params (non-standard) should be in bodyParams
|
||||
expect(result.bodyParams).toHaveProperty('tools')
|
||||
expect(result.bodyParams.tools).toEqual([{ id: 'openai.image_generation' }])
|
||||
|
||||
// Neither should be in providerOptions.gateway
|
||||
expect(result.providerOptions.gateway).not.toHaveProperty('topK')
|
||||
expect(result.providerOptions.gateway).not.toHaveProperty('tools')
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -155,6 +155,7 @@ export function buildProviderOptions(
|
||||
): {
|
||||
providerOptions: Record<string, Record<string, JSONValue>>
|
||||
standardParams: Partial<Record<AiSdkParam, any>>
|
||||
bodyParams: Record<string, any>
|
||||
} {
|
||||
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
|
||||
const rawProviderId = getAiSdkProviderId(actualProvider)
|
||||
@ -253,12 +254,6 @@ export function buildProviderOptions(
|
||||
const customParams = getCustomParameters(assistant)
|
||||
const { standardParams, providerParams } = extractAiSdkStandardParams(customParams)
|
||||
|
||||
// 合并 provider 特定的自定义参数到 providerSpecificOptions
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
...providerParams
|
||||
}
|
||||
|
||||
let rawProviderKey =
|
||||
{
|
||||
'google-vertex': 'google',
|
||||
@ -273,12 +268,27 @@ export function buildProviderOptions(
|
||||
rawProviderKey = { gemini: 'google', ['openai-response']: 'openai' }[actualProvider.type] || actualProvider.type
|
||||
}
|
||||
|
||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数
|
||||
// For AI Gateway, custom parameters should be placed at body level, not inside providerOptions.gateway
|
||||
// See: https://github.com/CherryHQ/cherry-studio/issues/4197
|
||||
let bodyParams: Record<string, any> = {}
|
||||
if (rawProviderKey === 'gateway') {
|
||||
// Custom parameters go to body level for AI Gateway
|
||||
bodyParams = providerParams
|
||||
} else {
|
||||
// For other providers, merge custom parameters into providerSpecificOptions
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
...providerParams
|
||||
}
|
||||
}
|
||||
|
||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数和 body 参数
|
||||
return {
|
||||
providerOptions: {
|
||||
[rawProviderKey]: providerSpecificOptions
|
||||
},
|
||||
standardParams
|
||||
standardParams,
|
||||
bodyParams
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user