mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-30 07:39:06 +08:00
Fix: custom parameters for Gemini models (#11456)
* Initial plan * fix(aiCore): extract AI SDK standard params from custom params for Gemini Custom parameters like topK, frequencyPenalty, presencePenalty, stopSequences, and seed should be passed as top-level streamText() parameters, not in providerOptions. This fixes the issue where these parameters were being ignored by the AI SDK's @ai-sdk/google module. Changes: - Add extractAiSdkStandardParams function to separate standard params - Update buildProviderOptions to return both providerOptions and standardParams - Update buildStreamTextParams to spread standardParams into params object - Update tests to reflect new return structure Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com> * refactor(aiCore): remove extractAiSdkStandardParams function and its tests, streamline parameter extraction logic * chore: type --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com> Co-authored-by: suyao <sy20010504@gmail.com>
This commit is contained in:
parent
5fb59d21ec
commit
c1f4b5b9b9
@ -106,7 +106,7 @@ export async function buildStreamTextParams(
|
||||
searchWithTime: store.getState().websearch.searchWithTime
|
||||
}
|
||||
|
||||
const providerOptions = buildProviderOptions(assistant, model, provider, {
|
||||
const { providerOptions, standardParams } = buildProviderOptions(assistant, model, provider, {
|
||||
enableReasoning,
|
||||
enableWebSearch,
|
||||
enableGenerateImage
|
||||
@ -181,11 +181,16 @@ export async function buildStreamTextParams(
|
||||
}
|
||||
|
||||
// 构建基础参数
|
||||
// Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed)
|
||||
// are extracted from custom parameters and passed directly to streamText()
|
||||
// instead of being placed in providerOptions
|
||||
const params: StreamTextParams = {
|
||||
messages: sdkMessages,
|
||||
maxOutputTokens: getMaxTokens(assistant, model),
|
||||
temperature: getTemperature(assistant, model),
|
||||
topP: getTopP(assistant, model),
|
||||
// Include AI SDK standard params extracted from custom parameters
|
||||
...standardParams,
|
||||
abortSignal: options.requestOptions?.signal,
|
||||
headers,
|
||||
providerOptions,
|
||||
|
||||
@ -0,0 +1,652 @@
|
||||
/**
|
||||
* extractAiSdkStandardParams Unit Tests
|
||||
* Tests for extracting AI SDK standard parameters from custom parameters
|
||||
*/
|
||||
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { extractAiSdkStandardParams } from '../options'
|
||||
|
||||
// Mock logger to prevent errors
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({
|
||||
debug: vi.fn(),
|
||||
error: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
info: vi.fn()
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock settings store
|
||||
vi.mock('@renderer/store/settings', () => ({
|
||||
default: (state = { settings: {} }) => state
|
||||
}))
|
||||
|
||||
// Mock hooks to prevent uuid errors
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
getStoreSetting: vi.fn(() => ({}))
|
||||
}))
|
||||
|
||||
// Mock uuid to prevent errors
|
||||
vi.mock('uuid', () => ({
|
||||
v4: vi.fn(() => 'test-uuid')
|
||||
}))
|
||||
|
||||
// Mock AssistantService to prevent uuid errors
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getDefaultAssistant: vi.fn(() => ({
|
||||
id: 'test-assistant',
|
||||
name: 'Test Assistant',
|
||||
settings: {}
|
||||
})),
|
||||
getDefaultTopic: vi.fn(() => ({
|
||||
id: 'test-topic',
|
||||
assistantId: 'test-assistant',
|
||||
createdAt: new Date().toISOString()
|
||||
}))
|
||||
}))
|
||||
|
||||
// Mock provider service
|
||||
vi.mock('@renderer/services/ProviderService', () => ({
|
||||
getProviderById: vi.fn(() => ({
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider'
|
||||
}))
|
||||
}))
|
||||
|
||||
// Mock config modules
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
isOpenAIModel: vi.fn(() => false),
|
||||
isQwenMTModel: vi.fn(() => false),
|
||||
isSupportFlexServiceTierModel: vi.fn(() => false),
|
||||
isSupportVerbosityModel: vi.fn(() => false),
|
||||
getModelSupportedVerbosity: vi.fn(() => [])
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/config/translate', () => ({
|
||||
mapLanguageToQwenMTModel: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/provider', () => ({
|
||||
isSupportServiceTierProvider: vi.fn(() => false),
|
||||
isSupportVerbosityProvider: vi.fn(() => false)
|
||||
}))
|
||||
|
||||
describe('extractAiSdkStandardParams', () => {
|
||||
describe('Positive cases - Standard parameters extraction', () => {
|
||||
it('should extract all AI SDK standard parameters', () => {
|
||||
const customParams = {
|
||||
maxOutputTokens: 1000,
|
||||
temperature: 0.7,
|
||||
topP: 0.9,
|
||||
topK: 40,
|
||||
presencePenalty: 0.5,
|
||||
frequencyPenalty: 0.3,
|
||||
stopSequences: ['STOP', 'END'],
|
||||
seed: 42
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
maxOutputTokens: 1000,
|
||||
temperature: 0.7,
|
||||
topP: 0.9,
|
||||
topK: 40,
|
||||
presencePenalty: 0.5,
|
||||
frequencyPenalty: 0.3,
|
||||
stopSequences: ['STOP', 'END'],
|
||||
seed: 42
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract single standard parameter', () => {
|
||||
const customParams = {
|
||||
temperature: 0.8
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.8
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract topK parameter', () => {
|
||||
const customParams = {
|
||||
topK: 50
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
topK: 50
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract frequencyPenalty parameter', () => {
|
||||
const customParams = {
|
||||
frequencyPenalty: 0.6
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
frequencyPenalty: 0.6
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract presencePenalty parameter', () => {
|
||||
const customParams = {
|
||||
presencePenalty: 0.4
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
presencePenalty: 0.4
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract stopSequences parameter', () => {
|
||||
const customParams = {
|
||||
stopSequences: ['HALT', 'TERMINATE']
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
stopSequences: ['HALT', 'TERMINATE']
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract seed parameter', () => {
|
||||
const customParams = {
|
||||
seed: 12345
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
seed: 12345
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract maxOutputTokens parameter', () => {
|
||||
const customParams = {
|
||||
maxOutputTokens: 2048
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
maxOutputTokens: 2048
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should extract topP parameter', () => {
|
||||
const customParams = {
|
||||
topP: 0.95
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
topP: 0.95
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Negative cases - Provider-specific parameters', () => {
|
||||
it('should place all non-standard parameters in providerParams', () => {
|
||||
const customParams = {
|
||||
customParam: 'value',
|
||||
anotherParam: 123,
|
||||
thirdParam: true
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
customParam: 'value',
|
||||
anotherParam: 123,
|
||||
thirdParam: true
|
||||
})
|
||||
})
|
||||
|
||||
it('should place single provider-specific parameter in providerParams', () => {
|
||||
const customParams = {
|
||||
reasoningEffort: 'high'
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
reasoningEffort: 'high'
|
||||
})
|
||||
})
|
||||
|
||||
it('should place model-specific parameter in providerParams', () => {
|
||||
const customParams = {
|
||||
thinking: { type: 'enabled', budgetTokens: 5000 }
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
thinking: { type: 'enabled', budgetTokens: 5000 }
|
||||
})
|
||||
})
|
||||
|
||||
it('should place serviceTier in providerParams', () => {
|
||||
const customParams = {
|
||||
serviceTier: 'auto'
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
serviceTier: 'auto'
|
||||
})
|
||||
})
|
||||
|
||||
it('should place textVerbosity in providerParams', () => {
|
||||
const customParams = {
|
||||
textVerbosity: 'high'
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
textVerbosity: 'high'
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Mixed parameters', () => {
|
||||
it('should correctly separate mixed standard and provider-specific parameters', () => {
|
||||
const customParams = {
|
||||
temperature: 0.7,
|
||||
topK: 40,
|
||||
customParam: 'custom_value',
|
||||
reasoningEffort: 'medium',
|
||||
frequencyPenalty: 0.5,
|
||||
seed: 999
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.7,
|
||||
topK: 40,
|
||||
frequencyPenalty: 0.5,
|
||||
seed: 999
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
customParam: 'custom_value',
|
||||
reasoningEffort: 'medium'
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle complex mixed parameters with nested objects', () => {
|
||||
const customParams = {
|
||||
topP: 0.9,
|
||||
presencePenalty: 0.3,
|
||||
thinking: { type: 'enabled', budgetTokens: 5000 },
|
||||
stopSequences: ['STOP'],
|
||||
serviceTier: 'auto',
|
||||
maxOutputTokens: 4096
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
topP: 0.9,
|
||||
presencePenalty: 0.3,
|
||||
stopSequences: ['STOP'],
|
||||
maxOutputTokens: 4096
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
thinking: { type: 'enabled', budgetTokens: 5000 },
|
||||
serviceTier: 'auto'
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle all standard params with some provider params', () => {
|
||||
const customParams = {
|
||||
maxOutputTokens: 2000,
|
||||
temperature: 0.8,
|
||||
topP: 0.95,
|
||||
topK: 50,
|
||||
presencePenalty: 0.6,
|
||||
frequencyPenalty: 0.4,
|
||||
stopSequences: ['END', 'DONE'],
|
||||
seed: 777,
|
||||
customApiParam: 'value',
|
||||
anotherCustomParam: 123
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
maxOutputTokens: 2000,
|
||||
temperature: 0.8,
|
||||
topP: 0.95,
|
||||
topK: 50,
|
||||
presencePenalty: 0.6,
|
||||
frequencyPenalty: 0.4,
|
||||
stopSequences: ['END', 'DONE'],
|
||||
seed: 777
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
customApiParam: 'value',
|
||||
anotherCustomParam: 123
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge cases', () => {
|
||||
it('should handle empty object', () => {
|
||||
const customParams = {}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should handle zero values for numeric parameters', () => {
|
||||
const customParams = {
|
||||
temperature: 0,
|
||||
topK: 0,
|
||||
seed: 0
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0,
|
||||
topK: 0,
|
||||
seed: 0
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should handle negative values for numeric parameters', () => {
|
||||
const customParams = {
|
||||
presencePenalty: -0.5,
|
||||
frequencyPenalty: -0.3,
|
||||
seed: -1
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
presencePenalty: -0.5,
|
||||
frequencyPenalty: -0.3,
|
||||
seed: -1
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should handle empty arrays for stopSequences', () => {
|
||||
const customParams = {
|
||||
stopSequences: []
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
stopSequences: []
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should handle null values in mixed parameters', () => {
|
||||
const customParams = {
|
||||
temperature: 0.7,
|
||||
customNull: null,
|
||||
topK: 40
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.7,
|
||||
topK: 40
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
customNull: null
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle undefined values in mixed parameters', () => {
|
||||
const customParams = {
|
||||
temperature: 0.7,
|
||||
customUndefined: undefined,
|
||||
topK: 40
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.7,
|
||||
topK: 40
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
customUndefined: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle boolean values for standard parameters', () => {
|
||||
const customParams = {
|
||||
temperature: 0.7,
|
||||
customBoolean: false,
|
||||
topK: 40
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.7,
|
||||
topK: 40
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
customBoolean: false
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle very large numeric values', () => {
|
||||
const customParams = {
|
||||
maxOutputTokens: 999999,
|
||||
seed: 2147483647,
|
||||
topK: 10000
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
maxOutputTokens: 999999,
|
||||
seed: 2147483647,
|
||||
topK: 10000
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
|
||||
it('should handle decimal values with high precision', () => {
|
||||
const customParams = {
|
||||
temperature: 0.123456789,
|
||||
topP: 0.987654321,
|
||||
presencePenalty: 0.111111111
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.123456789,
|
||||
topP: 0.987654321,
|
||||
presencePenalty: 0.111111111
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Case sensitivity', () => {
|
||||
it('should NOT extract parameters with incorrect case - uppercase first letter', () => {
|
||||
const customParams = {
|
||||
Temperature: 0.7,
|
||||
TopK: 40,
|
||||
FrequencyPenalty: 0.5
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
Temperature: 0.7,
|
||||
TopK: 40,
|
||||
FrequencyPenalty: 0.5
|
||||
})
|
||||
})
|
||||
|
||||
it('should NOT extract parameters with incorrect case - all uppercase', () => {
|
||||
const customParams = {
|
||||
TEMPERATURE: 0.7,
|
||||
TOPK: 40,
|
||||
SEED: 42
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
TEMPERATURE: 0.7,
|
||||
TOPK: 40,
|
||||
SEED: 42
|
||||
})
|
||||
})
|
||||
|
||||
it('should NOT extract parameters with incorrect case - all lowercase', () => {
|
||||
const customParams = {
|
||||
maxoutputtokens: 1000,
|
||||
frequencypenalty: 0.5,
|
||||
stopsequences: ['STOP']
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
maxoutputtokens: 1000,
|
||||
frequencypenalty: 0.5,
|
||||
stopsequences: ['STOP']
|
||||
})
|
||||
})
|
||||
|
||||
it('should correctly extract exact case match while rejecting incorrect case', () => {
|
||||
const customParams = {
|
||||
temperature: 0.7,
|
||||
Temperature: 0.8,
|
||||
TEMPERATURE: 0.9,
|
||||
topK: 40,
|
||||
TopK: 50
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
temperature: 0.7,
|
||||
topK: 40
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
Temperature: 0.8,
|
||||
TEMPERATURE: 0.9,
|
||||
TopK: 50
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Parameter name variations', () => {
|
||||
it('should NOT extract similar but incorrect parameter names', () => {
|
||||
const customParams = {
|
||||
temp: 0.7, // should not match temperature
|
||||
top_k: 40, // should not match topK
|
||||
max_tokens: 1000, // should not match maxOutputTokens
|
||||
freq_penalty: 0.5 // should not match frequencyPenalty
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
temp: 0.7,
|
||||
top_k: 40,
|
||||
max_tokens: 1000,
|
||||
freq_penalty: 0.5
|
||||
})
|
||||
})
|
||||
|
||||
it('should NOT extract snake_case versions of standard parameters', () => {
|
||||
const customParams = {
|
||||
top_k: 40,
|
||||
top_p: 0.9,
|
||||
presence_penalty: 0.5,
|
||||
frequency_penalty: 0.3,
|
||||
stop_sequences: ['STOP'],
|
||||
max_output_tokens: 1000
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
top_k: 40,
|
||||
top_p: 0.9,
|
||||
presence_penalty: 0.5,
|
||||
frequency_penalty: 0.3,
|
||||
stop_sequences: ['STOP'],
|
||||
max_output_tokens: 1000
|
||||
})
|
||||
})
|
||||
|
||||
it('should extract exact camelCase parameters only', () => {
|
||||
const customParams = {
|
||||
topK: 40, // correct
|
||||
top_k: 50, // incorrect
|
||||
topP: 0.9, // correct
|
||||
top_p: 0.8, // incorrect
|
||||
frequencyPenalty: 0.5, // correct
|
||||
frequency_penalty: 0.4 // incorrect
|
||||
}
|
||||
|
||||
const result = extractAiSdkStandardParams(customParams)
|
||||
|
||||
expect(result.standardParams).toStrictEqual({
|
||||
topK: 40,
|
||||
topP: 0.9,
|
||||
frequencyPenalty: 0.5
|
||||
})
|
||||
expect(result.providerParams).toStrictEqual({
|
||||
top_k: 50,
|
||||
top_p: 0.8,
|
||||
frequency_penalty: 0.4
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -128,7 +128,20 @@ vi.mock('../reasoning', () => ({
|
||||
reasoningConfig: { type: 'enabled', budgetTokens: 5000 }
|
||||
})),
|
||||
getReasoningEffort: vi.fn(() => ({ reasoningEffort: 'medium' })),
|
||||
getCustomParameters: vi.fn(() => ({}))
|
||||
getCustomParameters: vi.fn(() => ({})),
|
||||
extractAiSdkStandardParams: vi.fn((customParams: Record<string, any>) => {
|
||||
const AI_SDK_STANDARD_PARAMS = ['topK', 'frequencyPenalty', 'presencePenalty', 'stopSequences', 'seed']
|
||||
const standardParams: Record<string, any> = {}
|
||||
const providerParams: Record<string, any> = {}
|
||||
for (const [key, value] of Object.entries(customParams)) {
|
||||
if (AI_SDK_STANDARD_PARAMS.includes(key)) {
|
||||
standardParams[key] = value
|
||||
} else {
|
||||
providerParams[key] = value
|
||||
}
|
||||
}
|
||||
return { standardParams, providerParams }
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('../image', () => ({
|
||||
@ -184,8 +197,9 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('openai')
|
||||
expect(result.openai).toBeDefined()
|
||||
expect(result.providerOptions).toHaveProperty('openai')
|
||||
expect(result.providerOptions.openai).toBeDefined()
|
||||
expect(result.standardParams).toBeDefined()
|
||||
})
|
||||
|
||||
it('should include reasoning parameters when enabled', () => {
|
||||
@ -195,8 +209,8 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.openai).toHaveProperty('reasoningEffort')
|
||||
expect(result.openai.reasoningEffort).toBe('medium')
|
||||
expect(result.providerOptions.openai).toHaveProperty('reasoningEffort')
|
||||
expect(result.providerOptions.openai.reasoningEffort).toBe('medium')
|
||||
})
|
||||
|
||||
it('should include service tier when supported', () => {
|
||||
@ -211,8 +225,8 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.openai).toHaveProperty('serviceTier')
|
||||
expect(result.openai.serviceTier).toBe(OpenAIServiceTiers.auto)
|
||||
expect(result.providerOptions.openai).toHaveProperty('serviceTier')
|
||||
expect(result.providerOptions.openai.serviceTier).toBe(OpenAIServiceTiers.auto)
|
||||
})
|
||||
})
|
||||
|
||||
@ -239,8 +253,8 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('anthropic')
|
||||
expect(result.anthropic).toBeDefined()
|
||||
expect(result.providerOptions).toHaveProperty('anthropic')
|
||||
expect(result.providerOptions.anthropic).toBeDefined()
|
||||
})
|
||||
|
||||
it('should include reasoning parameters when enabled', () => {
|
||||
@ -250,8 +264,8 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.anthropic).toHaveProperty('thinking')
|
||||
expect(result.anthropic.thinking).toEqual({
|
||||
expect(result.providerOptions.anthropic).toHaveProperty('thinking')
|
||||
expect(result.providerOptions.anthropic.thinking).toEqual({
|
||||
type: 'enabled',
|
||||
budgetTokens: 5000
|
||||
})
|
||||
@ -282,8 +296,8 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('google')
|
||||
expect(result.google).toBeDefined()
|
||||
expect(result.providerOptions).toHaveProperty('google')
|
||||
expect(result.providerOptions.google).toBeDefined()
|
||||
})
|
||||
|
||||
it('should include reasoning parameters when enabled', () => {
|
||||
@ -293,8 +307,8 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.google).toHaveProperty('thinkingConfig')
|
||||
expect(result.google.thinkingConfig).toEqual({
|
||||
expect(result.providerOptions.google).toHaveProperty('thinkingConfig')
|
||||
expect(result.providerOptions.google.thinkingConfig).toEqual({
|
||||
include_thoughts: true
|
||||
})
|
||||
})
|
||||
@ -306,8 +320,8 @@ describe('options utils', () => {
|
||||
enableGenerateImage: true
|
||||
})
|
||||
|
||||
expect(result.google).toHaveProperty('responseModalities')
|
||||
expect(result.google.responseModalities).toEqual(['TEXT', 'IMAGE'])
|
||||
expect(result.providerOptions.google).toHaveProperty('responseModalities')
|
||||
expect(result.providerOptions.google.responseModalities).toEqual(['TEXT', 'IMAGE'])
|
||||
})
|
||||
})
|
||||
|
||||
@ -335,8 +349,8 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('xai')
|
||||
expect(result.xai).toBeDefined()
|
||||
expect(result.providerOptions).toHaveProperty('xai')
|
||||
expect(result.providerOptions.xai).toBeDefined()
|
||||
})
|
||||
|
||||
it('should include reasoning parameters when enabled', () => {
|
||||
@ -346,8 +360,8 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.xai).toHaveProperty('reasoningEffort')
|
||||
expect(result.xai.reasoningEffort).toBe('high')
|
||||
expect(result.providerOptions.xai).toHaveProperty('reasoningEffort')
|
||||
expect(result.providerOptions.xai.reasoningEffort).toBe('high')
|
||||
})
|
||||
})
|
||||
|
||||
@ -374,8 +388,8 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('deepseek')
|
||||
expect(result.deepseek).toBeDefined()
|
||||
expect(result.providerOptions).toHaveProperty('deepseek')
|
||||
expect(result.providerOptions.deepseek).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
@ -402,8 +416,8 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('openrouter')
|
||||
expect(result.openrouter).toBeDefined()
|
||||
expect(result.providerOptions).toHaveProperty('openrouter')
|
||||
expect(result.providerOptions.openrouter).toBeDefined()
|
||||
})
|
||||
|
||||
it('should include web search parameters when enabled', () => {
|
||||
@ -413,12 +427,12 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.openrouter).toHaveProperty('enable_search')
|
||||
expect(result.providerOptions.openrouter).toHaveProperty('enable_search')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Custom parameters', () => {
|
||||
it('should merge custom parameters', async () => {
|
||||
it('should merge custom provider-specific parameters', async () => {
|
||||
const { getCustomParameters } = await import('../reasoning')
|
||||
|
||||
vi.mocked(getCustomParameters).mockReturnValue({
|
||||
@ -443,10 +457,88 @@ describe('options utils', () => {
|
||||
}
|
||||
)
|
||||
|
||||
expect(result.openai).toHaveProperty('custom_param')
|
||||
expect(result.openai.custom_param).toBe('custom_value')
|
||||
expect(result.openai).toHaveProperty('another_param')
|
||||
expect(result.openai.another_param).toBe(123)
|
||||
expect(result.providerOptions.openai).toHaveProperty('custom_param')
|
||||
expect(result.providerOptions.openai.custom_param).toBe('custom_value')
|
||||
expect(result.providerOptions.openai).toHaveProperty('another_param')
|
||||
expect(result.providerOptions.openai.another_param).toBe(123)
|
||||
})
|
||||
|
||||
it('should extract AI SDK standard params from custom parameters', async () => {
|
||||
const { getCustomParameters } = await import('../reasoning')
|
||||
|
||||
vi.mocked(getCustomParameters).mockReturnValue({
|
||||
topK: 5,
|
||||
frequencyPenalty: 0.5,
|
||||
presencePenalty: 0.3,
|
||||
seed: 42,
|
||||
custom_param: 'custom_value'
|
||||
})
|
||||
|
||||
const result = buildProviderOptions(
|
||||
mockAssistant,
|
||||
mockModel,
|
||||
{
|
||||
id: SystemProviderIds.gemini,
|
||||
name: 'Google',
|
||||
type: 'gemini',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://generativelanguage.googleapis.com'
|
||||
} as Provider,
|
||||
{
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
}
|
||||
)
|
||||
|
||||
// Standard params should be extracted and returned separately
|
||||
expect(result.standardParams).toEqual({
|
||||
topK: 5,
|
||||
frequencyPenalty: 0.5,
|
||||
presencePenalty: 0.3,
|
||||
seed: 42
|
||||
})
|
||||
|
||||
// Provider-specific params should still be in providerOptions
|
||||
expect(result.providerOptions.google).toHaveProperty('custom_param')
|
||||
expect(result.providerOptions.google.custom_param).toBe('custom_value')
|
||||
|
||||
// Standard params should NOT be in providerOptions
|
||||
expect(result.providerOptions.google).not.toHaveProperty('topK')
|
||||
expect(result.providerOptions.google).not.toHaveProperty('frequencyPenalty')
|
||||
expect(result.providerOptions.google).not.toHaveProperty('presencePenalty')
|
||||
expect(result.providerOptions.google).not.toHaveProperty('seed')
|
||||
})
|
||||
|
||||
it('should handle stopSequences in custom parameters', async () => {
|
||||
const { getCustomParameters } = await import('../reasoning')
|
||||
|
||||
vi.mocked(getCustomParameters).mockReturnValue({
|
||||
stopSequences: ['STOP', 'END'],
|
||||
custom_param: 'value'
|
||||
})
|
||||
|
||||
const result = buildProviderOptions(
|
||||
mockAssistant,
|
||||
mockModel,
|
||||
{
|
||||
id: SystemProviderIds.gemini,
|
||||
name: 'Google',
|
||||
type: 'gemini',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://generativelanguage.googleapis.com'
|
||||
} as Provider,
|
||||
{
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
}
|
||||
)
|
||||
|
||||
expect(result.standardParams).toEqual({
|
||||
stopSequences: ['STOP', 'END']
|
||||
})
|
||||
expect(result.providerOptions.google).not.toHaveProperty('stopSequences')
|
||||
})
|
||||
})
|
||||
|
||||
@ -474,8 +566,8 @@ describe('options utils', () => {
|
||||
enableGenerateImage: true
|
||||
})
|
||||
|
||||
expect(result.google).toHaveProperty('thinkingConfig')
|
||||
expect(result.google).toHaveProperty('responseModalities')
|
||||
expect(result.providerOptions.google).toHaveProperty('thinkingConfig')
|
||||
expect(result.providerOptions.google).toHaveProperty('responseModalities')
|
||||
})
|
||||
|
||||
it('should handle all capabilities enabled', () => {
|
||||
@ -485,8 +577,8 @@ describe('options utils', () => {
|
||||
enableGenerateImage: true
|
||||
})
|
||||
|
||||
expect(result.google).toBeDefined()
|
||||
expect(Object.keys(result.google).length).toBeGreaterThan(0)
|
||||
expect(result.providerOptions.google).toBeDefined()
|
||||
expect(Object.keys(result.providerOptions.google).length).toBeGreaterThan(0)
|
||||
})
|
||||
})
|
||||
|
||||
@ -513,7 +605,7 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('google')
|
||||
expect(result.providerOptions).toHaveProperty('google')
|
||||
})
|
||||
|
||||
it('should map google-vertex-anthropic to anthropic', () => {
|
||||
@ -538,7 +630,7 @@ describe('options utils', () => {
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('anthropic')
|
||||
expect(result.providerOptions).toHaveProperty('anthropic')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -31,7 +31,7 @@ import {
|
||||
type Provider,
|
||||
type ServiceTier
|
||||
} from '@renderer/types'
|
||||
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
|
||||
import { type AiSdkParam, isAiSdkParam, type OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
|
||||
import { isSupportServiceTierProvider, isSupportVerbosityProvider } from '@renderer/utils/provider'
|
||||
import type { JSONValue } from 'ai'
|
||||
import { t } from 'i18next'
|
||||
@ -96,10 +96,39 @@ function getVerbosity(): OpenAIVerbosity {
|
||||
return openAI.verbosity
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract AI SDK standard parameters from custom parameters
|
||||
* These parameters should be passed directly to streamText() instead of providerOptions
|
||||
*/
|
||||
export function extractAiSdkStandardParams(customParams: Record<string, any>): {
|
||||
standardParams: Partial<Record<AiSdkParam, any>>
|
||||
providerParams: Record<string, any>
|
||||
} {
|
||||
const standardParams: Partial<Record<AiSdkParam, any>> = {}
|
||||
const providerParams: Record<string, any> = {}
|
||||
|
||||
for (const [key, value] of Object.entries(customParams)) {
|
||||
if (isAiSdkParam(key)) {
|
||||
standardParams[key] = value
|
||||
} else {
|
||||
providerParams[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return { standardParams, providerParams }
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 AI SDK 的 providerOptions
|
||||
* 按 provider 类型分离,保持类型安全
|
||||
* 返回格式:{ 'providerId': providerOptions }
|
||||
* 返回格式:{
|
||||
* providerOptions: { 'providerId': providerOptions },
|
||||
* standardParams: { topK, frequencyPenalty, presencePenalty, stopSequences, seed }
|
||||
* }
|
||||
*
|
||||
* Custom parameters are split into two categories:
|
||||
* 1. AI SDK standard parameters (topK, frequencyPenalty, etc.) - returned separately to be passed to streamText()
|
||||
* 2. Provider-specific parameters - merged into providerOptions
|
||||
*/
|
||||
export function buildProviderOptions(
|
||||
assistant: Assistant,
|
||||
@ -110,7 +139,10 @@ export function buildProviderOptions(
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, Record<string, JSONValue>> {
|
||||
): {
|
||||
providerOptions: Record<string, Record<string, JSONValue>>
|
||||
standardParams: Partial<Record<AiSdkParam, any>>
|
||||
} {
|
||||
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
|
||||
const rawProviderId = getAiSdkProviderId(actualProvider)
|
||||
// 构建 provider 特定的选项
|
||||
@ -202,10 +234,14 @@ export function buildProviderOptions(
|
||||
}
|
||||
}
|
||||
|
||||
// 合并自定义参数到 provider 特定的选项中
|
||||
// 获取自定义参数并分离标准参数和 provider 特定参数
|
||||
const customParams = getCustomParameters(assistant)
|
||||
const { standardParams, providerParams } = extractAiSdkStandardParams(customParams)
|
||||
|
||||
// 合并 provider 特定的自定义参数到 providerSpecificOptions
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
...getCustomParameters(assistant)
|
||||
...providerParams
|
||||
}
|
||||
|
||||
let rawProviderKey =
|
||||
@ -220,9 +256,12 @@ export function buildProviderOptions(
|
||||
rawProviderKey = { gemini: 'google' }[actualProvider.type] || actualProvider.type
|
||||
}
|
||||
|
||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
|
||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数
|
||||
return {
|
||||
[rawProviderKey]: providerSpecificOptions
|
||||
providerOptions: {
|
||||
[rawProviderKey]: providerSpecificOptions
|
||||
},
|
||||
standardParams
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ import type OpenAI from '@cherrystudio/openai'
|
||||
import type { NotNull, NotUndefined } from '@types'
|
||||
import type { ImageModel, LanguageModel } from 'ai'
|
||||
import type { generateObject, generateText, ModelMessage, streamObject, streamText } from 'ai'
|
||||
import * as z from 'zod'
|
||||
|
||||
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model' | 'messages'> &
|
||||
(
|
||||
@ -42,3 +43,20 @@ export type OpenAIReasoningEffort = OpenAI.ReasoningEffort
|
||||
// I pick undefined as the unique falsy type since they seem like share the same meaning according to OpenAI API docs.
|
||||
// Parameter would not be passed into request if it's undefined.
|
||||
export type OpenAISummaryText = NotNull<OpenAI.Reasoning['summary']>
|
||||
|
||||
const AiSdkParamsSchema = z.enum([
|
||||
'maxOutputTokens',
|
||||
'temperature',
|
||||
'topP',
|
||||
'topK',
|
||||
'presencePenalty',
|
||||
'frequencyPenalty',
|
||||
'stopSequences',
|
||||
'seed'
|
||||
])
|
||||
|
||||
export type AiSdkParam = z.infer<typeof AiSdkParamsSchema>
|
||||
|
||||
export const isAiSdkParam = (param: string): param is AiSdkParam => {
|
||||
return AiSdkParamsSchema.safeParse(param).success
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user