diff --git a/src/renderer/src/aiCore/clients/ApiClientFactory.ts b/src/renderer/src/aiCore/clients/ApiClientFactory.ts index d815dc923f..991bdccf34 100644 --- a/src/renderer/src/aiCore/clients/ApiClientFactory.ts +++ b/src/renderer/src/aiCore/clients/ApiClientFactory.ts @@ -72,6 +72,7 @@ export class ApiClientFactory { } } -export function isOpenAIProvider(provider: Provider) { - return !['anthropic', 'gemini'].includes(provider.type) -} +// 移除这个函数,它已经移动到 utils/index.ts +// export function isOpenAIProvider(provider: Provider) { +// return !['anthropic', 'gemini'].includes(provider.type) +// } diff --git a/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts b/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts new file mode 100644 index 0000000000..9fdfcb2687 --- /dev/null +++ b/src/renderer/src/aiCore/clients/__tests__/ApiClientFactory.test.ts @@ -0,0 +1,208 @@ +import { Provider } from '@renderer/types' +import { isOpenAIProvider } from '@renderer/utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { AihubmixAPIClient } from '../AihubmixAPIClient' +import { AnthropicAPIClient } from '../anthropic/AnthropicAPIClient' +import { ApiClientFactory } from '../ApiClientFactory' +import { GeminiAPIClient } from '../gemini/GeminiAPIClient' +import { VertexAPIClient } from '../gemini/VertexAPIClient' +import { NewAPIClient } from '../NewAPIClient' +import { OpenAIAPIClient } from '../openai/OpenAIApiClient' +import { OpenAIResponseAPIClient } from '../openai/OpenAIResponseAPIClient' +import { PPIOAPIClient } from '../ppio/PPIOAPIClient' + +// 为工厂测试创建最小化 provider 的辅助函数 +// ApiClientFactory 只使用 'id' 和 'type' 字段来决定创建哪个客户端 +// 其他字段会传递给客户端构造函数,但不影响工厂逻辑 +const createTestProvider = (id: string, type: string): Provider => ({ + id, + type: type as Provider['type'], + name: '', + apiKey: '', + apiHost: '', + models: [] +}) + +// Mock 所有客户端模块 +vi.mock('../AihubmixAPIClient', () => ({ + AihubmixAPIClient: vi.fn().mockImplementation(() => ({})) +})) +vi.mock('../anthropic/AnthropicAPIClient', () => ({ + AnthropicAPIClient: vi.fn().mockImplementation(() => ({})) +})) +vi.mock('../gemini/GeminiAPIClient', () => ({ + GeminiAPIClient: vi.fn().mockImplementation(() => ({})) +})) +vi.mock('../gemini/VertexAPIClient', () => ({ + VertexAPIClient: vi.fn().mockImplementation(() => ({})) +})) +vi.mock('../NewAPIClient', () => ({ + NewAPIClient: vi.fn().mockImplementation(() => ({})) +})) +vi.mock('../openai/OpenAIApiClient', () => ({ + OpenAIAPIClient: vi.fn().mockImplementation(() => ({})) +})) +vi.mock('../openai/OpenAIResponseAPIClient', () => ({ + OpenAIResponseAPIClient: vi.fn().mockImplementation(() => ({ + getClient: vi.fn().mockReturnThis() + })) +})) +vi.mock('../ppio/PPIOAPIClient', () => ({ + PPIOAPIClient: vi.fn().mockImplementation(() => ({})) +})) + +describe('ApiClientFactory', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('create', () => { + // 测试特殊 ID 的客户端创建 + it('should create AihubmixAPIClient for aihubmix provider', () => { + const provider = createTestProvider('aihubmix', 'openai') + + const client = ApiClientFactory.create(provider) + + expect(AihubmixAPIClient).toHaveBeenCalledWith(provider) + expect(client).toBeDefined() + }) + + it('should create NewAPIClient for new-api provider', () => { + const provider = createTestProvider('new-api', 'openai') + + const client = ApiClientFactory.create(provider) + + expect(NewAPIClient).toHaveBeenCalledWith(provider) + expect(client).toBeDefined() + }) + + it('should create PPIOAPIClient for ppio provider', () => { + const provider = createTestProvider('ppio', 'openai') + + const client = ApiClientFactory.create(provider) + + expect(PPIOAPIClient).toHaveBeenCalledWith(provider) + expect(client).toBeDefined() + }) + + // 测试标准类型的客户端创建 + it('should create OpenAIAPIClient for openai type', () => { + const provider = createTestProvider('custom-openai', 'openai') + + const client = ApiClientFactory.create(provider) + + expect(OpenAIAPIClient).toHaveBeenCalledWith(provider) + expect(client).toBeDefined() + }) + + it('should create OpenAIResponseAPIClient for azure-openai type', () => { + const provider = createTestProvider('azure-openai', 'azure-openai') + + const client = ApiClientFactory.create(provider) + + expect(OpenAIResponseAPIClient).toHaveBeenCalledWith(provider) + expect(client).toBeDefined() + }) + + it('should create OpenAIResponseAPIClient for openai-response type', () => { + const provider = createTestProvider('response', 'openai-response') + + const client = ApiClientFactory.create(provider) + + expect(OpenAIResponseAPIClient).toHaveBeenCalledWith(provider) + expect(client).toBeDefined() + }) + + it('should create GeminiAPIClient for gemini type', () => { + const provider = createTestProvider('gemini', 'gemini') + + const client = ApiClientFactory.create(provider) + + expect(GeminiAPIClient).toHaveBeenCalledWith(provider) + expect(client).toBeDefined() + }) + + it('should create VertexAPIClient for vertexai type', () => { + const provider = createTestProvider('vertex', 'vertexai') + + const client = ApiClientFactory.create(provider) + + expect(VertexAPIClient).toHaveBeenCalledWith(provider) + expect(client).toBeDefined() + }) + + it('should create AnthropicAPIClient for anthropic type', () => { + const provider = createTestProvider('anthropic', 'anthropic') + + const client = ApiClientFactory.create(provider) + + expect(AnthropicAPIClient).toHaveBeenCalledWith(provider) + expect(client).toBeDefined() + }) + + // 测试默认情况 + it('should create OpenAIAPIClient as default for unknown type', () => { + const provider = createTestProvider('unknown', 'unknown-type') + + const client = ApiClientFactory.create(provider) + + expect(OpenAIAPIClient).toHaveBeenCalledWith(provider) + expect(client).toBeDefined() + }) + + // 测试边界条件 + it('should handle provider with minimal configuration', () => { + const provider = createTestProvider('minimal', 'openai') + + const client = ApiClientFactory.create(provider) + + expect(OpenAIAPIClient).toHaveBeenCalledWith(provider) + expect(client).toBeDefined() + }) + + // 测试特殊 ID 优先级高于类型 + it('should prioritize special ID over type', () => { + const provider = createTestProvider('aihubmix', 'anthropic') // 即使类型是 anthropic + + const client = ApiClientFactory.create(provider) + + // 应该创建 AihubmixAPIClient 而不是 AnthropicAPIClient + expect(AihubmixAPIClient).toHaveBeenCalledWith(provider) + expect(AnthropicAPIClient).not.toHaveBeenCalled() + expect(client).toBeDefined() + }) + }) + + describe('isOpenAIProvider', () => { + it('should return true for openai type', () => { + const provider = createTestProvider('openai', 'openai') + expect(isOpenAIProvider(provider)).toBe(true) + }) + + it('should return true for azure-openai type', () => { + const provider = createTestProvider('azure-openai', 'azure-openai') + expect(isOpenAIProvider(provider)).toBe(true) + }) + + it('should return true for unknown type (fallback to OpenAI)', () => { + const provider = createTestProvider('unknown', 'unknown') + expect(isOpenAIProvider(provider)).toBe(true) + }) + + it('should return false for vertexai type', () => { + const provider = createTestProvider('vertex', 'vertexai') + expect(isOpenAIProvider(provider)).toBe(false) + }) + + it('should return false for anthropic type', () => { + const provider = createTestProvider('anthropic', 'anthropic') + expect(isOpenAIProvider(provider)).toBe(false) + }) + + it('should return false for gemini type', () => { + const provider = createTestProvider('gemini', 'gemini') + expect(isOpenAIProvider(provider)).toBe(false) + }) + }) +}) diff --git a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx index d2d1f9b136..e177cc014c 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx @@ -1,5 +1,4 @@ import { CheckOutlined, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons' -import { isOpenAIProvider } from '@renderer/aiCore/clients/ApiClientFactory' import OpenAIAlert from '@renderer/components/Alert/OpenAIAlert' import { StreamlineGoodHealthAndWellBeing } from '@renderer/components/Icons/SVGIcon' import { HStack } from '@renderer/components/Layout' @@ -12,7 +11,13 @@ import i18n from '@renderer/i18n' import { checkApi } from '@renderer/services/ApiService' import { checkModelsHealth, getModelCheckSummary } from '@renderer/services/HealthCheckService' import { isProviderSupportAuth } from '@renderer/services/ProviderService' -import { formatApiHost, formatApiKeys, getFancyProviderName, splitApiKeyString } from '@renderer/utils' +import { + formatApiHost, + formatApiKeys, + getFancyProviderName, + isOpenAIProvider, + splitApiKeyString +} from '@renderer/utils' import { formatErrorMessage } from '@renderer/utils/error' import { lightbulbVariants } from '@renderer/utils/motionVariants' import { Button, Divider, Flex, Input, Space, Switch, Tooltip } from 'antd' diff --git a/src/renderer/src/utils/index.ts b/src/renderer/src/utils/index.ts index a97b82d8cf..4ca6b2e1e5 100644 --- a/src/renderer/src/utils/index.ts +++ b/src/renderer/src/utils/index.ts @@ -1,7 +1,6 @@ import Logger from '@renderer/config/logger' -import { Model } from '@renderer/types' -import { ModalFuncProps } from 'antd/es/modal/interface' -// @ts-ignore next-line` +import { Model, Provider } from '@renderer/types' +import { ModalFuncProps } from 'antd' import { v4 as uuidv4 } from 'uuid' /** @@ -227,6 +226,15 @@ export function getMcpConfigSampleFromReadme(readme: string): Record