mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-06 13:19:33 +08:00
fix: add CherryAI provider support and update API host formatting (#11135)
* fix: add CherryAI provider support and update API host formatting * format code * add ut * format code
This commit is contained in:
parent
abd5d3b96f
commit
346af4d338
@ -21,10 +21,44 @@ vi.mock('@renderer/store', () => ({
|
|||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/utils/api', () => ({
|
||||||
|
formatApiHost: vi.fn((host, isSupportedAPIVersion = true) => {
|
||||||
|
if (isSupportedAPIVersion === false) {
|
||||||
|
return host // Return host as-is when isSupportedAPIVersion is false
|
||||||
|
}
|
||||||
|
return `${host}/v1` // Default behavior when isSupportedAPIVersion is true
|
||||||
|
}),
|
||||||
|
routeToEndpoint: vi.fn((host) => ({
|
||||||
|
baseURL: host,
|
||||||
|
endpoint: '/chat/completions'
|
||||||
|
}))
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/config/providers', async (importOriginal) => {
|
||||||
|
const actual = (await importOriginal()) as any
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
isCherryAIProvider: vi.fn(),
|
||||||
|
isAnthropicProvider: vi.fn(() => false),
|
||||||
|
isAzureOpenAIProvider: vi.fn(() => false),
|
||||||
|
isGeminiProvider: vi.fn(() => false),
|
||||||
|
isNewApiProvider: vi.fn(() => false)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mock('@renderer/hooks/useVertexAI', () => ({
|
||||||
|
isVertexProvider: vi.fn(() => false),
|
||||||
|
isVertexAIConfigured: vi.fn(() => false),
|
||||||
|
createVertexProvider: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
import { isCherryAIProvider } from '@renderer/config/providers'
|
||||||
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import type { Model, Provider } from '@renderer/types'
|
import type { Model, Provider } from '@renderer/types'
|
||||||
|
import { formatApiHost } from '@renderer/utils/api'
|
||||||
|
|
||||||
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
|
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
|
||||||
import { providerToAiSdkConfig } from '../providerConfig'
|
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'
|
||||||
|
|
||||||
const createWindowKeyv = () => {
|
const createWindowKeyv = () => {
|
||||||
const store = new Map<string, string>()
|
const store = new Map<string, string>()
|
||||||
@ -46,11 +80,21 @@ const createCopilotProvider = (): Provider => ({
|
|||||||
isSystem: true
|
isSystem: true
|
||||||
})
|
})
|
||||||
|
|
||||||
const createModel = (id: string, name = id): Model => ({
|
const createModel = (id: string, name = id, provider = 'copilot'): Model => ({
|
||||||
id,
|
id,
|
||||||
name,
|
name,
|
||||||
provider: 'copilot',
|
provider,
|
||||||
group: 'copilot'
|
group: provider
|
||||||
|
})
|
||||||
|
|
||||||
|
const createCherryAIProvider = (): Provider => ({
|
||||||
|
id: 'cherryai',
|
||||||
|
type: 'openai',
|
||||||
|
name: 'CherryAI',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://api.cherryai.com',
|
||||||
|
models: [],
|
||||||
|
isSystem: false
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('Copilot responses routing', () => {
|
describe('Copilot responses routing', () => {
|
||||||
@ -87,3 +131,67 @@ describe('Copilot responses routing', () => {
|
|||||||
expect(config.options.headers?.['Copilot-Integration-Id']).toBe(COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id'])
|
expect(config.options.headers?.['Copilot-Integration-Id']).toBe(COPILOT_DEFAULT_HEADERS['Copilot-Integration-Id'])
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('CherryAI provider configuration', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
;(globalThis as any).window = {
|
||||||
|
...(globalThis as any).window,
|
||||||
|
keyv: createWindowKeyv()
|
||||||
|
}
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('formats CherryAI provider apiHost with false parameter', () => {
|
||||||
|
const provider = createCherryAIProvider()
|
||||||
|
const model = createModel('gpt-4', 'GPT-4', 'cherryai')
|
||||||
|
|
||||||
|
// Mock the functions to simulate CherryAI provider detection
|
||||||
|
vi.mocked(isCherryAIProvider).mockReturnValue(true)
|
||||||
|
vi.mocked(getProviderByModel).mockReturnValue(provider)
|
||||||
|
|
||||||
|
// Call getActualProvider which should trigger formatProviderApiHost
|
||||||
|
const actualProvider = getActualProvider(model)
|
||||||
|
|
||||||
|
// Verify that formatApiHost was called with false as the second parameter
|
||||||
|
expect(formatApiHost).toHaveBeenCalledWith('https://api.cherryai.com', false)
|
||||||
|
expect(actualProvider.apiHost).toBe('https://api.cherryai.com')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('does not format non-CherryAI provider with false parameter', () => {
|
||||||
|
const provider = {
|
||||||
|
id: 'openai',
|
||||||
|
type: 'openai',
|
||||||
|
name: 'OpenAI',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://api.openai.com',
|
||||||
|
models: [],
|
||||||
|
isSystem: false
|
||||||
|
} as Provider
|
||||||
|
const model = createModel('gpt-4', 'GPT-4', 'openai')
|
||||||
|
|
||||||
|
// Mock the functions to simulate non-CherryAI provider
|
||||||
|
vi.mocked(isCherryAIProvider).mockReturnValue(false)
|
||||||
|
vi.mocked(getProviderByModel).mockReturnValue(provider)
|
||||||
|
|
||||||
|
// Call getActualProvider
|
||||||
|
const actualProvider = getActualProvider(model)
|
||||||
|
|
||||||
|
// Verify that formatApiHost was called with default parameters (true)
|
||||||
|
expect(formatApiHost).toHaveBeenCalledWith('https://api.openai.com')
|
||||||
|
expect(actualProvider.apiHost).toBe('https://api.openai.com/v1')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('handles CherryAI provider with empty apiHost', () => {
|
||||||
|
const provider = createCherryAIProvider()
|
||||||
|
provider.apiHost = ''
|
||||||
|
const model = createModel('gpt-4', 'GPT-4', 'cherryai')
|
||||||
|
|
||||||
|
vi.mocked(isCherryAIProvider).mockReturnValue(true)
|
||||||
|
vi.mocked(getProviderByModel).mockReturnValue(provider)
|
||||||
|
|
||||||
|
const actualProvider = getActualProvider(model)
|
||||||
|
|
||||||
|
expect(formatApiHost).toHaveBeenCalledWith('', false)
|
||||||
|
expect(actualProvider.apiHost).toBe('')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
|||||||
import {
|
import {
|
||||||
isAnthropicProvider,
|
isAnthropicProvider,
|
||||||
isAzureOpenAIProvider,
|
isAzureOpenAIProvider,
|
||||||
|
isCherryAIProvider,
|
||||||
isGeminiProvider,
|
isGeminiProvider,
|
||||||
isNewApiProvider
|
isNewApiProvider
|
||||||
} from '@renderer/config/providers'
|
} from '@renderer/config/providers'
|
||||||
@ -100,6 +101,8 @@ function formatProviderApiHost(provider: Provider): Provider {
|
|||||||
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
|
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
|
||||||
} else if (isVertexProvider(formatted)) {
|
} else if (isVertexProvider(formatted)) {
|
||||||
formatted.apiHost = formatVertexApiHost(formatted)
|
formatted.apiHost = formatVertexApiHost(formatted)
|
||||||
|
} else if (isCherryAIProvider(formatted)) {
|
||||||
|
formatted.apiHost = formatApiHost(formatted.apiHost, false)
|
||||||
} else {
|
} else {
|
||||||
formatted.apiHost = formatApiHost(formatted.apiHost)
|
formatted.apiHost = formatApiHost(formatted.apiHost)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1486,6 +1486,10 @@ export const isNewApiProvider = (provider: Provider) => {
|
|||||||
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === 'new-api'
|
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === 'new-api'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function isCherryAIProvider(provider: Provider): boolean {
|
||||||
|
return provider.id === 'cherryai'
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 判断是否为 OpenAI 兼容的提供商
|
* 判断是否为 OpenAI 兼容的提供商
|
||||||
* @param {Provider} provider 提供商对象
|
* @param {Provider} provider 提供商对象
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user