mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 21:01:32 +08:00
refactor(aiCore): streamline type exports and enhance provider registration
Removed unused type exports from the aiCore module and consolidated type definitions for better clarity. Updated provider registration tests to reflect new configurations and improved error handling for non-existent providers. Enhanced the overall structure of the provider management system, ensuring better type safety and consistency across the codebase.
This commit is contained in:
parent
4b7023f855
commit
d10ba04047
@ -34,35 +34,50 @@ vi.mock('@ai-sdk/xai', () => ({
|
||||
}))
|
||||
|
||||
import {
|
||||
AiProviderRegistry,
|
||||
cleanup,
|
||||
getAllDynamicMappings,
|
||||
getAllProviders,
|
||||
getAllValidProviderIds,
|
||||
getDynamicProviders,
|
||||
getProvider,
|
||||
getProviderMapping,
|
||||
isDynamicProvider,
|
||||
isProviderSupported,
|
||||
registerDynamicProvider,
|
||||
registerMultipleProviders,
|
||||
clearAllProviders,
|
||||
createAndRegisterProvider,
|
||||
createProvider,
|
||||
getAllProviderConfigAliases,
|
||||
getAllProviderConfigs,
|
||||
getInitializedProviders,
|
||||
getLanguageModel,
|
||||
getProviderConfig,
|
||||
getProviderConfigByAlias,
|
||||
getSupportedProviders,
|
||||
hasInitializedProviders,
|
||||
hasProviderConfig,
|
||||
hasProviderConfigByAlias,
|
||||
isProviderConfigAlias,
|
||||
ProviderInitializationError,
|
||||
providerRegistry,
|
||||
registerMultipleProviderConfigs,
|
||||
registerProvider,
|
||||
validateProviderIdRegistry
|
||||
registerProviderConfig,
|
||||
resolveProviderConfigId
|
||||
} from '../registry'
|
||||
import type { DynamicProviderRegistration, ProviderConfig } from '../schemas'
|
||||
import type { ProviderConfig } from '../schemas'
|
||||
|
||||
describe('AiProviderRegistry 功能测试', () => {
|
||||
describe('Provider Registry 功能测试', () => {
|
||||
beforeEach(() => {
|
||||
// 清理状态
|
||||
cleanup()
|
||||
})
|
||||
|
||||
describe('基础功能', () => {
|
||||
it('能够获取所有 providers', () => {
|
||||
const providers = getAllProviders()
|
||||
it('能够获取支持的 providers 列表', () => {
|
||||
const providers = getSupportedProviders()
|
||||
expect(Array.isArray(providers)).toBe(true)
|
||||
expect(providers.length).toBeGreaterThan(0)
|
||||
|
||||
// 检查返回的数据结构
|
||||
providers.forEach((provider) => {
|
||||
expect(provider).toHaveProperty('id')
|
||||
expect(provider).toHaveProperty('name')
|
||||
expect(typeof provider.id).toBe('string')
|
||||
expect(typeof provider.name).toBe('string')
|
||||
})
|
||||
|
||||
// 包含基础 providers
|
||||
const providerIds = providers.map((p) => p.id)
|
||||
expect(providerIds).toContain('openai')
|
||||
@ -70,74 +85,57 @@ describe('AiProviderRegistry 功能测试', () => {
|
||||
expect(providerIds).toContain('google')
|
||||
})
|
||||
|
||||
it('能够检查 provider 支持状态', () => {
|
||||
expect(isProviderSupported('openai')).toBe(true)
|
||||
expect(isProviderSupported('anthropic')).toBe(true)
|
||||
expect(isProviderSupported('google')).toBe(true)
|
||||
expect(isProviderSupported('non-existent')).toBe(true) // validateProviderId 通过
|
||||
expect(isProviderSupported('')).toBe(false)
|
||||
it('能够获取已初始化的 providers', () => {
|
||||
// 初始状态下没有已初始化的 providers
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasInitializedProviders()).toBe(false)
|
||||
})
|
||||
|
||||
it('能够获取有效的 provider IDs', () => {
|
||||
const allIds = getAllValidProviderIds()
|
||||
expect(Array.isArray(allIds)).toBe(true)
|
||||
expect(allIds).toContain('openai')
|
||||
expect(allIds).toContain('anthropic')
|
||||
it('能够访问全局注册管理器', () => {
|
||||
expect(providerRegistry).toBeDefined()
|
||||
expect(typeof providerRegistry.clear).toBe('function')
|
||||
expect(typeof providerRegistry.getRegisteredProviders).toBe('function')
|
||||
expect(typeof providerRegistry.hasProviders).toBe('function')
|
||||
})
|
||||
|
||||
it('能够根据 ID 获取特定的 provider', () => {
|
||||
// 获取存在的 provider
|
||||
const openaiProvider = getProvider('openai')
|
||||
expect(openaiProvider).toBeDefined()
|
||||
expect(openaiProvider?.id).toBe('openai')
|
||||
expect(openaiProvider?.name).toBe('OpenAI')
|
||||
|
||||
// 获取不存在的 provider,fallback到openai-compatible
|
||||
const nonExistentProvider = getProvider('non-existent')
|
||||
expect(nonExistentProvider).toBeDefined()
|
||||
expect(nonExistentProvider?.id).toBe('openai-compatible')
|
||||
})
|
||||
|
||||
it('能够验证 provider ID', () => {
|
||||
expect(validateProviderIdRegistry('valid-id')).toBe(true)
|
||||
expect(validateProviderIdRegistry('another-valid-id')).toBe(true)
|
||||
expect(validateProviderIdRegistry('')).toBe(false)
|
||||
// 注意:单个空格字符被认为是有效的,因为它不是空字符串
|
||||
// 如果需要更严格的验证,schemas 包含更多验证规则
|
||||
expect(validateProviderIdRegistry(' ')).toBe(true)
|
||||
it('能够获取语言模型', () => {
|
||||
// 在没有注册 provider 的情况下,这个函数可能会抛出错误或返回 undefined
|
||||
expect(() => getLanguageModel('non-existent')).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('动态 Provider 注册', () => {
|
||||
it('能够注册动态 provider', () => {
|
||||
const config = {
|
||||
describe('Provider 配置注册', () => {
|
||||
it('能够注册自定义 provider 配置', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(() => ({ name: 'custom' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerDynamicProvider(config)
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
expect(isDynamicProvider('custom-provider')).toBe(true)
|
||||
expect(isProviderSupported('custom-provider')).toBe(true)
|
||||
|
||||
const allIds = getAllValidProviderIds()
|
||||
expect(allIds).toContain('custom-provider')
|
||||
expect(hasProviderConfig('custom-provider')).toBe(true)
|
||||
expect(getProviderConfig('custom-provider')).toEqual(config)
|
||||
})
|
||||
|
||||
it('拒绝与基础 provider 冲突的配置', () => {
|
||||
const config = {
|
||||
id: 'openai',
|
||||
name: 'Duplicate OpenAI',
|
||||
creator: vi.fn(() => ({ name: 'duplicate' })),
|
||||
supportsImageGeneration: false
|
||||
it('能够注册带别名的 provider 配置', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'custom-provider-with-aliases',
|
||||
name: 'Custom Provider with Aliases',
|
||||
creator: vi.fn(() => ({ name: 'custom-aliased' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['alias-1', 'alias-2']
|
||||
}
|
||||
|
||||
const success = registerDynamicProvider(config)
|
||||
expect(success).toBe(false)
|
||||
expect(isDynamicProvider('openai')).toBe(false)
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
|
||||
expect(getProviderConfigByAlias('alias-1')).toEqual(config)
|
||||
expect(resolveProviderConfigId('alias-1')).toBe('custom-provider-with-aliases')
|
||||
})
|
||||
|
||||
it('拒绝无效的配置', () => {
|
||||
@ -147,12 +145,12 @@ describe('AiProviderRegistry 功能测试', () => {
|
||||
// 缺少 name, creator 等
|
||||
}
|
||||
|
||||
const success = registerDynamicProvider(invalidConfig as any)
|
||||
const success = registerProviderConfig(invalidConfig as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('能够批量注册动态 providers', () => {
|
||||
const configs: DynamicProviderRegistration[] = [
|
||||
it('能够批量注册 provider 配置', () => {
|
||||
const configs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'provider-1',
|
||||
name: 'Provider 1',
|
||||
@ -166,146 +164,146 @@ describe('AiProviderRegistry 功能测试', () => {
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'openai', // 这个失败,因为与基础 provider 冲突
|
||||
id: '', // 无效配置
|
||||
name: 'Invalid Provider',
|
||||
creator: vi.fn(() => ({ name: 'invalid' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
} as any
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviders(configs)
|
||||
const successCount = registerMultipleProviderConfigs(configs)
|
||||
expect(successCount).toBe(2) // 只有前两个成功
|
||||
|
||||
expect(isDynamicProvider('provider-1')).toBe(true)
|
||||
expect(isDynamicProvider('provider-2')).toBe(true)
|
||||
expect(isDynamicProvider('openai')).toBe(false) // 基础 provider,不是动态的
|
||||
expect(hasProviderConfig('provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('provider-2')).toBe(true)
|
||||
expect(hasProviderConfig('')).toBe(false)
|
||||
})
|
||||
|
||||
it('支持带映射关系的动态 provider', () => {
|
||||
const configWithMappings: DynamicProviderRegistration = {
|
||||
id: 'custom-provider-with-mappings',
|
||||
name: 'Custom Provider with Mappings',
|
||||
creator: vi.fn(() => ({ name: 'custom-mapped' })),
|
||||
it('能够获取所有配置和别名信息', () => {
|
||||
// 注册一些配置
|
||||
registerProviderConfig({
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: false,
|
||||
mappings: {
|
||||
'custom-alias-1': 'custom-provider-with-mappings',
|
||||
'custom-alias-2': 'custom-provider-with-mappings'
|
||||
}
|
||||
aliases: ['test-alias']
|
||||
})
|
||||
|
||||
const allConfigs = getAllProviderConfigs()
|
||||
expect(Array.isArray(allConfigs)).toBe(true)
|
||||
expect(allConfigs.some((config) => config.id === 'test-provider')).toBe(true)
|
||||
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
expect(aliases['test-alias']).toBe('test-provider')
|
||||
expect(isProviderConfigAlias('test-alias')).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider 创建和注册', () => {
|
||||
it('能够创建 provider 实例', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-create-provider',
|
||||
name: 'Test Create Provider',
|
||||
creator: vi.fn(() => ({ name: 'test-created' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerDynamicProvider(configWithMappings)
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 创建 provider 实例
|
||||
const provider = await createProvider('test-create-provider', { apiKey: 'test' })
|
||||
expect(provider).toBeDefined()
|
||||
expect(config.creator).toHaveBeenCalledWith({ apiKey: 'test' })
|
||||
})
|
||||
|
||||
it('能够注册 provider 到全局管理器', () => {
|
||||
const mockProvider = { name: 'mock-provider' }
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-register-provider',
|
||||
name: 'Test Register Provider',
|
||||
creator: vi.fn(() => mockProvider),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 注册 provider 到全局管理器
|
||||
const success = registerProvider('test-register-provider', mockProvider)
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证映射关系
|
||||
expect(getProviderMapping('custom-alias-1')).toBe('custom-provider-with-mappings')
|
||||
expect(getProviderMapping('custom-alias-2')).toBe('custom-provider-with-mappings')
|
||||
expect(getProviderMapping('custom-provider-with-mappings')).toBe('custom-provider-with-mappings')
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('test-register-provider')
|
||||
expect(hasInitializedProviders()).toBe(true)
|
||||
})
|
||||
|
||||
it('能够一步完成创建和注册', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-create-and-register',
|
||||
name: 'Test Create and Register',
|
||||
creator: vi.fn(() => ({ name: 'test-both' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 一步完成创建和注册
|
||||
const success = await createAndRegisterProvider('test-create-and-register', { apiKey: 'test' })
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('test-create-and-register')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Registry 管理', () => {
|
||||
it('能够清理动态 providers', () => {
|
||||
// 注册动态 provider
|
||||
registerDynamicProvider({
|
||||
it('能够清理所有配置和注册的 providers', () => {
|
||||
// 注册一些配置
|
||||
registerProviderConfig({
|
||||
id: 'temp-provider',
|
||||
name: 'Temp Provider',
|
||||
creator: vi.fn(() => ({ name: 'temp' })),
|
||||
supportsImageGeneration: false
|
||||
})
|
||||
|
||||
expect(isDynamicProvider('temp-provider')).toBe(true)
|
||||
expect(hasProviderConfig('temp-provider')).toBe(true)
|
||||
|
||||
// 清理
|
||||
cleanup()
|
||||
|
||||
expect(isDynamicProvider('temp-provider')).toBe(false)
|
||||
expect(isProviderSupported('openai')).toBe(true) // 基础 providers 仍存在
|
||||
expect(hasProviderConfig('temp-provider')).toBe(false)
|
||||
// 但基础配置应该重新加载
|
||||
expect(hasProviderConfig('openai')).toBe(true) // 基础 providers 会重新初始化
|
||||
})
|
||||
|
||||
it('保持单例模式', () => {
|
||||
const instance1 = AiProviderRegistry.getInstance()
|
||||
const instance2 = AiProviderRegistry.getInstance()
|
||||
expect(instance1).toBe(instance2)
|
||||
it('能够单独清理已注册的 providers', () => {
|
||||
// 清理所有 providers
|
||||
clearAllProviders()
|
||||
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasInitializedProviders()).toBe(false)
|
||||
})
|
||||
|
||||
it('能够注册基础 provider', () => {
|
||||
const customConfig: ProviderConfig = {
|
||||
id: 'custom-base-provider',
|
||||
name: 'Custom Base Provider',
|
||||
creator: vi.fn(() => ({ name: 'custom-base' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 注册基础 provider 不抛出错误
|
||||
expect(() => registerProvider(customConfig)).not.toThrow()
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProvider = getProvider('custom-base-provider')
|
||||
expect(registeredProvider).toBeDefined()
|
||||
expect(registeredProvider?.id).toBe('custom-base-provider')
|
||||
expect(registeredProvider?.name).toBe('Custom Base Provider')
|
||||
})
|
||||
|
||||
it('能够获取动态 providers 列表', () => {
|
||||
// 初始状态没有动态 providers
|
||||
expect(getDynamicProviders()).toEqual([])
|
||||
|
||||
// 注册一些动态 providers
|
||||
registerDynamicProvider({
|
||||
id: 'dynamic-1',
|
||||
name: 'Dynamic 1',
|
||||
creator: vi.fn(() => ({ name: 'dynamic-1' })),
|
||||
supportsImageGeneration: false
|
||||
})
|
||||
|
||||
registerDynamicProvider({
|
||||
id: 'dynamic-2',
|
||||
name: 'Dynamic 2',
|
||||
creator: vi.fn(() => ({ name: 'dynamic-2' })),
|
||||
supportsImageGeneration: true
|
||||
})
|
||||
|
||||
const dynamicProviders = getDynamicProviders()
|
||||
expect(Array.isArray(dynamicProviders)).toBe(true)
|
||||
expect(dynamicProviders).toContain('dynamic-1')
|
||||
expect(dynamicProviders).toContain('dynamic-2')
|
||||
expect(dynamicProviders.length).toBe(2)
|
||||
})
|
||||
|
||||
it('能够获取所有动态映射', () => {
|
||||
// 初始状态没有动态映射
|
||||
expect(getAllDynamicMappings()).toEqual({})
|
||||
|
||||
// 注册带映射的动态 provider
|
||||
registerDynamicProvider({
|
||||
id: 'mapped-provider',
|
||||
name: 'Mapped Provider',
|
||||
creator: vi.fn(() => ({ name: 'mapped' })),
|
||||
supportsImageGeneration: false,
|
||||
mappings: {
|
||||
'alias-1': 'mapped-provider',
|
||||
'alias-2': 'mapped-provider',
|
||||
'custom-name': 'mapped-provider'
|
||||
}
|
||||
})
|
||||
|
||||
const allMappings = getAllDynamicMappings()
|
||||
expect(allMappings).toEqual({
|
||||
'alias-1': 'mapped-provider',
|
||||
'alias-2': 'mapped-provider',
|
||||
'custom-name': 'mapped-provider'
|
||||
})
|
||||
it('ProviderInitializationError 错误类工作正常', () => {
|
||||
const error = new ProviderInitializationError('Test error', 'test-provider')
|
||||
expect(error.message).toBe('Test error')
|
||||
expect(error.providerId).toBe('test-provider')
|
||||
expect(error.name).toBe('ProviderInitializationError')
|
||||
})
|
||||
})
|
||||
|
||||
describe('错误处理', () => {
|
||||
it('优雅处理空配置', () => {
|
||||
const success = registerDynamicProvider(null as any)
|
||||
const success = registerProviderConfig(null as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('优雅处理未定义配置', () => {
|
||||
const success = registerDynamicProvider(undefined as any)
|
||||
const success = registerProviderConfig(undefined as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
@ -317,27 +315,31 @@ describe('AiProviderRegistry 功能测试', () => {
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerDynamicProvider(config)
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理注册基础 provider 时的无效 ID', () => {
|
||||
const invalidConfig: ProviderConfig = {
|
||||
id: '', // 无效 ID
|
||||
name: 'Invalid Provider',
|
||||
creator: vi.fn(() => ({ name: 'invalid' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
expect(() => registerProvider(invalidConfig)).toThrow('Invalid provider ID:')
|
||||
it('处理创建不存在配置的 provider', async () => {
|
||||
await expect(createProvider('non-existent-provider', {})).rejects.toThrow(
|
||||
'ProviderConfig not found for id: non-existent-provider'
|
||||
)
|
||||
})
|
||||
|
||||
it('处理获取不存在映射时的情况', () => {
|
||||
expect(getProviderMapping('non-existent-mapping')).toBeUndefined()
|
||||
it('处理注册不存在配置的 provider', () => {
|
||||
const mockProvider = { name: 'mock' }
|
||||
const success = registerProvider('non-existent-provider', mockProvider)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理获取不存在配置的情况', () => {
|
||||
expect(getProviderConfig('non-existent')).toBeUndefined()
|
||||
expect(getProviderConfigByAlias('non-existent-alias')).toBeUndefined()
|
||||
expect(hasProviderConfig('non-existent')).toBe(false)
|
||||
expect(hasProviderConfigByAlias('non-existent-alias')).toBe(false)
|
||||
})
|
||||
|
||||
it('处理批量注册时的部分失败', () => {
|
||||
const mixedConfigs: DynamicProviderRegistration[] = [
|
||||
const mixedConfigs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'valid-provider-1',
|
||||
name: 'Valid Provider 1',
|
||||
@ -349,7 +351,7 @@ describe('AiProviderRegistry 功能测试', () => {
|
||||
name: 'Invalid Provider',
|
||||
creator: vi.fn(() => ({ name: 'invalid' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
} as any,
|
||||
{
|
||||
id: 'valid-provider-2',
|
||||
name: 'Valid Provider 2',
|
||||
@ -358,153 +360,168 @@ describe('AiProviderRegistry 功能测试', () => {
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviders(mixedConfigs)
|
||||
const successCount = registerMultipleProviderConfigs(mixedConfigs)
|
||||
expect(successCount).toBe(2) // 只有两个有效配置成功
|
||||
|
||||
expect(isDynamicProvider('valid-provider-1')).toBe(true)
|
||||
expect(isDynamicProvider('valid-provider-2')).toBe(true)
|
||||
expect(getDynamicProviders()).not.toContain('')
|
||||
expect(hasProviderConfig('valid-provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('valid-provider-2')).toBe(true)
|
||||
expect(hasProviderConfig('')).toBe(false)
|
||||
})
|
||||
|
||||
it('处理动态导入失败的情况', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'import-test-provider',
|
||||
name: 'Import Test Provider',
|
||||
import: vi.fn().mockRejectedValue(new Error('Import failed')),
|
||||
creatorFunctionName: 'createTest',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
registerProviderConfig(config)
|
||||
|
||||
await expect(createProvider('import-test-provider', {})).rejects.toThrow('Import failed')
|
||||
})
|
||||
})
|
||||
|
||||
describe('集成测试', () => {
|
||||
it('正确处理复杂的注册、映射和清理场景', () => {
|
||||
it('正确处理复杂的配置、创建、注册和清理场景', async () => {
|
||||
// 初始状态验证
|
||||
const initialProviders = getAllProviders()
|
||||
const initialIds = getAllValidProviderIds()
|
||||
expect(initialProviders.length).toBeGreaterThan(0)
|
||||
expect(getDynamicProviders()).toEqual([])
|
||||
expect(getAllDynamicMappings()).toEqual({})
|
||||
const initialConfigs = getAllProviderConfigs()
|
||||
expect(initialConfigs.length).toBeGreaterThan(0) // 有基础配置
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
|
||||
// 注册多个带映射的动态 providers
|
||||
const configs: DynamicProviderRegistration[] = [
|
||||
// 注册多个带别名的 provider 配置
|
||||
const configs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'integration-provider-1',
|
||||
name: 'Integration Provider 1',
|
||||
creator: vi.fn(() => ({ name: 'integration-1' })),
|
||||
supportsImageGeneration: false,
|
||||
mappings: {
|
||||
'alias-1': 'integration-provider-1',
|
||||
'short-name-1': 'integration-provider-1'
|
||||
}
|
||||
aliases: ['alias-1', 'short-name-1']
|
||||
},
|
||||
{
|
||||
id: 'integration-provider-2',
|
||||
name: 'Integration Provider 2',
|
||||
creator: vi.fn(() => ({ name: 'integration-2' })),
|
||||
supportsImageGeneration: true,
|
||||
mappings: {
|
||||
'alias-2': 'integration-provider-2',
|
||||
'short-name-2': 'integration-provider-2'
|
||||
}
|
||||
aliases: ['alias-2', 'short-name-2']
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviders(configs)
|
||||
const successCount = registerMultipleProviderConfigs(configs)
|
||||
expect(successCount).toBe(2)
|
||||
|
||||
// 验证注册后的状态
|
||||
const afterRegisterProviders = getAllProviders()
|
||||
const afterRegisterIds = getAllValidProviderIds()
|
||||
expect(afterRegisterProviders.length).toBe(initialProviders.length + 2)
|
||||
expect(afterRegisterIds.length).toBeGreaterThanOrEqual(initialIds.length + 2)
|
||||
// 验证配置注册成功
|
||||
expect(hasProviderConfig('integration-provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('integration-provider-2')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
|
||||
|
||||
// 验证动态 providers
|
||||
const dynamicProviders = getDynamicProviders()
|
||||
expect(dynamicProviders).toContain('integration-provider-1')
|
||||
expect(dynamicProviders).toContain('integration-provider-2')
|
||||
// 验证别名映射
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
expect(aliases['alias-1']).toBe('integration-provider-1')
|
||||
expect(aliases['alias-2']).toBe('integration-provider-2')
|
||||
|
||||
// 验证映射
|
||||
const mappings = getAllDynamicMappings()
|
||||
expect(mappings['alias-1']).toBe('integration-provider-1')
|
||||
expect(mappings['alias-2']).toBe('integration-provider-2')
|
||||
expect(mappings['short-name-1']).toBe('integration-provider-1')
|
||||
expect(mappings['short-name-2']).toBe('integration-provider-2')
|
||||
// 创建和注册 providers
|
||||
const success1 = await createAndRegisterProvider('integration-provider-1', { apiKey: 'test1' })
|
||||
const success2 = await createAndRegisterProvider('integration-provider-2', { apiKey: 'test2' })
|
||||
expect(success1).toBe(true)
|
||||
expect(success2).toBe(true)
|
||||
|
||||
// 验证通过映射能够获取 provider
|
||||
expect(getProviderMapping('alias-1')).toBe('integration-provider-1')
|
||||
expect(getProviderMapping('integration-provider-1')).toBe('integration-provider-1')
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('integration-provider-1')
|
||||
expect(registeredProviders).toContain('integration-provider-2')
|
||||
expect(hasInitializedProviders()).toBe(true)
|
||||
|
||||
// 清理
|
||||
cleanup()
|
||||
|
||||
// 验证清理后的状态
|
||||
const afterCleanupProviders = getAllProviders()
|
||||
const afterCleanupIds = getAllValidProviderIds()
|
||||
expect(afterCleanupProviders.length).toBe(initialProviders.length)
|
||||
expect(afterCleanupIds.length).toBe(initialIds.length)
|
||||
expect(getDynamicProviders()).toEqual([])
|
||||
expect(getAllDynamicMappings()).toEqual({})
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasProviderConfig('integration-provider-1')).toBe(false)
|
||||
expect(hasProviderConfig('integration-provider-2')).toBe(false)
|
||||
expect(getAllProviderConfigAliases()).toEqual({})
|
||||
|
||||
// 基础配置应该重新加载
|
||||
expect(hasProviderConfig('openai')).toBe(true)
|
||||
})
|
||||
|
||||
it('正确处理 provider 的优先级和 fallback 机制', () => {
|
||||
// 验证 getProvider 的 fallback 机制
|
||||
const existingProvider = getProvider('openai')
|
||||
expect(existingProvider?.id).toBe('openai')
|
||||
|
||||
const nonExistentProvider = getProvider('definitely-non-existent')
|
||||
expect(nonExistentProvider?.id).toBe('openai-compatible') // fallback
|
||||
|
||||
// 注册自定义 provider 后能直接获取
|
||||
registerDynamicProvider({
|
||||
id: 'priority-test-provider',
|
||||
name: 'Priority Test Provider',
|
||||
creator: vi.fn(() => ({ name: 'priority-test' })),
|
||||
it('正确处理动态导入配置的 provider', async () => {
|
||||
const mockModule = { createCustomProvider: vi.fn(() => ({ name: 'custom-dynamic' })) }
|
||||
const dynamicImportConfig: ProviderConfig = {
|
||||
id: 'dynamic-import-provider',
|
||||
name: 'Dynamic Import Provider',
|
||||
import: vi.fn().mockResolvedValue(mockModule),
|
||||
creatorFunctionName: 'createCustomProvider',
|
||||
supportsImageGeneration: false
|
||||
})
|
||||
}
|
||||
|
||||
const customProvider = getProvider('priority-test-provider')
|
||||
expect(customProvider?.id).toBe('priority-test-provider')
|
||||
expect(customProvider?.name).toBe('Priority Test Provider')
|
||||
// 注册配置
|
||||
const configSuccess = registerProviderConfig(dynamicImportConfig)
|
||||
expect(configSuccess).toBe(true)
|
||||
|
||||
// 创建和注册 provider
|
||||
const registerSuccess = await createAndRegisterProvider('dynamic-import-provider', { apiKey: 'test' })
|
||||
expect(registerSuccess).toBe(true)
|
||||
|
||||
// 验证导入函数被调用
|
||||
expect(dynamicImportConfig.import).toHaveBeenCalled()
|
||||
expect(mockModule.createCustomProvider).toHaveBeenCalledWith({ apiKey: 'test' })
|
||||
|
||||
// 验证注册成功
|
||||
expect(getInitializedProviders()).toContain('dynamic-import-provider')
|
||||
})
|
||||
|
||||
it('正确处理大量动态 providers 的注册和管理', () => {
|
||||
const largeConfigList: DynamicProviderRegistration[] = []
|
||||
it('正确处理大量配置的注册和管理', () => {
|
||||
const largeConfigList: ProviderConfig[] = []
|
||||
|
||||
// 生成100个动态 providers
|
||||
for (let i = 0; i < 100; i++) {
|
||||
// 生成50个配置
|
||||
for (let i = 0; i < 50; i++) {
|
||||
largeConfigList.push({
|
||||
id: `bulk-provider-${i}`,
|
||||
name: `Bulk Provider ${i}`,
|
||||
creator: vi.fn(() => ({ name: `bulk-${i}` })),
|
||||
supportsImageGeneration: i % 2 === 0, // 偶数支持图像生成
|
||||
mappings: {
|
||||
[`alias-${i}`]: `bulk-provider-${i}`,
|
||||
[`short-${i}`]: `bulk-provider-${i}`
|
||||
}
|
||||
aliases: [`alias-${i}`, `short-${i}`]
|
||||
})
|
||||
}
|
||||
|
||||
const successCount = registerMultipleProviders(largeConfigList)
|
||||
expect(successCount).toBe(100)
|
||||
const successCount = registerMultipleProviderConfigs(largeConfigList)
|
||||
expect(successCount).toBe(50)
|
||||
|
||||
// 验证所有 providers 都被正确注册
|
||||
const dynamicProviders = getDynamicProviders()
|
||||
expect(dynamicProviders.length).toBe(100)
|
||||
// 验证所有配置都被正确注册
|
||||
const allConfigs = getAllProviderConfigs()
|
||||
expect(allConfigs.filter((config) => config.id.startsWith('bulk-provider-')).length).toBe(50)
|
||||
|
||||
// 验证映射数量
|
||||
const mappings = getAllDynamicMappings()
|
||||
expect(Object.keys(mappings).length).toBe(200) // 每个 provider 有2个映射
|
||||
// 验证别名数量
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
const bulkAliases = Object.keys(aliases).filter(
|
||||
(alias) => alias.startsWith('alias-') || alias.startsWith('short-')
|
||||
)
|
||||
expect(bulkAliases.length).toBe(100) // 每个 provider 有2个别名
|
||||
|
||||
// 随机验证几个 providers
|
||||
expect(isDynamicProvider('bulk-provider-0')).toBe(true)
|
||||
expect(isDynamicProvider('bulk-provider-50')).toBe(true)
|
||||
expect(isDynamicProvider('bulk-provider-99')).toBe(true)
|
||||
// 随机验证几个配置
|
||||
expect(hasProviderConfig('bulk-provider-0')).toBe(true)
|
||||
expect(hasProviderConfig('bulk-provider-25')).toBe(true)
|
||||
expect(hasProviderConfig('bulk-provider-49')).toBe(true)
|
||||
|
||||
// 验证映射工作正常
|
||||
expect(getProviderMapping('alias-25')).toBe('bulk-provider-25')
|
||||
expect(getProviderMapping('short-75')).toBe('bulk-provider-75')
|
||||
// 验证别名工作正常
|
||||
expect(resolveProviderConfigId('alias-25')).toBe('bulk-provider-25')
|
||||
expect(isProviderConfigAlias('short-30')).toBe(true)
|
||||
|
||||
// 清理能正确处理大量数据
|
||||
cleanup()
|
||||
expect(getDynamicProviders()).toEqual([])
|
||||
expect(getAllDynamicMappings()).toEqual({})
|
||||
const cleanupAliases = getAllProviderConfigAliases()
|
||||
expect(
|
||||
Object.keys(cleanupAliases).filter((alias) => alias.startsWith('alias-') || alias.startsWith('short-'))
|
||||
).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('边界测试', () => {
|
||||
it('处理包含特殊字符的 provider IDs', () => {
|
||||
const specialCharsConfigs: DynamicProviderRegistration[] = [
|
||||
const specialCharsConfigs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'provider-with-dashes',
|
||||
name: 'Provider With Dashes',
|
||||
@ -525,26 +542,29 @@ describe('AiProviderRegistry 功能测试', () => {
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviders(specialCharsConfigs)
|
||||
const successCount = registerMultipleProviderConfigs(specialCharsConfigs)
|
||||
expect(successCount).toBeGreaterThan(0) // 至少有一些成功
|
||||
|
||||
// 验证支持的特殊字符格式
|
||||
if (isDynamicProvider('provider-with-dashes')) {
|
||||
expect(getProvider('provider-with-dashes')).toBeDefined()
|
||||
if (hasProviderConfig('provider-with-dashes')) {
|
||||
expect(getProviderConfig('provider-with-dashes')).toBeDefined()
|
||||
}
|
||||
if (isDynamicProvider('provider_with_underscores')) {
|
||||
expect(getProvider('provider_with_underscores')).toBeDefined()
|
||||
if (hasProviderConfig('provider_with_underscores')) {
|
||||
expect(getProviderConfig('provider_with_underscores')).toBeDefined()
|
||||
}
|
||||
})
|
||||
|
||||
it('处理空的批量注册', () => {
|
||||
const successCount = registerMultipleProviders([])
|
||||
const successCount = registerMultipleProviderConfigs([])
|
||||
expect(successCount).toBe(0)
|
||||
expect(getDynamicProviders()).toEqual([])
|
||||
|
||||
// 确保没有额外的配置被添加
|
||||
const configsBefore = getAllProviderConfigs().length
|
||||
expect(configsBefore).toBeGreaterThan(0) // 应该有基础配置
|
||||
})
|
||||
|
||||
it('处理重复的 provider 注册', () => {
|
||||
const config: DynamicProviderRegistration = {
|
||||
it('处理重复的配置注册', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'duplicate-test-provider',
|
||||
name: 'Duplicate Test Provider',
|
||||
creator: vi.fn(() => ({ name: 'duplicate' })),
|
||||
@ -552,17 +572,61 @@ describe('AiProviderRegistry 功能测试', () => {
|
||||
}
|
||||
|
||||
// 第一次注册成功
|
||||
expect(registerDynamicProvider(config)).toBe(true)
|
||||
expect(isDynamicProvider('duplicate-test-provider')).toBe(true)
|
||||
expect(registerProviderConfig(config)).toBe(true)
|
||||
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
|
||||
|
||||
// 重复注册相同的 provider
|
||||
expect(registerDynamicProvider(config)).toBe(true) // 允许覆盖
|
||||
expect(isDynamicProvider('duplicate-test-provider')).toBe(true)
|
||||
// 重复注册相同的配置(允许覆盖)
|
||||
const updatedConfig: ProviderConfig = {
|
||||
...config,
|
||||
name: 'Updated Duplicate Test Provider'
|
||||
}
|
||||
expect(registerProviderConfig(updatedConfig)).toBe(true)
|
||||
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
|
||||
|
||||
// 验证只有一个实例
|
||||
const dynamicProviders = getDynamicProviders()
|
||||
const duplicateCount = dynamicProviders.filter((id) => id === 'duplicate-test-provider').length
|
||||
expect(duplicateCount).toBe(1)
|
||||
// 验证配置被更新
|
||||
const retrievedConfig = getProviderConfig('duplicate-test-provider')
|
||||
expect(retrievedConfig?.name).toBe('Updated Duplicate Test Provider')
|
||||
})
|
||||
|
||||
it('处理极长的 ID 和名称', () => {
|
||||
const longId = 'very-long-provider-id-' + 'x'.repeat(100)
|
||||
const longName = 'Very Long Provider Name ' + 'Y'.repeat(100)
|
||||
|
||||
const config: ProviderConfig = {
|
||||
id: longId,
|
||||
name: longName,
|
||||
creator: vi.fn(() => ({ name: 'long-test' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
expect(hasProviderConfig(longId)).toBe(true)
|
||||
|
||||
const retrievedConfig = getProviderConfig(longId)
|
||||
expect(retrievedConfig?.name).toBe(longName)
|
||||
})
|
||||
|
||||
it('处理大量别名的配置', () => {
|
||||
const manyAliases = Array.from({ length: 50 }, (_, i) => `alias-${i}`)
|
||||
|
||||
const config: ProviderConfig = {
|
||||
id: 'provider-with-many-aliases',
|
||||
name: 'Provider With Many Aliases',
|
||||
creator: vi.fn(() => ({ name: 'many-aliases' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: manyAliases
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证所有别名都能正确解析
|
||||
manyAliases.forEach((alias) => {
|
||||
expect(hasProviderConfigByAlias(alias)).toBe(true)
|
||||
expect(resolveProviderConfigId(alias)).toBe('provider-with-many-aliases')
|
||||
expect(isProviderConfigAlias(alias)).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -5,18 +5,11 @@ import {
|
||||
baseProviderIds,
|
||||
baseProviderIdSchema,
|
||||
baseProviders,
|
||||
type DynamicProviderId,
|
||||
dynamicProviderIdSchema,
|
||||
dynamicProviderRegistrationSchema,
|
||||
getBaseProviderConfig,
|
||||
isBaseProviderId,
|
||||
isValidDynamicProviderId,
|
||||
type CustomProviderId,
|
||||
customProviderIdSchema,
|
||||
providerConfigSchema,
|
||||
type ProviderId,
|
||||
providerIdSchema,
|
||||
validateDynamicProviderRegistration,
|
||||
validateProviderConfig,
|
||||
validateProviderId
|
||||
providerIdSchema
|
||||
} from '../schemas'
|
||||
|
||||
describe('Provider Schemas', () => {
|
||||
@ -90,22 +83,22 @@ describe('Provider Schemas', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('dynamicProviderIdSchema', () => {
|
||||
it('接受有效的动态 provider IDs', () => {
|
||||
describe('customProviderIdSchema', () => {
|
||||
it('接受有效的自定义 provider IDs', () => {
|
||||
const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2']
|
||||
validIds.forEach((id) => {
|
||||
expect(dynamicProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝与基础 provider IDs 冲突的 IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(dynamicProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝空字符串', () => {
|
||||
expect(dynamicProviderIdSchema.safeParse('').success).toBe(false)
|
||||
expect(customProviderIdSchema.safeParse('').success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
@ -116,9 +109,9 @@ describe('Provider Schemas', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('接受有效的动态 provider IDs', () => {
|
||||
const validDynamicIds = ['custom-provider', 'my-ai-service']
|
||||
validDynamicIds.forEach((id) => {
|
||||
it('接受有效的自定义 provider IDs', () => {
|
||||
const validCustomIds = ['custom-provider', 'my-ai-service']
|
||||
validCustomIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
@ -134,8 +127,8 @@ describe('Provider Schemas', () => {
|
||||
describe('providerConfigSchema', () => {
|
||||
it('验证带有 creator 的有效配置', () => {
|
||||
const validConfig = {
|
||||
id: 'openai',
|
||||
name: 'OpenAI',
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
@ -175,12 +168,21 @@ describe('Provider Schemas', () => {
|
||||
}
|
||||
})
|
||||
|
||||
it('拒绝使用基础 provider ID 的配置', () => {
|
||||
const invalidConfig = {
|
||||
id: 'openai', // 基础 provider ID
|
||||
name: 'Should Fail',
|
||||
creator: vi.fn()
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
|
||||
it('拒绝缺少必需字段的配置', () => {
|
||||
const invalidConfigs = [
|
||||
{ name: 'Missing ID', creator: vi.fn() },
|
||||
{ id: 'missing-name', creator: vi.fn() },
|
||||
{ id: '', name: 'Empty ID', creator: vi.fn() },
|
||||
{ id: 'valid', name: '', creator: vi.fn() }
|
||||
{ id: 'valid-custom', name: '', creator: vi.fn() }
|
||||
]
|
||||
|
||||
invalidConfigs.forEach((config) => {
|
||||
@ -189,186 +191,55 @@ describe('Provider Schemas', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('dynamicProviderRegistrationSchema', () => {
|
||||
it('验证有效的动态 provider 注册配置', () => {
|
||||
describe('Schema 验证功能', () => {
|
||||
it('baseProviderIdSchema 正确验证基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
expect(baseProviderIdSchema.safeParse('invalid-id').success).toBe(false)
|
||||
})
|
||||
|
||||
it('customProviderIdSchema 正确验证自定义 provider IDs', () => {
|
||||
const customIds = ['custom-provider', 'my-service', 'company-llm']
|
||||
customIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
// 拒绝基础 provider IDs
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('providerIdSchema 接受基础和自定义 provider IDs', () => {
|
||||
// 基础 IDs
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
// 自定义 IDs
|
||||
const customIds = ['custom-provider', 'my-service']
|
||||
customIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('providerConfigSchema 验证完整的 provider 配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: true,
|
||||
mappings: { model1: 'mapped-model1' }
|
||||
}
|
||||
expect(dynamicProviderRegistrationSchema.safeParse(validConfig).success).toBe(true)
|
||||
})
|
||||
|
||||
it('拒绝使用基础 provider ID 的配置', () => {
|
||||
const invalidConfig = {
|
||||
id: 'openai',
|
||||
name: 'Should Fail',
|
||||
creator: vi.fn()
|
||||
}
|
||||
expect(dynamicProviderRegistrationSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
|
||||
it('要求 creator 或 import 配置', () => {
|
||||
const configWithoutCreator = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider'
|
||||
}
|
||||
expect(dynamicProviderRegistrationSchema.safeParse(configWithoutCreator).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('validateProviderId', () => {
|
||||
it('验证基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(validateProviderId(id)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('验证有效的动态 provider IDs', () => {
|
||||
const validDynamicIds = ['custom-provider', 'my-service', 'company-llm']
|
||||
validDynamicIds.forEach((id) => {
|
||||
expect(validateProviderId(id)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝无效的 IDs', () => {
|
||||
const invalidIds = [undefined as any, null as any, 123 as any]
|
||||
invalidIds.forEach((id) => {
|
||||
expect(validateProviderId(id)).toBe(false)
|
||||
})
|
||||
|
||||
// 空字符串和只有空格的字符串会被当作有效的动态 provider ID
|
||||
expect(validateProviderId('')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isBaseProviderId', () => {
|
||||
it('正确识别基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(isBaseProviderId(id)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝动态 provider IDs', () => {
|
||||
const dynamicIds = ['custom-provider', 'my-service']
|
||||
dynamicIds.forEach((id) => {
|
||||
expect(isBaseProviderId(id)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝无效的 IDs', () => {
|
||||
const invalidIds = ['', 'invalid', undefined as any]
|
||||
invalidIds.forEach((id) => {
|
||||
expect(isBaseProviderId(id)).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('isValidDynamicProviderId', () => {
|
||||
it('接受有效的动态 provider IDs', () => {
|
||||
const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2']
|
||||
validIds.forEach((id) => {
|
||||
expect(isValidDynamicProviderId(id)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(isValidDynamicProviderId(id)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝无效的 IDs', () => {
|
||||
const invalidIds = [undefined as any, null as any]
|
||||
invalidIds.forEach((id) => {
|
||||
expect(isValidDynamicProviderId(id)).toBe(false)
|
||||
})
|
||||
|
||||
// 空字符串会被 schema 拒绝
|
||||
expect(isValidDynamicProviderId('')).toBe(false)
|
||||
// 只有空格的字符串是有效的动态 provider ID(但不推荐使用)
|
||||
expect(isValidDynamicProviderId(' ')).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('validateProviderConfig', () => {
|
||||
it('返回有效配置', () => {
|
||||
const validConfig = {
|
||||
id: 'openai',
|
||||
name: 'OpenAI',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
const result = validateProviderConfig(validConfig)
|
||||
expect(result).not.toBeNull()
|
||||
expect(result?.id).toBe('openai')
|
||||
expect(result?.name).toBe('OpenAI')
|
||||
})
|
||||
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||
|
||||
it('对无效配置返回 null', () => {
|
||||
const invalidConfig = {
|
||||
id: '',
|
||||
name: 'Invalid'
|
||||
id: 'openai', // 不允许基础 provider ID
|
||||
name: 'OpenAI',
|
||||
creator: vi.fn()
|
||||
}
|
||||
const result = validateProviderConfig(invalidConfig)
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it('处理完全无效的输入', () => {
|
||||
const invalidInputs = [undefined, null, 'string', 123, []]
|
||||
invalidInputs.forEach((input) => {
|
||||
const result = validateProviderConfig(input)
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('validateDynamicProviderRegistration', () => {
|
||||
it('返回有效的动态 provider 注册配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(),
|
||||
mappings: { model1: 'mapped-model1' }
|
||||
}
|
||||
const result = validateDynamicProviderRegistration(validConfig)
|
||||
expect(result).not.toBeNull()
|
||||
expect(result?.id).toBe('custom-provider')
|
||||
expect(result?.name).toBe('Custom Provider')
|
||||
})
|
||||
|
||||
it('对无效配置返回 null', () => {
|
||||
const invalidConfig = {
|
||||
id: 'openai',
|
||||
name: 'Should Fail'
|
||||
}
|
||||
const result = validateDynamicProviderRegistration(invalidConfig)
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getBaseProviderConfig', () => {
|
||||
it('返回有效基础 provider ID 的配置', () => {
|
||||
const config = getBaseProviderConfig('openai')
|
||||
expect(config).toBeDefined()
|
||||
expect(config?.id).toBe('openai')
|
||||
expect(config?.name).toBe('OpenAI')
|
||||
expect(config?.creator).toBeDefined()
|
||||
})
|
||||
|
||||
it('对无效 ID 返回 undefined', () => {
|
||||
const config = getBaseProviderConfig('invalid' as BaseProviderId)
|
||||
expect(config).toBeUndefined()
|
||||
})
|
||||
|
||||
it('返回所有基础 providers 的配置', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
const config = getBaseProviderConfig(id)
|
||||
expect(config).toBeDefined()
|
||||
expect(config?.id).toBe(id)
|
||||
})
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
@ -378,16 +249,16 @@ describe('Provider Schemas', () => {
|
||||
expect(baseProviderIds).toContain(id)
|
||||
})
|
||||
|
||||
it('DynamicProviderId 类型是字符串', () => {
|
||||
const id: DynamicProviderId = 'custom-provider'
|
||||
it('CustomProviderId 类型是字符串', () => {
|
||||
const id: CustomProviderId = 'custom-provider'
|
||||
expect(typeof id).toBe('string')
|
||||
})
|
||||
|
||||
it('ProviderId 类型支持基础和动态 IDs', () => {
|
||||
it('ProviderId 类型支持基础和自定义 IDs', () => {
|
||||
const baseId: ProviderId = 'openai'
|
||||
const dynamicId: ProviderId = 'custom-provider'
|
||||
const customId: ProviderId = 'custom-provider'
|
||||
expect(typeof baseId).toBe('string')
|
||||
expect(typeof dynamicId).toBe('string')
|
||||
expect(typeof customId).toBe('string')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -2,8 +2,8 @@ import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { experimental_generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { createImageModel } from '../../models/ModelCreator'
|
||||
import { type AiPlugin } from '../../plugins'
|
||||
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||
import { ImageGenerationError } from '../errors'
|
||||
import { RuntimeExecutor } from '../executor'
|
||||
|
||||
@ -19,8 +19,10 @@ vi.mock('ai', () => ({
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../models/ModelCreator', () => ({
|
||||
createImageModel: vi.fn()
|
||||
vi.mock('../../providers/RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
imageModel: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
describe('RuntimeExecutor.generateImage', () => {
|
||||
@ -67,8 +69,11 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
}
|
||||
|
||||
// Setup mocks
|
||||
vi.mocked(createImageModel).mockResolvedValue(mockImageModel)
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
|
||||
|
||||
// Reset mock implementation in case it was changed by previous tests
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => mockImageModel)
|
||||
})
|
||||
|
||||
describe('Basic functionality', () => {
|
||||
@ -77,7 +82,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
prompt: 'A futuristic cityscape at sunset'
|
||||
})
|
||||
|
||||
expect(createImageModel).toHaveBeenCalledWith('openai', 'dall-e-3', { apiKey: 'test-key' })
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai:dall-e-3')
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
@ -92,7 +97,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
prompt: 'A beautiful landscape'
|
||||
})
|
||||
|
||||
// Note: createImageModel may still be called due to resolveImageModel logic
|
||||
// Note: globalRegistryManagement.imageModel may still be called due to resolveImageModel logic
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A beautiful landscape'
|
||||
@ -335,8 +340,10 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
|
||||
describe('Error handling', () => {
|
||||
it('should handle model creation errors', async () => {
|
||||
const modelError = new Error('Failed to create image model')
|
||||
vi.mocked(createImageModel).mockRejectedValue(modelError)
|
||||
const modelError = new Error('Failed to get image model')
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||
throw modelError
|
||||
})
|
||||
|
||||
await expect(
|
||||
executor.generateImage('invalid-model', {
|
||||
@ -431,7 +438,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
prompt: 'A landscape'
|
||||
})
|
||||
|
||||
expect(createImageModel).toHaveBeenCalledWith('google', 'imagen-3.0-generate-002', { apiKey: 'google-key' })
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google:imagen-3.0-generate-002')
|
||||
})
|
||||
|
||||
it('should support xAI Grok image models', async () => {
|
||||
@ -443,7 +450,7 @@ describe('RuntimeExecutor.generateImage', () => {
|
||||
prompt: 'A futuristic robot'
|
||||
})
|
||||
|
||||
expect(createImageModel).toHaveBeenCalledWith('xai', 'grok-2-image', { apiKey: 'xai-key' })
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai:grok-2-image')
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -24,57 +24,11 @@ export { createContext, definePlugin, PluginManager } from './core/plugins'
|
||||
// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
|
||||
export { PluginEngine } from './core/runtime/pluginEngine'
|
||||
|
||||
// ==================== 类型定义 ====================
|
||||
export type { GenerateObjectParams, GenerateTextParams, StreamObjectParams, StreamTextParams } from './types'
|
||||
|
||||
// ==================== AI SDK 常用类型导出 ====================
|
||||
// 直接导出 AI SDK 的常用类型,方便使用
|
||||
export type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
|
||||
export type { ToolCall } from '@ai-sdk/provider-utils'
|
||||
export type { ReasoningPart } from '@ai-sdk/provider-utils'
|
||||
export type {
|
||||
AssistantModelMessage,
|
||||
FilePart,
|
||||
// 通用类型
|
||||
FinishReason,
|
||||
GenerateObjectResult,
|
||||
// 生成相关类型
|
||||
GenerateTextResult,
|
||||
ImagePart,
|
||||
InferToolInput,
|
||||
InferToolOutput,
|
||||
InvalidToolInputError,
|
||||
LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage
|
||||
// 消息相关类型
|
||||
ModelMessage,
|
||||
// 错误类型
|
||||
NoSuchToolError,
|
||||
ProviderMetadata,
|
||||
StreamTextResult,
|
||||
SystemModelMessage,
|
||||
TextPart,
|
||||
// 流相关类型
|
||||
TextStreamPart,
|
||||
// 工具相关类型
|
||||
Tool,
|
||||
ToolCallPart,
|
||||
ToolModelMessage,
|
||||
ToolResultPart,
|
||||
ToolSet,
|
||||
TypedToolCall,
|
||||
TypedToolError,
|
||||
TypedToolResult,
|
||||
UserModelMessage
|
||||
} from 'ai'
|
||||
export {
|
||||
defaultSettingsMiddleware,
|
||||
extractReasoningMiddleware,
|
||||
simulateStreamingMiddleware,
|
||||
smoothStream,
|
||||
stepCountIs
|
||||
} from 'ai'
|
||||
// 重新导出 Agent
|
||||
export { Experimental_Agent as Agent } from 'ai'
|
||||
|
||||
// ==================== 选项 ====================
|
||||
export {
|
||||
|
||||
@ -1,9 +1,2 @@
|
||||
import { generateObject, generateText, streamObject, streamText } from 'ai'
|
||||
|
||||
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
|
||||
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
|
||||
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
|
||||
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||
|
||||
// 重新导出插件类型
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'
|
||||
|
||||
@ -3,10 +3,10 @@
|
||||
* 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式
|
||||
*/
|
||||
|
||||
import { TextStreamPart, ToolSet } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { ToolCallChunkHandler } from './handleToolCallChunk'
|
||||
|
||||
|
||||
@ -4,10 +4,10 @@
|
||||
* 提供工具调用相关的处理API,每个交互使用一个新的实例
|
||||
*/
|
||||
|
||||
import { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai'
|
||||
// import type {
|
||||
// AnthropicSearchOutput,
|
||||
// WebSearchPluginConfig
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
* 3. 暂时保持接口兼容性
|
||||
*/
|
||||
|
||||
import { createExecutor, generateImage, StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { createExecutor, generateImage } from '@cherrystudio/ai-core'
|
||||
import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
|
||||
@ -23,6 +23,7 @@ import { CompletionsResult } from './legacy/middleware/schemas'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { getActualProvider, isModernSdkSupported, providerToAiSdkConfig } from './provider/ProviderConfigProcessor'
|
||||
import type { StreamTextParams } from './types'
|
||||
|
||||
const logger = loggerService.withContext('ModernAiProvider')
|
||||
|
||||
|
||||
@ -1,11 +1,7 @@
|
||||
import {
|
||||
extractReasoningMiddleware,
|
||||
LanguageModelV2Middleware,
|
||||
simulateStreamingMiddleware
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import type { MCPTool, Model, Provider } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { extractReasoningMiddleware, LanguageModelMiddleware, simulateStreamingMiddleware } from 'ai'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
|
||||
|
||||
@ -34,7 +30,7 @@ export interface AiSdkMiddlewareConfig {
|
||||
*/
|
||||
export interface NamedAiSdkMiddleware {
|
||||
name: string
|
||||
middleware: LanguageModelV2Middleware
|
||||
middleware: LanguageModelMiddleware
|
||||
}
|
||||
|
||||
/**
|
||||
@ -83,7 +79,7 @@ export class AiSdkMiddlewareBuilder {
|
||||
/**
|
||||
* 构建最终的中间件数组
|
||||
*/
|
||||
public build(): LanguageModelV2Middleware[] {
|
||||
public build(): LanguageModelMiddleware[] {
|
||||
return this.middlewares.map((m) => m.middleware)
|
||||
}
|
||||
|
||||
@ -114,7 +110,7 @@ export class AiSdkMiddlewareBuilder {
|
||||
* 根据配置构建AI SDK中间件的工厂函数
|
||||
* 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果
|
||||
*/
|
||||
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV2Middleware[] {
|
||||
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] {
|
||||
const builder = new AiSdkMiddlewareBuilder()
|
||||
|
||||
// 1. 根据provider添加特定中间件
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { definePlugin, TextStreamPart, ToolSet } from '@cherrystudio/ai-core'
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
export default definePlugin({
|
||||
name: 'reasoningTimePlugin',
|
||||
|
||||
@ -6,8 +6,7 @@
|
||||
* 2. transformParams: 根据意图分析结果动态添加对应的工具
|
||||
* 3. onRequestEnd: 自动记忆存储
|
||||
*/
|
||||
import type { AiRequestContext, ModelMessage } from '@cherrystudio/ai-core'
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import { type AiRequestContext, definePlugin } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
// import { generateObject } from '@cherrystudio/ai-core'
|
||||
import {
|
||||
@ -20,6 +19,7 @@ import store from '@renderer/store'
|
||||
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
||||
import type { Assistant } from '@renderer/types'
|
||||
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
|
||||
import type { ModelMessage } from 'ai'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
* 统一管理从各个 apiClient 提取的参数处理和转换功能
|
||||
*/
|
||||
|
||||
import { stepCountIs, type StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import {
|
||||
@ -31,8 +30,10 @@ import {
|
||||
} from '@renderer/utils/messageUtils/find'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import type { AssistantModelMessage, FilePart, ImagePart, ModelMessage, TextPart, UserModelMessage } from 'ai'
|
||||
import { stepCountIs } from 'ai'
|
||||
|
||||
import { getAiSdkProviderId } from './provider/factory'
|
||||
import type { StreamTextParams } from './types'
|
||||
// import { webSearchTool } from './tools/WebSearchTool'
|
||||
// import { jsonSchemaToZod } from 'json-schema-to-zod'
|
||||
import { setupToolsConfig } from './utils/mcp'
|
||||
|
||||
6
src/renderer/src/aiCore/types.ts
Normal file
6
src/renderer/src/aiCore/types.ts
Normal file
@ -0,0 +1,6 @@
|
||||
import { generateObject, generateText, streamObject, streamText } from 'ai'
|
||||
|
||||
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
|
||||
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
|
||||
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
|
||||
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||
@ -1,12 +1,12 @@
|
||||
/**
|
||||
* 职责:提供原子化的、无状态的API调用函数
|
||||
*/
|
||||
import { StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { CompletionsParams } from '@renderer/aiCore/legacy/middleware/schemas'
|
||||
import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/AiSdkMiddlewareBuilder'
|
||||
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
|
||||
import type { StreamTextParams } from '@renderer/aiCore/types'
|
||||
import { isDedicatedImageGenerationModel, isEmbeddingModel, isQwenMTModel } from '@renderer/config/models'
|
||||
import { LANG_DETECT_PROMPT } from '@renderer/config/prompts'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { convertMessagesToSdkMessages } from '@renderer/aiCore/transformParameters'
|
||||
import type { StreamTextParams } from '@renderer/aiCore/types'
|
||||
import { Assistant, Message } from '@renderer/types'
|
||||
import { filterAdjacentUserMessaegs, filterLastAssistantMessage } from '@renderer/utils/messageUtils/filters'
|
||||
import { isEmpty, takeRight } from 'lodash'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user