From c1f4b5b9b9f8171fd8d05efeb5002b171ae65609 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Wed, 26 Nov 2025 13:16:58 +0800 Subject: [PATCH] Fix: custom parameters for Gemini models (#11456) * Initial plan * fix(aiCore): extract AI SDK standard params from custom params for Gemini Custom parameters like topK, frequencyPenalty, presencePenalty, stopSequences, and seed should be passed as top-level streamText() parameters, not in providerOptions. This fixes the issue where these parameters were being ignored by the AI SDK's @ai-sdk/google module. Changes: - Add extractAiSdkStandardParams function to separate standard params - Update buildProviderOptions to return both providerOptions and standardParams - Update buildStreamTextParams to spread standardParams into params object - Update tests to reflect new return structure Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com> * refactor(aiCore): remove extractAiSdkStandardParams function and its tests, streamline parameter extraction logic * chore: type --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com> Co-authored-by: suyao --- .../aiCore/prepareParams/parameterBuilder.ts | 7 +- .../extractAiSdkStandardParams.test.ts | 652 ++++++++++++++++++ .../aiCore/utils/__tests__/options.test.ts | 166 ++++- src/renderer/src/aiCore/utils/options.ts | 53 +- src/renderer/src/types/aiCoreTypes.ts | 18 + 5 files changed, 851 insertions(+), 45 deletions(-) create mode 100644 src/renderer/src/aiCore/utils/__tests__/extractAiSdkStandardParams.test.ts diff --git a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts index c9a9d20b3c..dda3bd0b47 100644 --- a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts +++ b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts @@ -106,7 +106,7 @@ export async function buildStreamTextParams( searchWithTime: store.getState().websearch.searchWithTime } - const providerOptions = buildProviderOptions(assistant, model, provider, { + const { providerOptions, standardParams } = buildProviderOptions(assistant, model, provider, { enableReasoning, enableWebSearch, enableGenerateImage @@ -181,11 +181,16 @@ export async function buildStreamTextParams( } // 构建基础参数 + // Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed) + // are extracted from custom parameters and passed directly to streamText() + // instead of being placed in providerOptions const params: StreamTextParams = { messages: sdkMessages, maxOutputTokens: getMaxTokens(assistant, model), temperature: getTemperature(assistant, model), topP: getTopP(assistant, model), + // Include AI SDK standard params extracted from custom parameters + ...standardParams, abortSignal: options.requestOptions?.signal, headers, providerOptions, diff --git a/src/renderer/src/aiCore/utils/__tests__/extractAiSdkStandardParams.test.ts b/src/renderer/src/aiCore/utils/__tests__/extractAiSdkStandardParams.test.ts new file mode 100644 index 0000000000..288cc2e4a5 --- /dev/null +++ b/src/renderer/src/aiCore/utils/__tests__/extractAiSdkStandardParams.test.ts @@ -0,0 +1,652 @@ +/** + * extractAiSdkStandardParams Unit Tests + * Tests for extracting AI SDK standard parameters from custom parameters + */ + +import { describe, expect, it, vi } from 'vitest' + +import { extractAiSdkStandardParams } from '../options' + +// Mock logger to prevent errors +vi.mock('@logger', () => ({ + loggerService: { + withContext: () => ({ + debug: vi.fn(), + error: vi.fn(), + warn: vi.fn(), + info: vi.fn() + }) + } +})) + +// Mock settings store +vi.mock('@renderer/store/settings', () => ({ + default: (state = { settings: {} }) => state +})) + +// Mock hooks to prevent uuid errors +vi.mock('@renderer/hooks/useSettings', () => ({ + getStoreSetting: vi.fn(() => ({})) +})) + +// Mock uuid to prevent errors +vi.mock('uuid', () => ({ + v4: vi.fn(() => 'test-uuid') +})) + +// Mock AssistantService to prevent uuid errors +vi.mock('@renderer/services/AssistantService', () => ({ + getDefaultAssistant: vi.fn(() => ({ + id: 'test-assistant', + name: 'Test Assistant', + settings: {} + })), + getDefaultTopic: vi.fn(() => ({ + id: 'test-topic', + assistantId: 'test-assistant', + createdAt: new Date().toISOString() + })) +})) + +// Mock provider service +vi.mock('@renderer/services/ProviderService', () => ({ + getProviderById: vi.fn(() => ({ + id: 'test-provider', + name: 'Test Provider' + })) +})) + +// Mock config modules +vi.mock('@renderer/config/models', () => ({ + isOpenAIModel: vi.fn(() => false), + isQwenMTModel: vi.fn(() => false), + isSupportFlexServiceTierModel: vi.fn(() => false), + isSupportVerbosityModel: vi.fn(() => false), + getModelSupportedVerbosity: vi.fn(() => []) +})) + +vi.mock('@renderer/config/translate', () => ({ + mapLanguageToQwenMTModel: vi.fn() +})) + +vi.mock('@renderer/utils/provider', () => ({ + isSupportServiceTierProvider: vi.fn(() => false), + isSupportVerbosityProvider: vi.fn(() => false) +})) + +describe('extractAiSdkStandardParams', () => { + describe('Positive cases - Standard parameters extraction', () => { + it('should extract all AI SDK standard parameters', () => { + const customParams = { + maxOutputTokens: 1000, + temperature: 0.7, + topP: 0.9, + topK: 40, + presencePenalty: 0.5, + frequencyPenalty: 0.3, + stopSequences: ['STOP', 'END'], + seed: 42 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + maxOutputTokens: 1000, + temperature: 0.7, + topP: 0.9, + topK: 40, + presencePenalty: 0.5, + frequencyPenalty: 0.3, + stopSequences: ['STOP', 'END'], + seed: 42 + }) + expect(result.providerParams).toStrictEqual({}) + }) + + it('should extract single standard parameter', () => { + const customParams = { + temperature: 0.8 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + temperature: 0.8 + }) + expect(result.providerParams).toStrictEqual({}) + }) + + it('should extract topK parameter', () => { + const customParams = { + topK: 50 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + topK: 50 + }) + expect(result.providerParams).toStrictEqual({}) + }) + + it('should extract frequencyPenalty parameter', () => { + const customParams = { + frequencyPenalty: 0.6 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + frequencyPenalty: 0.6 + }) + expect(result.providerParams).toStrictEqual({}) + }) + + it('should extract presencePenalty parameter', () => { + const customParams = { + presencePenalty: 0.4 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + presencePenalty: 0.4 + }) + expect(result.providerParams).toStrictEqual({}) + }) + + it('should extract stopSequences parameter', () => { + const customParams = { + stopSequences: ['HALT', 'TERMINATE'] + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + stopSequences: ['HALT', 'TERMINATE'] + }) + expect(result.providerParams).toStrictEqual({}) + }) + + it('should extract seed parameter', () => { + const customParams = { + seed: 12345 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + seed: 12345 + }) + expect(result.providerParams).toStrictEqual({}) + }) + + it('should extract maxOutputTokens parameter', () => { + const customParams = { + maxOutputTokens: 2048 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + maxOutputTokens: 2048 + }) + expect(result.providerParams).toStrictEqual({}) + }) + + it('should extract topP parameter', () => { + const customParams = { + topP: 0.95 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + topP: 0.95 + }) + expect(result.providerParams).toStrictEqual({}) + }) + }) + + describe('Negative cases - Provider-specific parameters', () => { + it('should place all non-standard parameters in providerParams', () => { + const customParams = { + customParam: 'value', + anotherParam: 123, + thirdParam: true + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({}) + expect(result.providerParams).toStrictEqual({ + customParam: 'value', + anotherParam: 123, + thirdParam: true + }) + }) + + it('should place single provider-specific parameter in providerParams', () => { + const customParams = { + reasoningEffort: 'high' + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({}) + expect(result.providerParams).toStrictEqual({ + reasoningEffort: 'high' + }) + }) + + it('should place model-specific parameter in providerParams', () => { + const customParams = { + thinking: { type: 'enabled', budgetTokens: 5000 } + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({}) + expect(result.providerParams).toStrictEqual({ + thinking: { type: 'enabled', budgetTokens: 5000 } + }) + }) + + it('should place serviceTier in providerParams', () => { + const customParams = { + serviceTier: 'auto' + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({}) + expect(result.providerParams).toStrictEqual({ + serviceTier: 'auto' + }) + }) + + it('should place textVerbosity in providerParams', () => { + const customParams = { + textVerbosity: 'high' + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({}) + expect(result.providerParams).toStrictEqual({ + textVerbosity: 'high' + }) + }) + }) + + describe('Mixed parameters', () => { + it('should correctly separate mixed standard and provider-specific parameters', () => { + const customParams = { + temperature: 0.7, + topK: 40, + customParam: 'custom_value', + reasoningEffort: 'medium', + frequencyPenalty: 0.5, + seed: 999 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + temperature: 0.7, + topK: 40, + frequencyPenalty: 0.5, + seed: 999 + }) + expect(result.providerParams).toStrictEqual({ + customParam: 'custom_value', + reasoningEffort: 'medium' + }) + }) + + it('should handle complex mixed parameters with nested objects', () => { + const customParams = { + topP: 0.9, + presencePenalty: 0.3, + thinking: { type: 'enabled', budgetTokens: 5000 }, + stopSequences: ['STOP'], + serviceTier: 'auto', + maxOutputTokens: 4096 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + topP: 0.9, + presencePenalty: 0.3, + stopSequences: ['STOP'], + maxOutputTokens: 4096 + }) + expect(result.providerParams).toStrictEqual({ + thinking: { type: 'enabled', budgetTokens: 5000 }, + serviceTier: 'auto' + }) + }) + + it('should handle all standard params with some provider params', () => { + const customParams = { + maxOutputTokens: 2000, + temperature: 0.8, + topP: 0.95, + topK: 50, + presencePenalty: 0.6, + frequencyPenalty: 0.4, + stopSequences: ['END', 'DONE'], + seed: 777, + customApiParam: 'value', + anotherCustomParam: 123 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + maxOutputTokens: 2000, + temperature: 0.8, + topP: 0.95, + topK: 50, + presencePenalty: 0.6, + frequencyPenalty: 0.4, + stopSequences: ['END', 'DONE'], + seed: 777 + }) + expect(result.providerParams).toStrictEqual({ + customApiParam: 'value', + anotherCustomParam: 123 + }) + }) + }) + + describe('Edge cases', () => { + it('should handle empty object', () => { + const customParams = {} + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({}) + expect(result.providerParams).toStrictEqual({}) + }) + + it('should handle zero values for numeric parameters', () => { + const customParams = { + temperature: 0, + topK: 0, + seed: 0 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + temperature: 0, + topK: 0, + seed: 0 + }) + expect(result.providerParams).toStrictEqual({}) + }) + + it('should handle negative values for numeric parameters', () => { + const customParams = { + presencePenalty: -0.5, + frequencyPenalty: -0.3, + seed: -1 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + presencePenalty: -0.5, + frequencyPenalty: -0.3, + seed: -1 + }) + expect(result.providerParams).toStrictEqual({}) + }) + + it('should handle empty arrays for stopSequences', () => { + const customParams = { + stopSequences: [] + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + stopSequences: [] + }) + expect(result.providerParams).toStrictEqual({}) + }) + + it('should handle null values in mixed parameters', () => { + const customParams = { + temperature: 0.7, + customNull: null, + topK: 40 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + temperature: 0.7, + topK: 40 + }) + expect(result.providerParams).toStrictEqual({ + customNull: null + }) + }) + + it('should handle undefined values in mixed parameters', () => { + const customParams = { + temperature: 0.7, + customUndefined: undefined, + topK: 40 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + temperature: 0.7, + topK: 40 + }) + expect(result.providerParams).toStrictEqual({ + customUndefined: undefined + }) + }) + + it('should handle boolean values for standard parameters', () => { + const customParams = { + temperature: 0.7, + customBoolean: false, + topK: 40 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + temperature: 0.7, + topK: 40 + }) + expect(result.providerParams).toStrictEqual({ + customBoolean: false + }) + }) + + it('should handle very large numeric values', () => { + const customParams = { + maxOutputTokens: 999999, + seed: 2147483647, + topK: 10000 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + maxOutputTokens: 999999, + seed: 2147483647, + topK: 10000 + }) + expect(result.providerParams).toStrictEqual({}) + }) + + it('should handle decimal values with high precision', () => { + const customParams = { + temperature: 0.123456789, + topP: 0.987654321, + presencePenalty: 0.111111111 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + temperature: 0.123456789, + topP: 0.987654321, + presencePenalty: 0.111111111 + }) + expect(result.providerParams).toStrictEqual({}) + }) + }) + + describe('Case sensitivity', () => { + it('should NOT extract parameters with incorrect case - uppercase first letter', () => { + const customParams = { + Temperature: 0.7, + TopK: 40, + FrequencyPenalty: 0.5 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({}) + expect(result.providerParams).toStrictEqual({ + Temperature: 0.7, + TopK: 40, + FrequencyPenalty: 0.5 + }) + }) + + it('should NOT extract parameters with incorrect case - all uppercase', () => { + const customParams = { + TEMPERATURE: 0.7, + TOPK: 40, + SEED: 42 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({}) + expect(result.providerParams).toStrictEqual({ + TEMPERATURE: 0.7, + TOPK: 40, + SEED: 42 + }) + }) + + it('should NOT extract parameters with incorrect case - all lowercase', () => { + const customParams = { + maxoutputtokens: 1000, + frequencypenalty: 0.5, + stopsequences: ['STOP'] + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({}) + expect(result.providerParams).toStrictEqual({ + maxoutputtokens: 1000, + frequencypenalty: 0.5, + stopsequences: ['STOP'] + }) + }) + + it('should correctly extract exact case match while rejecting incorrect case', () => { + const customParams = { + temperature: 0.7, + Temperature: 0.8, + TEMPERATURE: 0.9, + topK: 40, + TopK: 50 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + temperature: 0.7, + topK: 40 + }) + expect(result.providerParams).toStrictEqual({ + Temperature: 0.8, + TEMPERATURE: 0.9, + TopK: 50 + }) + }) + }) + + describe('Parameter name variations', () => { + it('should NOT extract similar but incorrect parameter names', () => { + const customParams = { + temp: 0.7, // should not match temperature + top_k: 40, // should not match topK + max_tokens: 1000, // should not match maxOutputTokens + freq_penalty: 0.5 // should not match frequencyPenalty + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({}) + expect(result.providerParams).toStrictEqual({ + temp: 0.7, + top_k: 40, + max_tokens: 1000, + freq_penalty: 0.5 + }) + }) + + it('should NOT extract snake_case versions of standard parameters', () => { + const customParams = { + top_k: 40, + top_p: 0.9, + presence_penalty: 0.5, + frequency_penalty: 0.3, + stop_sequences: ['STOP'], + max_output_tokens: 1000 + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({}) + expect(result.providerParams).toStrictEqual({ + top_k: 40, + top_p: 0.9, + presence_penalty: 0.5, + frequency_penalty: 0.3, + stop_sequences: ['STOP'], + max_output_tokens: 1000 + }) + }) + + it('should extract exact camelCase parameters only', () => { + const customParams = { + topK: 40, // correct + top_k: 50, // incorrect + topP: 0.9, // correct + top_p: 0.8, // incorrect + frequencyPenalty: 0.5, // correct + frequency_penalty: 0.4 // incorrect + } + + const result = extractAiSdkStandardParams(customParams) + + expect(result.standardParams).toStrictEqual({ + topK: 40, + topP: 0.9, + frequencyPenalty: 0.5 + }) + expect(result.providerParams).toStrictEqual({ + top_k: 50, + top_p: 0.8, + frequency_penalty: 0.4 + }) + }) + }) +}) diff --git a/src/renderer/src/aiCore/utils/__tests__/options.test.ts b/src/renderer/src/aiCore/utils/__tests__/options.test.ts index 4bf8447d65..8f2629f4d8 100644 --- a/src/renderer/src/aiCore/utils/__tests__/options.test.ts +++ b/src/renderer/src/aiCore/utils/__tests__/options.test.ts @@ -128,7 +128,20 @@ vi.mock('../reasoning', () => ({ reasoningConfig: { type: 'enabled', budgetTokens: 5000 } })), getReasoningEffort: vi.fn(() => ({ reasoningEffort: 'medium' })), - getCustomParameters: vi.fn(() => ({})) + getCustomParameters: vi.fn(() => ({})), + extractAiSdkStandardParams: vi.fn((customParams: Record) => { + const AI_SDK_STANDARD_PARAMS = ['topK', 'frequencyPenalty', 'presencePenalty', 'stopSequences', 'seed'] + const standardParams: Record = {} + const providerParams: Record = {} + for (const [key, value] of Object.entries(customParams)) { + if (AI_SDK_STANDARD_PARAMS.includes(key)) { + standardParams[key] = value + } else { + providerParams[key] = value + } + } + return { standardParams, providerParams } + }) })) vi.mock('../image', () => ({ @@ -184,8 +197,9 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result).toHaveProperty('openai') - expect(result.openai).toBeDefined() + expect(result.providerOptions).toHaveProperty('openai') + expect(result.providerOptions.openai).toBeDefined() + expect(result.standardParams).toBeDefined() }) it('should include reasoning parameters when enabled', () => { @@ -195,8 +209,8 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result.openai).toHaveProperty('reasoningEffort') - expect(result.openai.reasoningEffort).toBe('medium') + expect(result.providerOptions.openai).toHaveProperty('reasoningEffort') + expect(result.providerOptions.openai.reasoningEffort).toBe('medium') }) it('should include service tier when supported', () => { @@ -211,8 +225,8 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result.openai).toHaveProperty('serviceTier') - expect(result.openai.serviceTier).toBe(OpenAIServiceTiers.auto) + expect(result.providerOptions.openai).toHaveProperty('serviceTier') + expect(result.providerOptions.openai.serviceTier).toBe(OpenAIServiceTiers.auto) }) }) @@ -239,8 +253,8 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result).toHaveProperty('anthropic') - expect(result.anthropic).toBeDefined() + expect(result.providerOptions).toHaveProperty('anthropic') + expect(result.providerOptions.anthropic).toBeDefined() }) it('should include reasoning parameters when enabled', () => { @@ -250,8 +264,8 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result.anthropic).toHaveProperty('thinking') - expect(result.anthropic.thinking).toEqual({ + expect(result.providerOptions.anthropic).toHaveProperty('thinking') + expect(result.providerOptions.anthropic.thinking).toEqual({ type: 'enabled', budgetTokens: 5000 }) @@ -282,8 +296,8 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result).toHaveProperty('google') - expect(result.google).toBeDefined() + expect(result.providerOptions).toHaveProperty('google') + expect(result.providerOptions.google).toBeDefined() }) it('should include reasoning parameters when enabled', () => { @@ -293,8 +307,8 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result.google).toHaveProperty('thinkingConfig') - expect(result.google.thinkingConfig).toEqual({ + expect(result.providerOptions.google).toHaveProperty('thinkingConfig') + expect(result.providerOptions.google.thinkingConfig).toEqual({ include_thoughts: true }) }) @@ -306,8 +320,8 @@ describe('options utils', () => { enableGenerateImage: true }) - expect(result.google).toHaveProperty('responseModalities') - expect(result.google.responseModalities).toEqual(['TEXT', 'IMAGE']) + expect(result.providerOptions.google).toHaveProperty('responseModalities') + expect(result.providerOptions.google.responseModalities).toEqual(['TEXT', 'IMAGE']) }) }) @@ -335,8 +349,8 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result).toHaveProperty('xai') - expect(result.xai).toBeDefined() + expect(result.providerOptions).toHaveProperty('xai') + expect(result.providerOptions.xai).toBeDefined() }) it('should include reasoning parameters when enabled', () => { @@ -346,8 +360,8 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result.xai).toHaveProperty('reasoningEffort') - expect(result.xai.reasoningEffort).toBe('high') + expect(result.providerOptions.xai).toHaveProperty('reasoningEffort') + expect(result.providerOptions.xai.reasoningEffort).toBe('high') }) }) @@ -374,8 +388,8 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result).toHaveProperty('deepseek') - expect(result.deepseek).toBeDefined() + expect(result.providerOptions).toHaveProperty('deepseek') + expect(result.providerOptions.deepseek).toBeDefined() }) }) @@ -402,8 +416,8 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result).toHaveProperty('openrouter') - expect(result.openrouter).toBeDefined() + expect(result.providerOptions).toHaveProperty('openrouter') + expect(result.providerOptions.openrouter).toBeDefined() }) it('should include web search parameters when enabled', () => { @@ -413,12 +427,12 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result.openrouter).toHaveProperty('enable_search') + expect(result.providerOptions.openrouter).toHaveProperty('enable_search') }) }) describe('Custom parameters', () => { - it('should merge custom parameters', async () => { + it('should merge custom provider-specific parameters', async () => { const { getCustomParameters } = await import('../reasoning') vi.mocked(getCustomParameters).mockReturnValue({ @@ -443,10 +457,88 @@ describe('options utils', () => { } ) - expect(result.openai).toHaveProperty('custom_param') - expect(result.openai.custom_param).toBe('custom_value') - expect(result.openai).toHaveProperty('another_param') - expect(result.openai.another_param).toBe(123) + expect(result.providerOptions.openai).toHaveProperty('custom_param') + expect(result.providerOptions.openai.custom_param).toBe('custom_value') + expect(result.providerOptions.openai).toHaveProperty('another_param') + expect(result.providerOptions.openai.another_param).toBe(123) + }) + + it('should extract AI SDK standard params from custom parameters', async () => { + const { getCustomParameters } = await import('../reasoning') + + vi.mocked(getCustomParameters).mockReturnValue({ + topK: 5, + frequencyPenalty: 0.5, + presencePenalty: 0.3, + seed: 42, + custom_param: 'custom_value' + }) + + const result = buildProviderOptions( + mockAssistant, + mockModel, + { + id: SystemProviderIds.gemini, + name: 'Google', + type: 'gemini', + apiKey: 'test-key', + apiHost: 'https://generativelanguage.googleapis.com' + } as Provider, + { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + } + ) + + // Standard params should be extracted and returned separately + expect(result.standardParams).toEqual({ + topK: 5, + frequencyPenalty: 0.5, + presencePenalty: 0.3, + seed: 42 + }) + + // Provider-specific params should still be in providerOptions + expect(result.providerOptions.google).toHaveProperty('custom_param') + expect(result.providerOptions.google.custom_param).toBe('custom_value') + + // Standard params should NOT be in providerOptions + expect(result.providerOptions.google).not.toHaveProperty('topK') + expect(result.providerOptions.google).not.toHaveProperty('frequencyPenalty') + expect(result.providerOptions.google).not.toHaveProperty('presencePenalty') + expect(result.providerOptions.google).not.toHaveProperty('seed') + }) + + it('should handle stopSequences in custom parameters', async () => { + const { getCustomParameters } = await import('../reasoning') + + vi.mocked(getCustomParameters).mockReturnValue({ + stopSequences: ['STOP', 'END'], + custom_param: 'value' + }) + + const result = buildProviderOptions( + mockAssistant, + mockModel, + { + id: SystemProviderIds.gemini, + name: 'Google', + type: 'gemini', + apiKey: 'test-key', + apiHost: 'https://generativelanguage.googleapis.com' + } as Provider, + { + enableReasoning: false, + enableWebSearch: false, + enableGenerateImage: false + } + ) + + expect(result.standardParams).toEqual({ + stopSequences: ['STOP', 'END'] + }) + expect(result.providerOptions.google).not.toHaveProperty('stopSequences') }) }) @@ -474,8 +566,8 @@ describe('options utils', () => { enableGenerateImage: true }) - expect(result.google).toHaveProperty('thinkingConfig') - expect(result.google).toHaveProperty('responseModalities') + expect(result.providerOptions.google).toHaveProperty('thinkingConfig') + expect(result.providerOptions.google).toHaveProperty('responseModalities') }) it('should handle all capabilities enabled', () => { @@ -485,8 +577,8 @@ describe('options utils', () => { enableGenerateImage: true }) - expect(result.google).toBeDefined() - expect(Object.keys(result.google).length).toBeGreaterThan(0) + expect(result.providerOptions.google).toBeDefined() + expect(Object.keys(result.providerOptions.google).length).toBeGreaterThan(0) }) }) @@ -513,7 +605,7 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result).toHaveProperty('google') + expect(result.providerOptions).toHaveProperty('google') }) it('should map google-vertex-anthropic to anthropic', () => { @@ -538,7 +630,7 @@ describe('options utils', () => { enableGenerateImage: false }) - expect(result).toHaveProperty('anthropic') + expect(result.providerOptions).toHaveProperty('anthropic') }) }) }) diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index 47e7d2510c..f420908ba6 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -31,7 +31,7 @@ import { type Provider, type ServiceTier } from '@renderer/types' -import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes' +import { type AiSdkParam, isAiSdkParam, type OpenAIVerbosity } from '@renderer/types/aiCoreTypes' import { isSupportServiceTierProvider, isSupportVerbosityProvider } from '@renderer/utils/provider' import type { JSONValue } from 'ai' import { t } from 'i18next' @@ -96,10 +96,39 @@ function getVerbosity(): OpenAIVerbosity { return openAI.verbosity } +/** + * Extract AI SDK standard parameters from custom parameters + * These parameters should be passed directly to streamText() instead of providerOptions + */ +export function extractAiSdkStandardParams(customParams: Record): { + standardParams: Partial> + providerParams: Record +} { + const standardParams: Partial> = {} + const providerParams: Record = {} + + for (const [key, value] of Object.entries(customParams)) { + if (isAiSdkParam(key)) { + standardParams[key] = value + } else { + providerParams[key] = value + } + } + + return { standardParams, providerParams } +} + /** * 构建 AI SDK 的 providerOptions * 按 provider 类型分离,保持类型安全 - * 返回格式:{ 'providerId': providerOptions } + * 返回格式:{ + * providerOptions: { 'providerId': providerOptions }, + * standardParams: { topK, frequencyPenalty, presencePenalty, stopSequences, seed } + * } + * + * Custom parameters are split into two categories: + * 1. AI SDK standard parameters (topK, frequencyPenalty, etc.) - returned separately to be passed to streamText() + * 2. Provider-specific parameters - merged into providerOptions */ export function buildProviderOptions( assistant: Assistant, @@ -110,7 +139,10 @@ export function buildProviderOptions( enableWebSearch: boolean enableGenerateImage: boolean } -): Record> { +): { + providerOptions: Record> + standardParams: Partial> +} { logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities }) const rawProviderId = getAiSdkProviderId(actualProvider) // 构建 provider 特定的选项 @@ -202,10 +234,14 @@ export function buildProviderOptions( } } - // 合并自定义参数到 provider 特定的选项中 + // 获取自定义参数并分离标准参数和 provider 特定参数 + const customParams = getCustomParameters(assistant) + const { standardParams, providerParams } = extractAiSdkStandardParams(customParams) + + // 合并 provider 特定的自定义参数到 providerSpecificOptions providerSpecificOptions = { ...providerSpecificOptions, - ...getCustomParameters(assistant) + ...providerParams } let rawProviderKey = @@ -220,9 +256,12 @@ export function buildProviderOptions( rawProviderKey = { gemini: 'google' }[actualProvider.type] || actualProvider.type } - // 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } + // 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数 return { - [rawProviderKey]: providerSpecificOptions + providerOptions: { + [rawProviderKey]: providerSpecificOptions + }, + standardParams } } diff --git a/src/renderer/src/types/aiCoreTypes.ts b/src/renderer/src/types/aiCoreTypes.ts index 6327fe6835..2e4c09348b 100644 --- a/src/renderer/src/types/aiCoreTypes.ts +++ b/src/renderer/src/types/aiCoreTypes.ts @@ -2,6 +2,7 @@ import type OpenAI from '@cherrystudio/openai' import type { NotNull, NotUndefined } from '@types' import type { ImageModel, LanguageModel } from 'ai' import type { generateObject, generateText, ModelMessage, streamObject, streamText } from 'ai' +import * as z from 'zod' export type StreamTextParams = Omit[0], 'model' | 'messages'> & ( @@ -42,3 +43,20 @@ export type OpenAIReasoningEffort = OpenAI.ReasoningEffort // I pick undefined as the unique falsy type since they seem like share the same meaning according to OpenAI API docs. // Parameter would not be passed into request if it's undefined. export type OpenAISummaryText = NotNull + +const AiSdkParamsSchema = z.enum([ + 'maxOutputTokens', + 'temperature', + 'topP', + 'topK', + 'presencePenalty', + 'frequencyPenalty', + 'stopSequences', + 'seed' +]) + +export type AiSdkParam = z.infer + +export const isAiSdkParam = (param: string): param is AiSdkParam => { + return AiSdkParamsSchema.safeParse(param).success +}