Fix: custom parameters for Gemini models (#11456)

* Initial plan

* fix(aiCore): extract AI SDK standard params from custom params for Gemini

Custom parameters like topK, frequencyPenalty, presencePenalty,
stopSequences, and seed should be passed as top-level streamText()
parameters, not in providerOptions. This fixes the issue where these
parameters were being ignored by the AI SDK's @ai-sdk/google module.

Changes:
- Add extractAiSdkStandardParams function to separate standard params
- Update buildProviderOptions to return both providerOptions and standardParams
- Update buildStreamTextParams to spread standardParams into params object
- Update tests to reflect new return structure

Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com>

* refactor(aiCore): remove extractAiSdkStandardParams function and its tests, streamline parameter extraction logic

* chore: type

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com>
Co-authored-by: suyao <sy20010504@gmail.com>
This commit is contained in:
Copilot 2025-11-26 13:16:58 +08:00 committed by GitHub
parent 5fb59d21ec
commit c1f4b5b9b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 851 additions and 45 deletions

View File

@ -106,7 +106,7 @@ export async function buildStreamTextParams(
searchWithTime: store.getState().websearch.searchWithTime
}
const providerOptions = buildProviderOptions(assistant, model, provider, {
const { providerOptions, standardParams } = buildProviderOptions(assistant, model, provider, {
enableReasoning,
enableWebSearch,
enableGenerateImage
@ -181,11 +181,16 @@ export async function buildStreamTextParams(
}
// 构建基础参数
// Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed)
// are extracted from custom parameters and passed directly to streamText()
// instead of being placed in providerOptions
const params: StreamTextParams = {
messages: sdkMessages,
maxOutputTokens: getMaxTokens(assistant, model),
temperature: getTemperature(assistant, model),
topP: getTopP(assistant, model),
// Include AI SDK standard params extracted from custom parameters
...standardParams,
abortSignal: options.requestOptions?.signal,
headers,
providerOptions,

View File

@ -0,0 +1,652 @@
/**
* extractAiSdkStandardParams Unit Tests
* Tests for extracting AI SDK standard parameters from custom parameters
*/
import { describe, expect, it, vi } from 'vitest'
import { extractAiSdkStandardParams } from '../options'
// Mock logger to prevent errors
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
debug: vi.fn(),
error: vi.fn(),
warn: vi.fn(),
info: vi.fn()
})
}
}))
// Mock settings store
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
// Mock hooks to prevent uuid errors
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn(() => ({}))
}))
// Mock uuid to prevent errors
vi.mock('uuid', () => ({
v4: vi.fn(() => 'test-uuid')
}))
// Mock AssistantService to prevent uuid errors
vi.mock('@renderer/services/AssistantService', () => ({
getDefaultAssistant: vi.fn(() => ({
id: 'test-assistant',
name: 'Test Assistant',
settings: {}
})),
getDefaultTopic: vi.fn(() => ({
id: 'test-topic',
assistantId: 'test-assistant',
createdAt: new Date().toISOString()
}))
}))
// Mock provider service
vi.mock('@renderer/services/ProviderService', () => ({
getProviderById: vi.fn(() => ({
id: 'test-provider',
name: 'Test Provider'
}))
}))
// Mock config modules
vi.mock('@renderer/config/models', () => ({
isOpenAIModel: vi.fn(() => false),
isQwenMTModel: vi.fn(() => false),
isSupportFlexServiceTierModel: vi.fn(() => false),
isSupportVerbosityModel: vi.fn(() => false),
getModelSupportedVerbosity: vi.fn(() => [])
}))
vi.mock('@renderer/config/translate', () => ({
mapLanguageToQwenMTModel: vi.fn()
}))
vi.mock('@renderer/utils/provider', () => ({
isSupportServiceTierProvider: vi.fn(() => false),
isSupportVerbosityProvider: vi.fn(() => false)
}))
describe('extractAiSdkStandardParams', () => {
describe('Positive cases - Standard parameters extraction', () => {
it('should extract all AI SDK standard parameters', () => {
const customParams = {
maxOutputTokens: 1000,
temperature: 0.7,
topP: 0.9,
topK: 40,
presencePenalty: 0.5,
frequencyPenalty: 0.3,
stopSequences: ['STOP', 'END'],
seed: 42
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
maxOutputTokens: 1000,
temperature: 0.7,
topP: 0.9,
topK: 40,
presencePenalty: 0.5,
frequencyPenalty: 0.3,
stopSequences: ['STOP', 'END'],
seed: 42
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract single standard parameter', () => {
const customParams = {
temperature: 0.8
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.8
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract topK parameter', () => {
const customParams = {
topK: 50
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
topK: 50
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract frequencyPenalty parameter', () => {
const customParams = {
frequencyPenalty: 0.6
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
frequencyPenalty: 0.6
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract presencePenalty parameter', () => {
const customParams = {
presencePenalty: 0.4
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
presencePenalty: 0.4
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract stopSequences parameter', () => {
const customParams = {
stopSequences: ['HALT', 'TERMINATE']
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
stopSequences: ['HALT', 'TERMINATE']
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract seed parameter', () => {
const customParams = {
seed: 12345
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
seed: 12345
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract maxOutputTokens parameter', () => {
const customParams = {
maxOutputTokens: 2048
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
maxOutputTokens: 2048
})
expect(result.providerParams).toStrictEqual({})
})
it('should extract topP parameter', () => {
const customParams = {
topP: 0.95
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
topP: 0.95
})
expect(result.providerParams).toStrictEqual({})
})
})
describe('Negative cases - Provider-specific parameters', () => {
it('should place all non-standard parameters in providerParams', () => {
const customParams = {
customParam: 'value',
anotherParam: 123,
thirdParam: true
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
customParam: 'value',
anotherParam: 123,
thirdParam: true
})
})
it('should place single provider-specific parameter in providerParams', () => {
const customParams = {
reasoningEffort: 'high'
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
reasoningEffort: 'high'
})
})
it('should place model-specific parameter in providerParams', () => {
const customParams = {
thinking: { type: 'enabled', budgetTokens: 5000 }
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
thinking: { type: 'enabled', budgetTokens: 5000 }
})
})
it('should place serviceTier in providerParams', () => {
const customParams = {
serviceTier: 'auto'
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
serviceTier: 'auto'
})
})
it('should place textVerbosity in providerParams', () => {
const customParams = {
textVerbosity: 'high'
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
textVerbosity: 'high'
})
})
})
describe('Mixed parameters', () => {
it('should correctly separate mixed standard and provider-specific parameters', () => {
const customParams = {
temperature: 0.7,
topK: 40,
customParam: 'custom_value',
reasoningEffort: 'medium',
frequencyPenalty: 0.5,
seed: 999
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.7,
topK: 40,
frequencyPenalty: 0.5,
seed: 999
})
expect(result.providerParams).toStrictEqual({
customParam: 'custom_value',
reasoningEffort: 'medium'
})
})
it('should handle complex mixed parameters with nested objects', () => {
const customParams = {
topP: 0.9,
presencePenalty: 0.3,
thinking: { type: 'enabled', budgetTokens: 5000 },
stopSequences: ['STOP'],
serviceTier: 'auto',
maxOutputTokens: 4096
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
topP: 0.9,
presencePenalty: 0.3,
stopSequences: ['STOP'],
maxOutputTokens: 4096
})
expect(result.providerParams).toStrictEqual({
thinking: { type: 'enabled', budgetTokens: 5000 },
serviceTier: 'auto'
})
})
it('should handle all standard params with some provider params', () => {
const customParams = {
maxOutputTokens: 2000,
temperature: 0.8,
topP: 0.95,
topK: 50,
presencePenalty: 0.6,
frequencyPenalty: 0.4,
stopSequences: ['END', 'DONE'],
seed: 777,
customApiParam: 'value',
anotherCustomParam: 123
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
maxOutputTokens: 2000,
temperature: 0.8,
topP: 0.95,
topK: 50,
presencePenalty: 0.6,
frequencyPenalty: 0.4,
stopSequences: ['END', 'DONE'],
seed: 777
})
expect(result.providerParams).toStrictEqual({
customApiParam: 'value',
anotherCustomParam: 123
})
})
})
describe('Edge cases', () => {
it('should handle empty object', () => {
const customParams = {}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({})
})
it('should handle zero values for numeric parameters', () => {
const customParams = {
temperature: 0,
topK: 0,
seed: 0
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0,
topK: 0,
seed: 0
})
expect(result.providerParams).toStrictEqual({})
})
it('should handle negative values for numeric parameters', () => {
const customParams = {
presencePenalty: -0.5,
frequencyPenalty: -0.3,
seed: -1
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
presencePenalty: -0.5,
frequencyPenalty: -0.3,
seed: -1
})
expect(result.providerParams).toStrictEqual({})
})
it('should handle empty arrays for stopSequences', () => {
const customParams = {
stopSequences: []
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
stopSequences: []
})
expect(result.providerParams).toStrictEqual({})
})
it('should handle null values in mixed parameters', () => {
const customParams = {
temperature: 0.7,
customNull: null,
topK: 40
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.7,
topK: 40
})
expect(result.providerParams).toStrictEqual({
customNull: null
})
})
it('should handle undefined values in mixed parameters', () => {
const customParams = {
temperature: 0.7,
customUndefined: undefined,
topK: 40
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.7,
topK: 40
})
expect(result.providerParams).toStrictEqual({
customUndefined: undefined
})
})
it('should handle boolean values for standard parameters', () => {
const customParams = {
temperature: 0.7,
customBoolean: false,
topK: 40
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.7,
topK: 40
})
expect(result.providerParams).toStrictEqual({
customBoolean: false
})
})
it('should handle very large numeric values', () => {
const customParams = {
maxOutputTokens: 999999,
seed: 2147483647,
topK: 10000
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
maxOutputTokens: 999999,
seed: 2147483647,
topK: 10000
})
expect(result.providerParams).toStrictEqual({})
})
it('should handle decimal values with high precision', () => {
const customParams = {
temperature: 0.123456789,
topP: 0.987654321,
presencePenalty: 0.111111111
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.123456789,
topP: 0.987654321,
presencePenalty: 0.111111111
})
expect(result.providerParams).toStrictEqual({})
})
})
describe('Case sensitivity', () => {
it('should NOT extract parameters with incorrect case - uppercase first letter', () => {
const customParams = {
Temperature: 0.7,
TopK: 40,
FrequencyPenalty: 0.5
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
Temperature: 0.7,
TopK: 40,
FrequencyPenalty: 0.5
})
})
it('should NOT extract parameters with incorrect case - all uppercase', () => {
const customParams = {
TEMPERATURE: 0.7,
TOPK: 40,
SEED: 42
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
TEMPERATURE: 0.7,
TOPK: 40,
SEED: 42
})
})
it('should NOT extract parameters with incorrect case - all lowercase', () => {
const customParams = {
maxoutputtokens: 1000,
frequencypenalty: 0.5,
stopsequences: ['STOP']
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
maxoutputtokens: 1000,
frequencypenalty: 0.5,
stopsequences: ['STOP']
})
})
it('should correctly extract exact case match while rejecting incorrect case', () => {
const customParams = {
temperature: 0.7,
Temperature: 0.8,
TEMPERATURE: 0.9,
topK: 40,
TopK: 50
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
temperature: 0.7,
topK: 40
})
expect(result.providerParams).toStrictEqual({
Temperature: 0.8,
TEMPERATURE: 0.9,
TopK: 50
})
})
})
describe('Parameter name variations', () => {
it('should NOT extract similar but incorrect parameter names', () => {
const customParams = {
temp: 0.7, // should not match temperature
top_k: 40, // should not match topK
max_tokens: 1000, // should not match maxOutputTokens
freq_penalty: 0.5 // should not match frequencyPenalty
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
temp: 0.7,
top_k: 40,
max_tokens: 1000,
freq_penalty: 0.5
})
})
it('should NOT extract snake_case versions of standard parameters', () => {
const customParams = {
top_k: 40,
top_p: 0.9,
presence_penalty: 0.5,
frequency_penalty: 0.3,
stop_sequences: ['STOP'],
max_output_tokens: 1000
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({})
expect(result.providerParams).toStrictEqual({
top_k: 40,
top_p: 0.9,
presence_penalty: 0.5,
frequency_penalty: 0.3,
stop_sequences: ['STOP'],
max_output_tokens: 1000
})
})
it('should extract exact camelCase parameters only', () => {
const customParams = {
topK: 40, // correct
top_k: 50, // incorrect
topP: 0.9, // correct
top_p: 0.8, // incorrect
frequencyPenalty: 0.5, // correct
frequency_penalty: 0.4 // incorrect
}
const result = extractAiSdkStandardParams(customParams)
expect(result.standardParams).toStrictEqual({
topK: 40,
topP: 0.9,
frequencyPenalty: 0.5
})
expect(result.providerParams).toStrictEqual({
top_k: 50,
top_p: 0.8,
frequency_penalty: 0.4
})
})
})
})

View File

@ -128,7 +128,20 @@ vi.mock('../reasoning', () => ({
reasoningConfig: { type: 'enabled', budgetTokens: 5000 }
})),
getReasoningEffort: vi.fn(() => ({ reasoningEffort: 'medium' })),
getCustomParameters: vi.fn(() => ({}))
getCustomParameters: vi.fn(() => ({})),
extractAiSdkStandardParams: vi.fn((customParams: Record<string, any>) => {
const AI_SDK_STANDARD_PARAMS = ['topK', 'frequencyPenalty', 'presencePenalty', 'stopSequences', 'seed']
const standardParams: Record<string, any> = {}
const providerParams: Record<string, any> = {}
for (const [key, value] of Object.entries(customParams)) {
if (AI_SDK_STANDARD_PARAMS.includes(key)) {
standardParams[key] = value
} else {
providerParams[key] = value
}
}
return { standardParams, providerParams }
})
}))
vi.mock('../image', () => ({
@ -184,8 +197,9 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result).toHaveProperty('openai')
expect(result.openai).toBeDefined()
expect(result.providerOptions).toHaveProperty('openai')
expect(result.providerOptions.openai).toBeDefined()
expect(result.standardParams).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
@ -195,8 +209,8 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result.openai).toHaveProperty('reasoningEffort')
expect(result.openai.reasoningEffort).toBe('medium')
expect(result.providerOptions.openai).toHaveProperty('reasoningEffort')
expect(result.providerOptions.openai.reasoningEffort).toBe('medium')
})
it('should include service tier when supported', () => {
@ -211,8 +225,8 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result.openai).toHaveProperty('serviceTier')
expect(result.openai.serviceTier).toBe(OpenAIServiceTiers.auto)
expect(result.providerOptions.openai).toHaveProperty('serviceTier')
expect(result.providerOptions.openai.serviceTier).toBe(OpenAIServiceTiers.auto)
})
})
@ -239,8 +253,8 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result).toHaveProperty('anthropic')
expect(result.anthropic).toBeDefined()
expect(result.providerOptions).toHaveProperty('anthropic')
expect(result.providerOptions.anthropic).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
@ -250,8 +264,8 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result.anthropic).toHaveProperty('thinking')
expect(result.anthropic.thinking).toEqual({
expect(result.providerOptions.anthropic).toHaveProperty('thinking')
expect(result.providerOptions.anthropic.thinking).toEqual({
type: 'enabled',
budgetTokens: 5000
})
@ -282,8 +296,8 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result).toHaveProperty('google')
expect(result.google).toBeDefined()
expect(result.providerOptions).toHaveProperty('google')
expect(result.providerOptions.google).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
@ -293,8 +307,8 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result.google).toHaveProperty('thinkingConfig')
expect(result.google.thinkingConfig).toEqual({
expect(result.providerOptions.google).toHaveProperty('thinkingConfig')
expect(result.providerOptions.google.thinkingConfig).toEqual({
include_thoughts: true
})
})
@ -306,8 +320,8 @@ describe('options utils', () => {
enableGenerateImage: true
})
expect(result.google).toHaveProperty('responseModalities')
expect(result.google.responseModalities).toEqual(['TEXT', 'IMAGE'])
expect(result.providerOptions.google).toHaveProperty('responseModalities')
expect(result.providerOptions.google.responseModalities).toEqual(['TEXT', 'IMAGE'])
})
})
@ -335,8 +349,8 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result).toHaveProperty('xai')
expect(result.xai).toBeDefined()
expect(result.providerOptions).toHaveProperty('xai')
expect(result.providerOptions.xai).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
@ -346,8 +360,8 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result.xai).toHaveProperty('reasoningEffort')
expect(result.xai.reasoningEffort).toBe('high')
expect(result.providerOptions.xai).toHaveProperty('reasoningEffort')
expect(result.providerOptions.xai.reasoningEffort).toBe('high')
})
})
@ -374,8 +388,8 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result).toHaveProperty('deepseek')
expect(result.deepseek).toBeDefined()
expect(result.providerOptions).toHaveProperty('deepseek')
expect(result.providerOptions.deepseek).toBeDefined()
})
})
@ -402,8 +416,8 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result).toHaveProperty('openrouter')
expect(result.openrouter).toBeDefined()
expect(result.providerOptions).toHaveProperty('openrouter')
expect(result.providerOptions.openrouter).toBeDefined()
})
it('should include web search parameters when enabled', () => {
@ -413,12 +427,12 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result.openrouter).toHaveProperty('enable_search')
expect(result.providerOptions.openrouter).toHaveProperty('enable_search')
})
})
describe('Custom parameters', () => {
it('should merge custom parameters', async () => {
it('should merge custom provider-specific parameters', async () => {
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({
@ -443,10 +457,88 @@ describe('options utils', () => {
}
)
expect(result.openai).toHaveProperty('custom_param')
expect(result.openai.custom_param).toBe('custom_value')
expect(result.openai).toHaveProperty('another_param')
expect(result.openai.another_param).toBe(123)
expect(result.providerOptions.openai).toHaveProperty('custom_param')
expect(result.providerOptions.openai.custom_param).toBe('custom_value')
expect(result.providerOptions.openai).toHaveProperty('another_param')
expect(result.providerOptions.openai.another_param).toBe(123)
})
it('should extract AI SDK standard params from custom parameters', async () => {
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({
topK: 5,
frequencyPenalty: 0.5,
presencePenalty: 0.3,
seed: 42,
custom_param: 'custom_value'
})
const result = buildProviderOptions(
mockAssistant,
mockModel,
{
id: SystemProviderIds.gemini,
name: 'Google',
type: 'gemini',
apiKey: 'test-key',
apiHost: 'https://generativelanguage.googleapis.com'
} as Provider,
{
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
}
)
// Standard params should be extracted and returned separately
expect(result.standardParams).toEqual({
topK: 5,
frequencyPenalty: 0.5,
presencePenalty: 0.3,
seed: 42
})
// Provider-specific params should still be in providerOptions
expect(result.providerOptions.google).toHaveProperty('custom_param')
expect(result.providerOptions.google.custom_param).toBe('custom_value')
// Standard params should NOT be in providerOptions
expect(result.providerOptions.google).not.toHaveProperty('topK')
expect(result.providerOptions.google).not.toHaveProperty('frequencyPenalty')
expect(result.providerOptions.google).not.toHaveProperty('presencePenalty')
expect(result.providerOptions.google).not.toHaveProperty('seed')
})
it('should handle stopSequences in custom parameters', async () => {
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({
stopSequences: ['STOP', 'END'],
custom_param: 'value'
})
const result = buildProviderOptions(
mockAssistant,
mockModel,
{
id: SystemProviderIds.gemini,
name: 'Google',
type: 'gemini',
apiKey: 'test-key',
apiHost: 'https://generativelanguage.googleapis.com'
} as Provider,
{
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
}
)
expect(result.standardParams).toEqual({
stopSequences: ['STOP', 'END']
})
expect(result.providerOptions.google).not.toHaveProperty('stopSequences')
})
})
@ -474,8 +566,8 @@ describe('options utils', () => {
enableGenerateImage: true
})
expect(result.google).toHaveProperty('thinkingConfig')
expect(result.google).toHaveProperty('responseModalities')
expect(result.providerOptions.google).toHaveProperty('thinkingConfig')
expect(result.providerOptions.google).toHaveProperty('responseModalities')
})
it('should handle all capabilities enabled', () => {
@ -485,8 +577,8 @@ describe('options utils', () => {
enableGenerateImage: true
})
expect(result.google).toBeDefined()
expect(Object.keys(result.google).length).toBeGreaterThan(0)
expect(result.providerOptions.google).toBeDefined()
expect(Object.keys(result.providerOptions.google).length).toBeGreaterThan(0)
})
})
@ -513,7 +605,7 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result).toHaveProperty('google')
expect(result.providerOptions).toHaveProperty('google')
})
it('should map google-vertex-anthropic to anthropic', () => {
@ -538,7 +630,7 @@ describe('options utils', () => {
enableGenerateImage: false
})
expect(result).toHaveProperty('anthropic')
expect(result.providerOptions).toHaveProperty('anthropic')
})
})
})

View File

@ -31,7 +31,7 @@ import {
type Provider,
type ServiceTier
} from '@renderer/types'
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import { type AiSdkParam, isAiSdkParam, type OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import { isSupportServiceTierProvider, isSupportVerbosityProvider } from '@renderer/utils/provider'
import type { JSONValue } from 'ai'
import { t } from 'i18next'
@ -96,10 +96,39 @@ function getVerbosity(): OpenAIVerbosity {
return openAI.verbosity
}
/**
* Extract AI SDK standard parameters from custom parameters
* These parameters should be passed directly to streamText() instead of providerOptions
*/
export function extractAiSdkStandardParams(customParams: Record<string, any>): {
standardParams: Partial<Record<AiSdkParam, any>>
providerParams: Record<string, any>
} {
const standardParams: Partial<Record<AiSdkParam, any>> = {}
const providerParams: Record<string, any> = {}
for (const [key, value] of Object.entries(customParams)) {
if (isAiSdkParam(key)) {
standardParams[key] = value
} else {
providerParams[key] = value
}
}
return { standardParams, providerParams }
}
/**
* AI SDK providerOptions
* provider
* { 'providerId': providerOptions }
* {
* providerOptions: { 'providerId': providerOptions },
* standardParams: { topK, frequencyPenalty, presencePenalty, stopSequences, seed }
* }
*
* Custom parameters are split into two categories:
* 1. AI SDK standard parameters (topK, frequencyPenalty, etc.) - returned separately to be passed to streamText()
* 2. Provider-specific parameters - merged into providerOptions
*/
export function buildProviderOptions(
assistant: Assistant,
@ -110,7 +139,10 @@ export function buildProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, Record<string, JSONValue>> {
): {
providerOptions: Record<string, Record<string, JSONValue>>
standardParams: Partial<Record<AiSdkParam, any>>
} {
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
const rawProviderId = getAiSdkProviderId(actualProvider)
// 构建 provider 特定的选项
@ -202,10 +234,14 @@ export function buildProviderOptions(
}
}
// 合并自定义参数到 provider 特定的选项中
// 获取自定义参数并分离标准参数和 provider 特定参数
const customParams = getCustomParameters(assistant)
const { standardParams, providerParams } = extractAiSdkStandardParams(customParams)
// 合并 provider 特定的自定义参数到 providerSpecificOptions
providerSpecificOptions = {
...providerSpecificOptions,
...getCustomParameters(assistant)
...providerParams
}
let rawProviderKey =
@ -220,9 +256,12 @@ export function buildProviderOptions(
rawProviderKey = { gemini: 'google' }[actualProvider.type] || actualProvider.type
}
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数
return {
[rawProviderKey]: providerSpecificOptions
providerOptions: {
[rawProviderKey]: providerSpecificOptions
},
standardParams
}
}

View File

@ -2,6 +2,7 @@ import type OpenAI from '@cherrystudio/openai'
import type { NotNull, NotUndefined } from '@types'
import type { ImageModel, LanguageModel } from 'ai'
import type { generateObject, generateText, ModelMessage, streamObject, streamText } from 'ai'
import * as z from 'zod'
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model' | 'messages'> &
(
@ -42,3 +43,20 @@ export type OpenAIReasoningEffort = OpenAI.ReasoningEffort
// I pick undefined as the unique falsy type since they seem like share the same meaning according to OpenAI API docs.
// Parameter would not be passed into request if it's undefined.
export type OpenAISummaryText = NotNull<OpenAI.Reasoning['summary']>
const AiSdkParamsSchema = z.enum([
'maxOutputTokens',
'temperature',
'topP',
'topK',
'presencePenalty',
'frequencyPenalty',
'stopSequences',
'seed'
])
export type AiSdkParam = z.infer<typeof AiSdkParamsSchema>
export const isAiSdkParam = (param: string): param is AiSdkParam => {
return AiSdkParamsSchema.safeParse(param).success
}