From 346af4d338fbd3719a3d0af9f2c3380273fb3e47 Mon Sep 17 00:00:00 2001 From: beyondkmp Date: Tue, 4 Nov 2025 12:59:14 +0800 Subject: [PATCH] fix: add CherryAI provider support and update API host formatting (#11135) * fix: add CherryAI provider support and update API host formatting * format code * add ut * format code --- .../provider/__tests__/providerConfig.test.ts | 116 +++++++++++++++++- .../src/aiCore/provider/providerConfig.ts | 3 + src/renderer/src/config/providers.ts | 4 + 3 files changed, 119 insertions(+), 4 deletions(-) diff --git a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts index eb6e73c8ae..cc5f20c63e 100644 --- a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts +++ b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts @@ -21,10 +21,44 @@ vi.mock('@renderer/store', () => ({ } })) +vi.mock('@renderer/utils/api', () => ({ + formatApiHost: vi.fn((host, isSupportedAPIVersion = true) => { + if (isSupportedAPIVersion === false) { + return host // Return host as-is when isSupportedAPIVersion is false + } + return `${host}/v1` // Default behavior when isSupportedAPIVersion is true + }), + routeToEndpoint: vi.fn((host) => ({ + baseURL: host, + endpoint: '/chat/completions' + })) +})) + +vi.mock('@renderer/config/providers', async (importOriginal) => { + const actual = (await importOriginal()) as any + return { + ...actual, + isCherryAIProvider: vi.fn(), + isAnthropicProvider: vi.fn(() => false), + isAzureOpenAIProvider: vi.fn(() => false), + isGeminiProvider: vi.fn(() => false), + isNewApiProvider: vi.fn(() => false) + } +}) + +vi.mock('@renderer/hooks/useVertexAI', () => ({ + isVertexProvider: vi.fn(() => false), + isVertexAIConfigured: vi.fn(() => false), + createVertexProvider: vi.fn() +})) + +import { isCherryAIProvider } from '@renderer/config/providers' +import { getProviderByModel } from '@renderer/services/AssistantService' import type { Model, Provider } from '@renderer/types' +import { formatApiHost } from '@renderer/utils/api' import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants' -import { providerToAiSdkConfig } from '../providerConfig' +import { getActualProvider, providerToAiSdkConfig } from '../providerConfig' const createWindowKeyv = () => { const store = new Map() @@ -46,11 +80,21 @@ const createCopilotProvider = (): Provider => ({ isSystem: true }) -const createModel = (id: string, name = id): Model => ({ +const createModel = (id: string, name = id, provider = 'copilot'): Model => ({ id, name, - provider: 'copilot', - group: 'copilot' + provider, + group: provider +}) + +const createCherryAIProvider = (): Provider => ({ + id: 'cherryai', + type: 'openai', + name: 'CherryAI', + apiKey: 'test-key', + apiHost: 'https://api.cherryai.com', + models: [], + isSystem: false }) describe('Copilot responses routing', () => { @@ -87,3 +131,67 @@ describe('Copilot responses routing', () => { expect(config.options.headers?.['Copilot-Integration-Id']).toBe(COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id']) }) }) + +describe('CherryAI provider configuration', () => { + beforeEach(() => { + ;(globalThis as any).window = { + ...(globalThis as any).window, + keyv: createWindowKeyv() + } + vi.clearAllMocks() + }) + + it('formats CherryAI provider apiHost with false parameter', () => { + const provider = createCherryAIProvider() + const model = createModel('gpt-4', 'GPT-4', 'cherryai') + + // Mock the functions to simulate CherryAI provider detection + vi.mocked(isCherryAIProvider).mockReturnValue(true) + vi.mocked(getProviderByModel).mockReturnValue(provider) + + // Call getActualProvider which should trigger formatProviderApiHost + const actualProvider = getActualProvider(model) + + // Verify that formatApiHost was called with false as the second parameter + expect(formatApiHost).toHaveBeenCalledWith('https://api.cherryai.com', false) + expect(actualProvider.apiHost).toBe('https://api.cherryai.com') + }) + + it('does not format non-CherryAI provider with false parameter', () => { + const provider = { + id: 'openai', + type: 'openai', + name: 'OpenAI', + apiKey: 'test-key', + apiHost: 'https://api.openai.com', + models: [], + isSystem: false + } as Provider + const model = createModel('gpt-4', 'GPT-4', 'openai') + + // Mock the functions to simulate non-CherryAI provider + vi.mocked(isCherryAIProvider).mockReturnValue(false) + vi.mocked(getProviderByModel).mockReturnValue(provider) + + // Call getActualProvider + const actualProvider = getActualProvider(model) + + // Verify that formatApiHost was called with default parameters (true) + expect(formatApiHost).toHaveBeenCalledWith('https://api.openai.com') + expect(actualProvider.apiHost).toBe('https://api.openai.com/v1') + }) + + it('handles CherryAI provider with empty apiHost', () => { + const provider = createCherryAIProvider() + provider.apiHost = '' + const model = createModel('gpt-4', 'GPT-4', 'cherryai') + + vi.mocked(isCherryAIProvider).mockReturnValue(true) + vi.mocked(getProviderByModel).mockReturnValue(provider) + + const actualProvider = getActualProvider(model) + + expect(formatApiHost).toHaveBeenCalledWith('', false) + expect(actualProvider.apiHost).toBe('') + }) +}) diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index c8447671d1..4669d4c851 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -9,6 +9,7 @@ import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models' import { isAnthropicProvider, isAzureOpenAIProvider, + isCherryAIProvider, isGeminiProvider, isNewApiProvider } from '@renderer/config/providers' @@ -100,6 +101,8 @@ function formatProviderApiHost(provider: Provider): Provider { formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost) } else if (isVertexProvider(formatted)) { formatted.apiHost = formatVertexApiHost(formatted) + } else if (isCherryAIProvider(formatted)) { + formatted.apiHost = formatApiHost(formatted.apiHost, false) } else { formatted.apiHost = formatApiHost(formatted.apiHost) } diff --git a/src/renderer/src/config/providers.ts b/src/renderer/src/config/providers.ts index 2b9fe2b21f..5fbae73dbf 100644 --- a/src/renderer/src/config/providers.ts +++ b/src/renderer/src/config/providers.ts @@ -1486,6 +1486,10 @@ export const isNewApiProvider = (provider: Provider) => { return ['new-api', 'cherryin'].includes(provider.id) || provider.type === 'new-api' } +export function isCherryAIProvider(provider: Provider): boolean { + return provider.id === 'cherryai' +} + /** * 判断是否为 OpenAI 兼容的提供商 * @param {Provider} provider 提供商对象