diff --git a/packages/aiCore/src/core/providers/__tests__/registry-functionality.test.ts b/packages/aiCore/src/core/providers/__tests__/registry-functionality.test.ts index cabbda15f1..ee31af2dd8 100644 --- a/packages/aiCore/src/core/providers/__tests__/registry-functionality.test.ts +++ b/packages/aiCore/src/core/providers/__tests__/registry-functionality.test.ts @@ -34,35 +34,50 @@ vi.mock('@ai-sdk/xai', () => ({ })) import { - AiProviderRegistry, cleanup, - getAllDynamicMappings, - getAllProviders, - getAllValidProviderIds, - getDynamicProviders, - getProvider, - getProviderMapping, - isDynamicProvider, - isProviderSupported, - registerDynamicProvider, - registerMultipleProviders, + clearAllProviders, + createAndRegisterProvider, + createProvider, + getAllProviderConfigAliases, + getAllProviderConfigs, + getInitializedProviders, + getLanguageModel, + getProviderConfig, + getProviderConfigByAlias, + getSupportedProviders, + hasInitializedProviders, + hasProviderConfig, + hasProviderConfigByAlias, + isProviderConfigAlias, + ProviderInitializationError, + providerRegistry, + registerMultipleProviderConfigs, registerProvider, - validateProviderIdRegistry + registerProviderConfig, + resolveProviderConfigId } from '../registry' -import type { DynamicProviderRegistration, ProviderConfig } from '../schemas' +import type { ProviderConfig } from '../schemas' -describe('AiProviderRegistry 功能测试', () => { +describe('Provider Registry 功能测试', () => { beforeEach(() => { // 清理状态 cleanup() }) describe('基础功能', () => { - it('能够获取所有 providers', () => { - const providers = getAllProviders() + it('能够获取支持的 providers 列表', () => { + const providers = getSupportedProviders() expect(Array.isArray(providers)).toBe(true) expect(providers.length).toBeGreaterThan(0) + // 检查返回的数据结构 + providers.forEach((provider) => { + expect(provider).toHaveProperty('id') + expect(provider).toHaveProperty('name') + expect(typeof provider.id).toBe('string') + expect(typeof provider.name).toBe('string') + }) + // 包含基础 providers const providerIds = providers.map((p) => p.id) expect(providerIds).toContain('openai') @@ -70,74 +85,57 @@ describe('AiProviderRegistry 功能测试', () => { expect(providerIds).toContain('google') }) - it('能够检查 provider 支持状态', () => { - expect(isProviderSupported('openai')).toBe(true) - expect(isProviderSupported('anthropic')).toBe(true) - expect(isProviderSupported('google')).toBe(true) - expect(isProviderSupported('non-existent')).toBe(true) // validateProviderId 通过 - expect(isProviderSupported('')).toBe(false) + it('能够获取已初始化的 providers', () => { + // 初始状态下没有已初始化的 providers + expect(getInitializedProviders()).toEqual([]) + expect(hasInitializedProviders()).toBe(false) }) - it('能够获取有效的 provider IDs', () => { - const allIds = getAllValidProviderIds() - expect(Array.isArray(allIds)).toBe(true) - expect(allIds).toContain('openai') - expect(allIds).toContain('anthropic') + it('能够访问全局注册管理器', () => { + expect(providerRegistry).toBeDefined() + expect(typeof providerRegistry.clear).toBe('function') + expect(typeof providerRegistry.getRegisteredProviders).toBe('function') + expect(typeof providerRegistry.hasProviders).toBe('function') }) - it('能够根据 ID 获取特定的 provider', () => { - // 获取存在的 provider - const openaiProvider = getProvider('openai') - expect(openaiProvider).toBeDefined() - expect(openaiProvider?.id).toBe('openai') - expect(openaiProvider?.name).toBe('OpenAI') - - // 获取不存在的 provider,fallback到openai-compatible - const nonExistentProvider = getProvider('non-existent') - expect(nonExistentProvider).toBeDefined() - expect(nonExistentProvider?.id).toBe('openai-compatible') - }) - - it('能够验证 provider ID', () => { - expect(validateProviderIdRegistry('valid-id')).toBe(true) - expect(validateProviderIdRegistry('another-valid-id')).toBe(true) - expect(validateProviderIdRegistry('')).toBe(false) - // 注意:单个空格字符被认为是有效的,因为它不是空字符串 - // 如果需要更严格的验证,schemas 包含更多验证规则 - expect(validateProviderIdRegistry(' ')).toBe(true) + it('能够获取语言模型', () => { + // 在没有注册 provider 的情况下,这个函数可能会抛出错误或返回 undefined + expect(() => getLanguageModel('non-existent')).not.toThrow() }) }) - describe('动态 Provider 注册', () => { - it('能够注册动态 provider', () => { - const config = { + describe('Provider 配置注册', () => { + it('能够注册自定义 provider 配置', () => { + const config: ProviderConfig = { id: 'custom-provider', name: 'Custom Provider', creator: vi.fn(() => ({ name: 'custom' })), supportsImageGeneration: false } - const success = registerDynamicProvider(config) + const success = registerProviderConfig(config) expect(success).toBe(true) - expect(isDynamicProvider('custom-provider')).toBe(true) - expect(isProviderSupported('custom-provider')).toBe(true) - - const allIds = getAllValidProviderIds() - expect(allIds).toContain('custom-provider') + expect(hasProviderConfig('custom-provider')).toBe(true) + expect(getProviderConfig('custom-provider')).toEqual(config) }) - it('拒绝与基础 provider 冲突的配置', () => { - const config = { - id: 'openai', - name: 'Duplicate OpenAI', - creator: vi.fn(() => ({ name: 'duplicate' })), - supportsImageGeneration: false + it('能够注册带别名的 provider 配置', () => { + const config: ProviderConfig = { + id: 'custom-provider-with-aliases', + name: 'Custom Provider with Aliases', + creator: vi.fn(() => ({ name: 'custom-aliased' })), + supportsImageGeneration: false, + aliases: ['alias-1', 'alias-2'] } - const success = registerDynamicProvider(config) - expect(success).toBe(false) - expect(isDynamicProvider('openai')).toBe(false) + const success = registerProviderConfig(config) + expect(success).toBe(true) + + expect(hasProviderConfigByAlias('alias-1')).toBe(true) + expect(hasProviderConfigByAlias('alias-2')).toBe(true) + expect(getProviderConfigByAlias('alias-1')).toEqual(config) + expect(resolveProviderConfigId('alias-1')).toBe('custom-provider-with-aliases') }) it('拒绝无效的配置', () => { @@ -147,12 +145,12 @@ describe('AiProviderRegistry 功能测试', () => { // 缺少 name, creator 等 } - const success = registerDynamicProvider(invalidConfig as any) + const success = registerProviderConfig(invalidConfig as any) expect(success).toBe(false) }) - it('能够批量注册动态 providers', () => { - const configs: DynamicProviderRegistration[] = [ + it('能够批量注册 provider 配置', () => { + const configs: ProviderConfig[] = [ { id: 'provider-1', name: 'Provider 1', @@ -166,146 +164,146 @@ describe('AiProviderRegistry 功能测试', () => { supportsImageGeneration: true }, { - id: 'openai', // 这个失败,因为与基础 provider 冲突 + id: '', // 无效配置 name: 'Invalid Provider', creator: vi.fn(() => ({ name: 'invalid' })), supportsImageGeneration: false - } + } as any ] - const successCount = registerMultipleProviders(configs) + const successCount = registerMultipleProviderConfigs(configs) expect(successCount).toBe(2) // 只有前两个成功 - expect(isDynamicProvider('provider-1')).toBe(true) - expect(isDynamicProvider('provider-2')).toBe(true) - expect(isDynamicProvider('openai')).toBe(false) // 基础 provider,不是动态的 + expect(hasProviderConfig('provider-1')).toBe(true) + expect(hasProviderConfig('provider-2')).toBe(true) + expect(hasProviderConfig('')).toBe(false) }) - it('支持带映射关系的动态 provider', () => { - const configWithMappings: DynamicProviderRegistration = { - id: 'custom-provider-with-mappings', - name: 'Custom Provider with Mappings', - creator: vi.fn(() => ({ name: 'custom-mapped' })), + it('能够获取所有配置和别名信息', () => { + // 注册一些配置 + registerProviderConfig({ + id: 'test-provider', + name: 'Test Provider', + creator: vi.fn(), supportsImageGeneration: false, - mappings: { - 'custom-alias-1': 'custom-provider-with-mappings', - 'custom-alias-2': 'custom-provider-with-mappings' - } + aliases: ['test-alias'] + }) + + const allConfigs = getAllProviderConfigs() + expect(Array.isArray(allConfigs)).toBe(true) + expect(allConfigs.some((config) => config.id === 'test-provider')).toBe(true) + + const aliases = getAllProviderConfigAliases() + expect(aliases['test-alias']).toBe('test-provider') + expect(isProviderConfigAlias('test-alias')).toBe(true) + }) + }) + + describe('Provider 创建和注册', () => { + it('能够创建 provider 实例', async () => { + const config: ProviderConfig = { + id: 'test-create-provider', + name: 'Test Create Provider', + creator: vi.fn(() => ({ name: 'test-created' })), + supportsImageGeneration: false } - const success = registerDynamicProvider(configWithMappings) + // 先注册配置 + registerProviderConfig(config) + + // 创建 provider 实例 + const provider = await createProvider('test-create-provider', { apiKey: 'test' }) + expect(provider).toBeDefined() + expect(config.creator).toHaveBeenCalledWith({ apiKey: 'test' }) + }) + + it('能够注册 provider 到全局管理器', () => { + const mockProvider = { name: 'mock-provider' } + const config: ProviderConfig = { + id: 'test-register-provider', + name: 'Test Register Provider', + creator: vi.fn(() => mockProvider), + supportsImageGeneration: false + } + + // 先注册配置 + registerProviderConfig(config) + + // 注册 provider 到全局管理器 + const success = registerProvider('test-register-provider', mockProvider) expect(success).toBe(true) - // 验证映射关系 - expect(getProviderMapping('custom-alias-1')).toBe('custom-provider-with-mappings') - expect(getProviderMapping('custom-alias-2')).toBe('custom-provider-with-mappings') - expect(getProviderMapping('custom-provider-with-mappings')).toBe('custom-provider-with-mappings') + // 验证注册成功 + const registeredProviders = getInitializedProviders() + expect(registeredProviders).toContain('test-register-provider') + expect(hasInitializedProviders()).toBe(true) + }) + + it('能够一步完成创建和注册', async () => { + const config: ProviderConfig = { + id: 'test-create-and-register', + name: 'Test Create and Register', + creator: vi.fn(() => ({ name: 'test-both' })), + supportsImageGeneration: false + } + + // 先注册配置 + registerProviderConfig(config) + + // 一步完成创建和注册 + const success = await createAndRegisterProvider('test-create-and-register', { apiKey: 'test' }) + expect(success).toBe(true) + + // 验证注册成功 + const registeredProviders = getInitializedProviders() + expect(registeredProviders).toContain('test-create-and-register') }) }) describe('Registry 管理', () => { - it('能够清理动态 providers', () => { - // 注册动态 provider - registerDynamicProvider({ + it('能够清理所有配置和注册的 providers', () => { + // 注册一些配置 + registerProviderConfig({ id: 'temp-provider', name: 'Temp Provider', creator: vi.fn(() => ({ name: 'temp' })), supportsImageGeneration: false }) - expect(isDynamicProvider('temp-provider')).toBe(true) + expect(hasProviderConfig('temp-provider')).toBe(true) // 清理 cleanup() - expect(isDynamicProvider('temp-provider')).toBe(false) - expect(isProviderSupported('openai')).toBe(true) // 基础 providers 仍存在 + expect(hasProviderConfig('temp-provider')).toBe(false) + // 但基础配置应该重新加载 + expect(hasProviderConfig('openai')).toBe(true) // 基础 providers 会重新初始化 }) - it('保持单例模式', () => { - const instance1 = AiProviderRegistry.getInstance() - const instance2 = AiProviderRegistry.getInstance() - expect(instance1).toBe(instance2) + it('能够单独清理已注册的 providers', () => { + // 清理所有 providers + clearAllProviders() + + expect(getInitializedProviders()).toEqual([]) + expect(hasInitializedProviders()).toBe(false) }) - it('能够注册基础 provider', () => { - const customConfig: ProviderConfig = { - id: 'custom-base-provider', - name: 'Custom Base Provider', - creator: vi.fn(() => ({ name: 'custom-base' })), - supportsImageGeneration: false - } - - // 注册基础 provider 不抛出错误 - expect(() => registerProvider(customConfig)).not.toThrow() - - // 验证注册成功 - const registeredProvider = getProvider('custom-base-provider') - expect(registeredProvider).toBeDefined() - expect(registeredProvider?.id).toBe('custom-base-provider') - expect(registeredProvider?.name).toBe('Custom Base Provider') - }) - - it('能够获取动态 providers 列表', () => { - // 初始状态没有动态 providers - expect(getDynamicProviders()).toEqual([]) - - // 注册一些动态 providers - registerDynamicProvider({ - id: 'dynamic-1', - name: 'Dynamic 1', - creator: vi.fn(() => ({ name: 'dynamic-1' })), - supportsImageGeneration: false - }) - - registerDynamicProvider({ - id: 'dynamic-2', - name: 'Dynamic 2', - creator: vi.fn(() => ({ name: 'dynamic-2' })), - supportsImageGeneration: true - }) - - const dynamicProviders = getDynamicProviders() - expect(Array.isArray(dynamicProviders)).toBe(true) - expect(dynamicProviders).toContain('dynamic-1') - expect(dynamicProviders).toContain('dynamic-2') - expect(dynamicProviders.length).toBe(2) - }) - - it('能够获取所有动态映射', () => { - // 初始状态没有动态映射 - expect(getAllDynamicMappings()).toEqual({}) - - // 注册带映射的动态 provider - registerDynamicProvider({ - id: 'mapped-provider', - name: 'Mapped Provider', - creator: vi.fn(() => ({ name: 'mapped' })), - supportsImageGeneration: false, - mappings: { - 'alias-1': 'mapped-provider', - 'alias-2': 'mapped-provider', - 'custom-name': 'mapped-provider' - } - }) - - const allMappings = getAllDynamicMappings() - expect(allMappings).toEqual({ - 'alias-1': 'mapped-provider', - 'alias-2': 'mapped-provider', - 'custom-name': 'mapped-provider' - }) + it('ProviderInitializationError 错误类工作正常', () => { + const error = new ProviderInitializationError('Test error', 'test-provider') + expect(error.message).toBe('Test error') + expect(error.providerId).toBe('test-provider') + expect(error.name).toBe('ProviderInitializationError') }) }) describe('错误处理', () => { it('优雅处理空配置', () => { - const success = registerDynamicProvider(null as any) + const success = registerProviderConfig(null as any) expect(success).toBe(false) }) it('优雅处理未定义配置', () => { - const success = registerDynamicProvider(undefined as any) + const success = registerProviderConfig(undefined as any) expect(success).toBe(false) }) @@ -317,27 +315,31 @@ describe('AiProviderRegistry 功能测试', () => { supportsImageGeneration: false } - const success = registerDynamicProvider(config) + const success = registerProviderConfig(config) expect(success).toBe(false) }) - it('处理注册基础 provider 时的无效 ID', () => { - const invalidConfig: ProviderConfig = { - id: '', // 无效 ID - name: 'Invalid Provider', - creator: vi.fn(() => ({ name: 'invalid' })), - supportsImageGeneration: false - } - - expect(() => registerProvider(invalidConfig)).toThrow('Invalid provider ID:') + it('处理创建不存在配置的 provider', async () => { + await expect(createProvider('non-existent-provider', {})).rejects.toThrow( + 'ProviderConfig not found for id: non-existent-provider' + ) }) - it('处理获取不存在映射时的情况', () => { - expect(getProviderMapping('non-existent-mapping')).toBeUndefined() + it('处理注册不存在配置的 provider', () => { + const mockProvider = { name: 'mock' } + const success = registerProvider('non-existent-provider', mockProvider) + expect(success).toBe(false) + }) + + it('处理获取不存在配置的情况', () => { + expect(getProviderConfig('non-existent')).toBeUndefined() + expect(getProviderConfigByAlias('non-existent-alias')).toBeUndefined() + expect(hasProviderConfig('non-existent')).toBe(false) + expect(hasProviderConfigByAlias('non-existent-alias')).toBe(false) }) it('处理批量注册时的部分失败', () => { - const mixedConfigs: DynamicProviderRegistration[] = [ + const mixedConfigs: ProviderConfig[] = [ { id: 'valid-provider-1', name: 'Valid Provider 1', @@ -349,7 +351,7 @@ describe('AiProviderRegistry 功能测试', () => { name: 'Invalid Provider', creator: vi.fn(() => ({ name: 'invalid' })), supportsImageGeneration: false - }, + } as any, { id: 'valid-provider-2', name: 'Valid Provider 2', @@ -358,153 +360,168 @@ describe('AiProviderRegistry 功能测试', () => { } ] - const successCount = registerMultipleProviders(mixedConfigs) + const successCount = registerMultipleProviderConfigs(mixedConfigs) expect(successCount).toBe(2) // 只有两个有效配置成功 - expect(isDynamicProvider('valid-provider-1')).toBe(true) - expect(isDynamicProvider('valid-provider-2')).toBe(true) - expect(getDynamicProviders()).not.toContain('') + expect(hasProviderConfig('valid-provider-1')).toBe(true) + expect(hasProviderConfig('valid-provider-2')).toBe(true) + expect(hasProviderConfig('')).toBe(false) + }) + + it('处理动态导入失败的情况', async () => { + const config: ProviderConfig = { + id: 'import-test-provider', + name: 'Import Test Provider', + import: vi.fn().mockRejectedValue(new Error('Import failed')), + creatorFunctionName: 'createTest', + supportsImageGeneration: false + } + + registerProviderConfig(config) + + await expect(createProvider('import-test-provider', {})).rejects.toThrow('Import failed') }) }) describe('集成测试', () => { - it('正确处理复杂的注册、映射和清理场景', () => { + it('正确处理复杂的配置、创建、注册和清理场景', async () => { // 初始状态验证 - const initialProviders = getAllProviders() - const initialIds = getAllValidProviderIds() - expect(initialProviders.length).toBeGreaterThan(0) - expect(getDynamicProviders()).toEqual([]) - expect(getAllDynamicMappings()).toEqual({}) + const initialConfigs = getAllProviderConfigs() + expect(initialConfigs.length).toBeGreaterThan(0) // 有基础配置 + expect(getInitializedProviders()).toEqual([]) - // 注册多个带映射的动态 providers - const configs: DynamicProviderRegistration[] = [ + // 注册多个带别名的 provider 配置 + const configs: ProviderConfig[] = [ { id: 'integration-provider-1', name: 'Integration Provider 1', creator: vi.fn(() => ({ name: 'integration-1' })), supportsImageGeneration: false, - mappings: { - 'alias-1': 'integration-provider-1', - 'short-name-1': 'integration-provider-1' - } + aliases: ['alias-1', 'short-name-1'] }, { id: 'integration-provider-2', name: 'Integration Provider 2', creator: vi.fn(() => ({ name: 'integration-2' })), supportsImageGeneration: true, - mappings: { - 'alias-2': 'integration-provider-2', - 'short-name-2': 'integration-provider-2' - } + aliases: ['alias-2', 'short-name-2'] } ] - const successCount = registerMultipleProviders(configs) + const successCount = registerMultipleProviderConfigs(configs) expect(successCount).toBe(2) - // 验证注册后的状态 - const afterRegisterProviders = getAllProviders() - const afterRegisterIds = getAllValidProviderIds() - expect(afterRegisterProviders.length).toBe(initialProviders.length + 2) - expect(afterRegisterIds.length).toBeGreaterThanOrEqual(initialIds.length + 2) + // 验证配置注册成功 + expect(hasProviderConfig('integration-provider-1')).toBe(true) + expect(hasProviderConfig('integration-provider-2')).toBe(true) + expect(hasProviderConfigByAlias('alias-1')).toBe(true) + expect(hasProviderConfigByAlias('alias-2')).toBe(true) - // 验证动态 providers - const dynamicProviders = getDynamicProviders() - expect(dynamicProviders).toContain('integration-provider-1') - expect(dynamicProviders).toContain('integration-provider-2') + // 验证别名映射 + const aliases = getAllProviderConfigAliases() + expect(aliases['alias-1']).toBe('integration-provider-1') + expect(aliases['alias-2']).toBe('integration-provider-2') - // 验证映射 - const mappings = getAllDynamicMappings() - expect(mappings['alias-1']).toBe('integration-provider-1') - expect(mappings['alias-2']).toBe('integration-provider-2') - expect(mappings['short-name-1']).toBe('integration-provider-1') - expect(mappings['short-name-2']).toBe('integration-provider-2') + // 创建和注册 providers + const success1 = await createAndRegisterProvider('integration-provider-1', { apiKey: 'test1' }) + const success2 = await createAndRegisterProvider('integration-provider-2', { apiKey: 'test2' }) + expect(success1).toBe(true) + expect(success2).toBe(true) - // 验证通过映射能够获取 provider - expect(getProviderMapping('alias-1')).toBe('integration-provider-1') - expect(getProviderMapping('integration-provider-1')).toBe('integration-provider-1') + // 验证注册成功 + const registeredProviders = getInitializedProviders() + expect(registeredProviders).toContain('integration-provider-1') + expect(registeredProviders).toContain('integration-provider-2') + expect(hasInitializedProviders()).toBe(true) // 清理 cleanup() // 验证清理后的状态 - const afterCleanupProviders = getAllProviders() - const afterCleanupIds = getAllValidProviderIds() - expect(afterCleanupProviders.length).toBe(initialProviders.length) - expect(afterCleanupIds.length).toBe(initialIds.length) - expect(getDynamicProviders()).toEqual([]) - expect(getAllDynamicMappings()).toEqual({}) + expect(getInitializedProviders()).toEqual([]) + expect(hasProviderConfig('integration-provider-1')).toBe(false) + expect(hasProviderConfig('integration-provider-2')).toBe(false) + expect(getAllProviderConfigAliases()).toEqual({}) + + // 基础配置应该重新加载 + expect(hasProviderConfig('openai')).toBe(true) }) - it('正确处理 provider 的优先级和 fallback 机制', () => { - // 验证 getProvider 的 fallback 机制 - const existingProvider = getProvider('openai') - expect(existingProvider?.id).toBe('openai') - - const nonExistentProvider = getProvider('definitely-non-existent') - expect(nonExistentProvider?.id).toBe('openai-compatible') // fallback - - // 注册自定义 provider 后能直接获取 - registerDynamicProvider({ - id: 'priority-test-provider', - name: 'Priority Test Provider', - creator: vi.fn(() => ({ name: 'priority-test' })), + it('正确处理动态导入配置的 provider', async () => { + const mockModule = { createCustomProvider: vi.fn(() => ({ name: 'custom-dynamic' })) } + const dynamicImportConfig: ProviderConfig = { + id: 'dynamic-import-provider', + name: 'Dynamic Import Provider', + import: vi.fn().mockResolvedValue(mockModule), + creatorFunctionName: 'createCustomProvider', supportsImageGeneration: false - }) + } - const customProvider = getProvider('priority-test-provider') - expect(customProvider?.id).toBe('priority-test-provider') - expect(customProvider?.name).toBe('Priority Test Provider') + // 注册配置 + const configSuccess = registerProviderConfig(dynamicImportConfig) + expect(configSuccess).toBe(true) + + // 创建和注册 provider + const registerSuccess = await createAndRegisterProvider('dynamic-import-provider', { apiKey: 'test' }) + expect(registerSuccess).toBe(true) + + // 验证导入函数被调用 + expect(dynamicImportConfig.import).toHaveBeenCalled() + expect(mockModule.createCustomProvider).toHaveBeenCalledWith({ apiKey: 'test' }) + + // 验证注册成功 + expect(getInitializedProviders()).toContain('dynamic-import-provider') }) - it('正确处理大量动态 providers 的注册和管理', () => { - const largeConfigList: DynamicProviderRegistration[] = [] + it('正确处理大量配置的注册和管理', () => { + const largeConfigList: ProviderConfig[] = [] - // 生成100个动态 providers - for (let i = 0; i < 100; i++) { + // 生成50个配置 + for (let i = 0; i < 50; i++) { largeConfigList.push({ id: `bulk-provider-${i}`, name: `Bulk Provider ${i}`, creator: vi.fn(() => ({ name: `bulk-${i}` })), supportsImageGeneration: i % 2 === 0, // 偶数支持图像生成 - mappings: { - [`alias-${i}`]: `bulk-provider-${i}`, - [`short-${i}`]: `bulk-provider-${i}` - } + aliases: [`alias-${i}`, `short-${i}`] }) } - const successCount = registerMultipleProviders(largeConfigList) - expect(successCount).toBe(100) + const successCount = registerMultipleProviderConfigs(largeConfigList) + expect(successCount).toBe(50) - // 验证所有 providers 都被正确注册 - const dynamicProviders = getDynamicProviders() - expect(dynamicProviders.length).toBe(100) + // 验证所有配置都被正确注册 + const allConfigs = getAllProviderConfigs() + expect(allConfigs.filter((config) => config.id.startsWith('bulk-provider-')).length).toBe(50) - // 验证映射数量 - const mappings = getAllDynamicMappings() - expect(Object.keys(mappings).length).toBe(200) // 每个 provider 有2个映射 + // 验证别名数量 + const aliases = getAllProviderConfigAliases() + const bulkAliases = Object.keys(aliases).filter( + (alias) => alias.startsWith('alias-') || alias.startsWith('short-') + ) + expect(bulkAliases.length).toBe(100) // 每个 provider 有2个别名 - // 随机验证几个 providers - expect(isDynamicProvider('bulk-provider-0')).toBe(true) - expect(isDynamicProvider('bulk-provider-50')).toBe(true) - expect(isDynamicProvider('bulk-provider-99')).toBe(true) + // 随机验证几个配置 + expect(hasProviderConfig('bulk-provider-0')).toBe(true) + expect(hasProviderConfig('bulk-provider-25')).toBe(true) + expect(hasProviderConfig('bulk-provider-49')).toBe(true) - // 验证映射工作正常 - expect(getProviderMapping('alias-25')).toBe('bulk-provider-25') - expect(getProviderMapping('short-75')).toBe('bulk-provider-75') + // 验证别名工作正常 + expect(resolveProviderConfigId('alias-25')).toBe('bulk-provider-25') + expect(isProviderConfigAlias('short-30')).toBe(true) // 清理能正确处理大量数据 cleanup() - expect(getDynamicProviders()).toEqual([]) - expect(getAllDynamicMappings()).toEqual({}) + const cleanupAliases = getAllProviderConfigAliases() + expect( + Object.keys(cleanupAliases).filter((alias) => alias.startsWith('alias-') || alias.startsWith('short-')) + ).toEqual([]) }) }) describe('边界测试', () => { it('处理包含特殊字符的 provider IDs', () => { - const specialCharsConfigs: DynamicProviderRegistration[] = [ + const specialCharsConfigs: ProviderConfig[] = [ { id: 'provider-with-dashes', name: 'Provider With Dashes', @@ -525,26 +542,29 @@ describe('AiProviderRegistry 功能测试', () => { } ] - const successCount = registerMultipleProviders(specialCharsConfigs) + const successCount = registerMultipleProviderConfigs(specialCharsConfigs) expect(successCount).toBeGreaterThan(0) // 至少有一些成功 // 验证支持的特殊字符格式 - if (isDynamicProvider('provider-with-dashes')) { - expect(getProvider('provider-with-dashes')).toBeDefined() + if (hasProviderConfig('provider-with-dashes')) { + expect(getProviderConfig('provider-with-dashes')).toBeDefined() } - if (isDynamicProvider('provider_with_underscores')) { - expect(getProvider('provider_with_underscores')).toBeDefined() + if (hasProviderConfig('provider_with_underscores')) { + expect(getProviderConfig('provider_with_underscores')).toBeDefined() } }) it('处理空的批量注册', () => { - const successCount = registerMultipleProviders([]) + const successCount = registerMultipleProviderConfigs([]) expect(successCount).toBe(0) - expect(getDynamicProviders()).toEqual([]) + + // 确保没有额外的配置被添加 + const configsBefore = getAllProviderConfigs().length + expect(configsBefore).toBeGreaterThan(0) // 应该有基础配置 }) - it('处理重复的 provider 注册', () => { - const config: DynamicProviderRegistration = { + it('处理重复的配置注册', () => { + const config: ProviderConfig = { id: 'duplicate-test-provider', name: 'Duplicate Test Provider', creator: vi.fn(() => ({ name: 'duplicate' })), @@ -552,17 +572,61 @@ describe('AiProviderRegistry 功能测试', () => { } // 第一次注册成功 - expect(registerDynamicProvider(config)).toBe(true) - expect(isDynamicProvider('duplicate-test-provider')).toBe(true) + expect(registerProviderConfig(config)).toBe(true) + expect(hasProviderConfig('duplicate-test-provider')).toBe(true) - // 重复注册相同的 provider - expect(registerDynamicProvider(config)).toBe(true) // 允许覆盖 - expect(isDynamicProvider('duplicate-test-provider')).toBe(true) + // 重复注册相同的配置(允许覆盖) + const updatedConfig: ProviderConfig = { + ...config, + name: 'Updated Duplicate Test Provider' + } + expect(registerProviderConfig(updatedConfig)).toBe(true) + expect(hasProviderConfig('duplicate-test-provider')).toBe(true) - // 验证只有一个实例 - const dynamicProviders = getDynamicProviders() - const duplicateCount = dynamicProviders.filter((id) => id === 'duplicate-test-provider').length - expect(duplicateCount).toBe(1) + // 验证配置被更新 + const retrievedConfig = getProviderConfig('duplicate-test-provider') + expect(retrievedConfig?.name).toBe('Updated Duplicate Test Provider') + }) + + it('处理极长的 ID 和名称', () => { + const longId = 'very-long-provider-id-' + 'x'.repeat(100) + const longName = 'Very Long Provider Name ' + 'Y'.repeat(100) + + const config: ProviderConfig = { + id: longId, + name: longName, + creator: vi.fn(() => ({ name: 'long-test' })), + supportsImageGeneration: false + } + + const success = registerProviderConfig(config) + expect(success).toBe(true) + expect(hasProviderConfig(longId)).toBe(true) + + const retrievedConfig = getProviderConfig(longId) + expect(retrievedConfig?.name).toBe(longName) + }) + + it('处理大量别名的配置', () => { + const manyAliases = Array.from({ length: 50 }, (_, i) => `alias-${i}`) + + const config: ProviderConfig = { + id: 'provider-with-many-aliases', + name: 'Provider With Many Aliases', + creator: vi.fn(() => ({ name: 'many-aliases' })), + supportsImageGeneration: false, + aliases: manyAliases + } + + const success = registerProviderConfig(config) + expect(success).toBe(true) + + // 验证所有别名都能正确解析 + manyAliases.forEach((alias) => { + expect(hasProviderConfigByAlias(alias)).toBe(true) + expect(resolveProviderConfigId(alias)).toBe('provider-with-many-aliases') + expect(isProviderConfigAlias(alias)).toBe(true) + }) }) }) }) diff --git a/packages/aiCore/src/core/providers/__tests__/schemas.test.ts b/packages/aiCore/src/core/providers/__tests__/schemas.test.ts index 3addc0708a..82b390ba05 100644 --- a/packages/aiCore/src/core/providers/__tests__/schemas.test.ts +++ b/packages/aiCore/src/core/providers/__tests__/schemas.test.ts @@ -5,18 +5,11 @@ import { baseProviderIds, baseProviderIdSchema, baseProviders, - type DynamicProviderId, - dynamicProviderIdSchema, - dynamicProviderRegistrationSchema, - getBaseProviderConfig, - isBaseProviderId, - isValidDynamicProviderId, + type CustomProviderId, + customProviderIdSchema, providerConfigSchema, type ProviderId, - providerIdSchema, - validateDynamicProviderRegistration, - validateProviderConfig, - validateProviderId + providerIdSchema } from '../schemas' describe('Provider Schemas', () => { @@ -90,22 +83,22 @@ describe('Provider Schemas', () => { }) }) - describe('dynamicProviderIdSchema', () => { - it('接受有效的动态 provider IDs', () => { + describe('customProviderIdSchema', () => { + it('接受有效的自定义 provider IDs', () => { const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2'] validIds.forEach((id) => { - expect(dynamicProviderIdSchema.safeParse(id).success).toBe(true) + expect(customProviderIdSchema.safeParse(id).success).toBe(true) }) }) it('拒绝与基础 provider IDs 冲突的 IDs', () => { baseProviderIds.forEach((id) => { - expect(dynamicProviderIdSchema.safeParse(id).success).toBe(false) + expect(customProviderIdSchema.safeParse(id).success).toBe(false) }) }) it('拒绝空字符串', () => { - expect(dynamicProviderIdSchema.safeParse('').success).toBe(false) + expect(customProviderIdSchema.safeParse('').success).toBe(false) }) }) @@ -116,9 +109,9 @@ describe('Provider Schemas', () => { }) }) - it('接受有效的动态 provider IDs', () => { - const validDynamicIds = ['custom-provider', 'my-ai-service'] - validDynamicIds.forEach((id) => { + it('接受有效的自定义 provider IDs', () => { + const validCustomIds = ['custom-provider', 'my-ai-service'] + validCustomIds.forEach((id) => { expect(providerIdSchema.safeParse(id).success).toBe(true) }) }) @@ -134,8 +127,8 @@ describe('Provider Schemas', () => { describe('providerConfigSchema', () => { it('验证带有 creator 的有效配置', () => { const validConfig = { - id: 'openai', - name: 'OpenAI', + id: 'custom-provider', + name: 'Custom Provider', creator: vi.fn(), supportsImageGeneration: true } @@ -175,12 +168,21 @@ describe('Provider Schemas', () => { } }) + it('拒绝使用基础 provider ID 的配置', () => { + const invalidConfig = { + id: 'openai', // 基础 provider ID + name: 'Should Fail', + creator: vi.fn() + } + expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false) + }) + it('拒绝缺少必需字段的配置', () => { const invalidConfigs = [ { name: 'Missing ID', creator: vi.fn() }, { id: 'missing-name', creator: vi.fn() }, { id: '', name: 'Empty ID', creator: vi.fn() }, - { id: 'valid', name: '', creator: vi.fn() } + { id: 'valid-custom', name: '', creator: vi.fn() } ] invalidConfigs.forEach((config) => { @@ -189,186 +191,55 @@ describe('Provider Schemas', () => { }) }) - describe('dynamicProviderRegistrationSchema', () => { - it('验证有效的动态 provider 注册配置', () => { + describe('Schema 验证功能', () => { + it('baseProviderIdSchema 正确验证基础 provider IDs', () => { + baseProviderIds.forEach((id) => { + expect(baseProviderIdSchema.safeParse(id).success).toBe(true) + }) + + expect(baseProviderIdSchema.safeParse('invalid-id').success).toBe(false) + }) + + it('customProviderIdSchema 正确验证自定义 provider IDs', () => { + const customIds = ['custom-provider', 'my-service', 'company-llm'] + customIds.forEach((id) => { + expect(customProviderIdSchema.safeParse(id).success).toBe(true) + }) + + // 拒绝基础 provider IDs + baseProviderIds.forEach((id) => { + expect(customProviderIdSchema.safeParse(id).success).toBe(false) + }) + }) + + it('providerIdSchema 接受基础和自定义 provider IDs', () => { + // 基础 IDs + baseProviderIds.forEach((id) => { + expect(providerIdSchema.safeParse(id).success).toBe(true) + }) + + // 自定义 IDs + const customIds = ['custom-provider', 'my-service'] + customIds.forEach((id) => { + expect(providerIdSchema.safeParse(id).success).toBe(true) + }) + }) + + it('providerConfigSchema 验证完整的 provider 配置', () => { const validConfig = { id: 'custom-provider', name: 'Custom Provider', creator: vi.fn(), - supportsImageGeneration: true, - mappings: { model1: 'mapped-model1' } - } - expect(dynamicProviderRegistrationSchema.safeParse(validConfig).success).toBe(true) - }) - - it('拒绝使用基础 provider ID 的配置', () => { - const invalidConfig = { - id: 'openai', - name: 'Should Fail', - creator: vi.fn() - } - expect(dynamicProviderRegistrationSchema.safeParse(invalidConfig).success).toBe(false) - }) - - it('要求 creator 或 import 配置', () => { - const configWithoutCreator = { - id: 'custom-provider', - name: 'Custom Provider' - } - expect(dynamicProviderRegistrationSchema.safeParse(configWithoutCreator).success).toBe(false) - }) - }) - - describe('validateProviderId', () => { - it('验证基础 provider IDs', () => { - baseProviderIds.forEach((id) => { - expect(validateProviderId(id)).toBe(true) - }) - }) - - it('验证有效的动态 provider IDs', () => { - const validDynamicIds = ['custom-provider', 'my-service', 'company-llm'] - validDynamicIds.forEach((id) => { - expect(validateProviderId(id)).toBe(true) - }) - }) - - it('拒绝无效的 IDs', () => { - const invalidIds = [undefined as any, null as any, 123 as any] - invalidIds.forEach((id) => { - expect(validateProviderId(id)).toBe(false) - }) - - // 空字符串和只有空格的字符串会被当作有效的动态 provider ID - expect(validateProviderId('')).toBe(false) - }) - }) - - describe('isBaseProviderId', () => { - it('正确识别基础 provider IDs', () => { - baseProviderIds.forEach((id) => { - expect(isBaseProviderId(id)).toBe(true) - }) - }) - - it('拒绝动态 provider IDs', () => { - const dynamicIds = ['custom-provider', 'my-service'] - dynamicIds.forEach((id) => { - expect(isBaseProviderId(id)).toBe(false) - }) - }) - - it('拒绝无效的 IDs', () => { - const invalidIds = ['', 'invalid', undefined as any] - invalidIds.forEach((id) => { - expect(isBaseProviderId(id)).toBe(false) - }) - }) - }) - - describe('isValidDynamicProviderId', () => { - it('接受有效的动态 provider IDs', () => { - const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2'] - validIds.forEach((id) => { - expect(isValidDynamicProviderId(id)).toBe(true) - }) - }) - - it('拒绝基础 provider IDs', () => { - baseProviderIds.forEach((id) => { - expect(isValidDynamicProviderId(id)).toBe(false) - }) - }) - - it('拒绝无效的 IDs', () => { - const invalidIds = [undefined as any, null as any] - invalidIds.forEach((id) => { - expect(isValidDynamicProviderId(id)).toBe(false) - }) - - // 空字符串会被 schema 拒绝 - expect(isValidDynamicProviderId('')).toBe(false) - // 只有空格的字符串是有效的动态 provider ID(但不推荐使用) - expect(isValidDynamicProviderId(' ')).toBe(true) - }) - }) - - describe('validateProviderConfig', () => { - it('返回有效配置', () => { - const validConfig = { - id: 'openai', - name: 'OpenAI', - creator: vi.fn(), supportsImageGeneration: true } - const result = validateProviderConfig(validConfig) - expect(result).not.toBeNull() - expect(result?.id).toBe('openai') - expect(result?.name).toBe('OpenAI') - }) + expect(providerConfigSchema.safeParse(validConfig).success).toBe(true) - it('对无效配置返回 null', () => { const invalidConfig = { - id: '', - name: 'Invalid' + id: 'openai', // 不允许基础 provider ID + name: 'OpenAI', + creator: vi.fn() } - const result = validateProviderConfig(invalidConfig) - expect(result).toBeNull() - }) - - it('处理完全无效的输入', () => { - const invalidInputs = [undefined, null, 'string', 123, []] - invalidInputs.forEach((input) => { - const result = validateProviderConfig(input) - expect(result).toBeNull() - }) - }) - }) - - describe('validateDynamicProviderRegistration', () => { - it('返回有效的动态 provider 注册配置', () => { - const validConfig = { - id: 'custom-provider', - name: 'Custom Provider', - creator: vi.fn(), - mappings: { model1: 'mapped-model1' } - } - const result = validateDynamicProviderRegistration(validConfig) - expect(result).not.toBeNull() - expect(result?.id).toBe('custom-provider') - expect(result?.name).toBe('Custom Provider') - }) - - it('对无效配置返回 null', () => { - const invalidConfig = { - id: 'openai', - name: 'Should Fail' - } - const result = validateDynamicProviderRegistration(invalidConfig) - expect(result).toBeNull() - }) - }) - - describe('getBaseProviderConfig', () => { - it('返回有效基础 provider ID 的配置', () => { - const config = getBaseProviderConfig('openai') - expect(config).toBeDefined() - expect(config?.id).toBe('openai') - expect(config?.name).toBe('OpenAI') - expect(config?.creator).toBeDefined() - }) - - it('对无效 ID 返回 undefined', () => { - const config = getBaseProviderConfig('invalid' as BaseProviderId) - expect(config).toBeUndefined() - }) - - it('返回所有基础 providers 的配置', () => { - baseProviderIds.forEach((id) => { - const config = getBaseProviderConfig(id) - expect(config).toBeDefined() - expect(config?.id).toBe(id) - }) + expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false) }) }) @@ -378,16 +249,16 @@ describe('Provider Schemas', () => { expect(baseProviderIds).toContain(id) }) - it('DynamicProviderId 类型是字符串', () => { - const id: DynamicProviderId = 'custom-provider' + it('CustomProviderId 类型是字符串', () => { + const id: CustomProviderId = 'custom-provider' expect(typeof id).toBe('string') }) - it('ProviderId 类型支持基础和动态 IDs', () => { + it('ProviderId 类型支持基础和自定义 IDs', () => { const baseId: ProviderId = 'openai' - const dynamicId: ProviderId = 'custom-provider' + const customId: ProviderId = 'custom-provider' expect(typeof baseId).toBe('string') - expect(typeof dynamicId).toBe('string') + expect(typeof customId).toBe('string') }) }) }) diff --git a/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts b/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts index 67273fdd59..7ac8637ed3 100644 --- a/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts +++ b/packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts @@ -2,8 +2,8 @@ import { ImageModelV2 } from '@ai-sdk/provider' import { experimental_generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { createImageModel } from '../../models/ModelCreator' import { type AiPlugin } from '../../plugins' +import { globalRegistryManagement } from '../../providers/RegistryManagement' import { ImageGenerationError } from '../errors' import { RuntimeExecutor } from '../executor' @@ -19,8 +19,10 @@ vi.mock('ai', () => ({ } })) -vi.mock('../../models/ModelCreator', () => ({ - createImageModel: vi.fn() +vi.mock('../../providers/RegistryManagement', () => ({ + globalRegistryManagement: { + imageModel: vi.fn() + } })) describe('RuntimeExecutor.generateImage', () => { @@ -67,8 +69,11 @@ describe('RuntimeExecutor.generateImage', () => { } // Setup mocks - vi.mocked(createImageModel).mockResolvedValue(mockImageModel) + vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel) vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult) + + // Reset mock implementation in case it was changed by previous tests + vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => mockImageModel) }) describe('Basic functionality', () => { @@ -77,7 +82,7 @@ describe('RuntimeExecutor.generateImage', () => { prompt: 'A futuristic cityscape at sunset' }) - expect(createImageModel).toHaveBeenCalledWith('openai', 'dall-e-3', { apiKey: 'test-key' }) + expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai:dall-e-3') expect(aiGenerateImage).toHaveBeenCalledWith({ model: mockImageModel, @@ -92,7 +97,7 @@ describe('RuntimeExecutor.generateImage', () => { prompt: 'A beautiful landscape' }) - // Note: createImageModel may still be called due to resolveImageModel logic + // Note: globalRegistryManagement.imageModel may still be called due to resolveImageModel logic expect(aiGenerateImage).toHaveBeenCalledWith({ model: mockImageModel, prompt: 'A beautiful landscape' @@ -335,8 +340,10 @@ describe('RuntimeExecutor.generateImage', () => { describe('Error handling', () => { it('should handle model creation errors', async () => { - const modelError = new Error('Failed to create image model') - vi.mocked(createImageModel).mockRejectedValue(modelError) + const modelError = new Error('Failed to get image model') + vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => { + throw modelError + }) await expect( executor.generateImage('invalid-model', { @@ -431,7 +438,7 @@ describe('RuntimeExecutor.generateImage', () => { prompt: 'A landscape' }) - expect(createImageModel).toHaveBeenCalledWith('google', 'imagen-3.0-generate-002', { apiKey: 'google-key' }) + expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google:imagen-3.0-generate-002') }) it('should support xAI Grok image models', async () => { @@ -443,7 +450,7 @@ describe('RuntimeExecutor.generateImage', () => { prompt: 'A futuristic robot' }) - expect(createImageModel).toHaveBeenCalledWith('xai', 'grok-2-image', { apiKey: 'xai-key' }) + expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai:grok-2-image') }) }) diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index 99e585b5c5..9db95e512c 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -24,57 +24,11 @@ export { createContext, definePlugin, PluginManager } from './core/plugins' // export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in' export { PluginEngine } from './core/runtime/pluginEngine' -// ==================== 类型定义 ==================== -export type { GenerateObjectParams, GenerateTextParams, StreamObjectParams, StreamTextParams } from './types' - // ==================== AI SDK 常用类型导出 ==================== // 直接导出 AI SDK 的常用类型,方便使用 export type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider' export type { ToolCall } from '@ai-sdk/provider-utils' export type { ReasoningPart } from '@ai-sdk/provider-utils' -export type { - AssistantModelMessage, - FilePart, - // 通用类型 - FinishReason, - GenerateObjectResult, - // 生成相关类型 - GenerateTextResult, - ImagePart, - InferToolInput, - InferToolOutput, - InvalidToolInputError, - LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage - // 消息相关类型 - ModelMessage, - // 错误类型 - NoSuchToolError, - ProviderMetadata, - StreamTextResult, - SystemModelMessage, - TextPart, - // 流相关类型 - TextStreamPart, - // 工具相关类型 - Tool, - ToolCallPart, - ToolModelMessage, - ToolResultPart, - ToolSet, - TypedToolCall, - TypedToolError, - TypedToolResult, - UserModelMessage -} from 'ai' -export { - defaultSettingsMiddleware, - extractReasoningMiddleware, - simulateStreamingMiddleware, - smoothStream, - stepCountIs -} from 'ai' -// 重新导出 Agent -export { Experimental_Agent as Agent } from 'ai' // ==================== 选项 ==================== export { diff --git a/packages/aiCore/src/types.ts b/packages/aiCore/src/types.ts index c85cefa0a0..d7796a943d 100644 --- a/packages/aiCore/src/types.ts +++ b/packages/aiCore/src/types.ts @@ -1,9 +1,2 @@ -import { generateObject, generateText, streamObject, streamText } from 'ai' - -export type StreamTextParams = Omit[0], 'model'> -export type GenerateTextParams = Omit[0], 'model'> -export type StreamObjectParams = Omit[0], 'model'> -export type GenerateObjectParams = Omit[0], 'model'> - // 重新导出插件类型 export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types' diff --git a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts index 825960850c..0430c750d2 100644 --- a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts +++ b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts @@ -3,10 +3,10 @@ * 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式 */ -import { TextStreamPart, ToolSet } from '@cherrystudio/ai-core' import { loggerService } from '@logger' import { MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types' import { Chunk, ChunkType } from '@renderer/types/chunk' +import type { TextStreamPart, ToolSet } from 'ai' import { ToolCallChunkHandler } from './handleToolCallChunk' diff --git a/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts b/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts index 78a0a15338..ebe3f4e5c0 100644 --- a/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts +++ b/src/renderer/src/aiCore/chunk/handleToolCallChunk.ts @@ -4,10 +4,10 @@ * 提供工具调用相关的处理API,每个交互使用一个新的实例 */ -import { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from '@cherrystudio/ai-core' import { loggerService } from '@logger' import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types' import { Chunk, ChunkType } from '@renderer/types/chunk' +import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai' // import type { // AnthropicSearchOutput, // WebSearchPluginConfig diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index f95e0b2265..b0fbb6f333 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -8,7 +8,7 @@ * 3. 暂时保持接口兼容性 */ -import { createExecutor, generateImage, StreamTextParams } from '@cherrystudio/ai-core' +import { createExecutor, generateImage } from '@cherrystudio/ai-core' import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' import { isNotSupportedImageSizeModel } from '@renderer/config/models' @@ -23,6 +23,7 @@ import { CompletionsResult } from './legacy/middleware/schemas' import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder' import { buildPlugins } from './plugins/PluginBuilder' import { getActualProvider, isModernSdkSupported, providerToAiSdkConfig } from './provider/ProviderConfigProcessor' +import type { StreamTextParams } from './types' const logger = loggerService.withContext('ModernAiProvider') diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index 1b9454e2de..faef00051f 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -1,11 +1,7 @@ -import { - extractReasoningMiddleware, - LanguageModelV2Middleware, - simulateStreamingMiddleware -} from '@cherrystudio/ai-core' import { loggerService } from '@logger' import type { MCPTool, Model, Provider } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' +import { extractReasoningMiddleware, LanguageModelMiddleware, simulateStreamingMiddleware } from 'ai' const logger = loggerService.withContext('AiSdkMiddlewareBuilder') @@ -34,7 +30,7 @@ export interface AiSdkMiddlewareConfig { */ export interface NamedAiSdkMiddleware { name: string - middleware: LanguageModelV2Middleware + middleware: LanguageModelMiddleware } /** @@ -83,7 +79,7 @@ export class AiSdkMiddlewareBuilder { /** * 构建最终的中间件数组 */ - public build(): LanguageModelV2Middleware[] { + public build(): LanguageModelMiddleware[] { return this.middlewares.map((m) => m.middleware) } @@ -114,7 +110,7 @@ export class AiSdkMiddlewareBuilder { * 根据配置构建AI SDK中间件的工厂函数 * 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果 */ -export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV2Middleware[] { +export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] { const builder = new AiSdkMiddlewareBuilder() // 1. 根据provider添加特定中间件 diff --git a/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts b/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts index 8b521abc91..1fe0a177c3 100644 --- a/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts +++ b/src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts @@ -1,4 +1,5 @@ -import { definePlugin, TextStreamPart, ToolSet } from '@cherrystudio/ai-core' +import { definePlugin } from '@cherrystudio/ai-core' +import type { TextStreamPart, ToolSet } from 'ai' export default definePlugin({ name: 'reasoningTimePlugin', diff --git a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts index f9f8c91dfc..e57481eda5 100644 --- a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts +++ b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts @@ -6,8 +6,7 @@ * 2. transformParams: 根据意图分析结果动态添加对应的工具 * 3. onRequestEnd: 自动记忆存储 */ -import type { AiRequestContext, ModelMessage } from '@cherrystudio/ai-core' -import { definePlugin } from '@cherrystudio/ai-core' +import { type AiRequestContext, definePlugin } from '@cherrystudio/ai-core' import { loggerService } from '@logger' // import { generateObject } from '@cherrystudio/ai-core' import { @@ -20,6 +19,7 @@ import store from '@renderer/store' import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory' import type { Assistant } from '@renderer/types' import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract' +import type { ModelMessage } from 'ai' import { isEmpty } from 'lodash' import { MemoryProcessor } from '../../services/MemoryProcessor' diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts index 6e5f7b0d66..df791dfbf9 100644 --- a/src/renderer/src/aiCore/transformParameters.ts +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -3,7 +3,6 @@ * 统一管理从各个 apiClient 提取的参数处理和转换功能 */ -import { stepCountIs, type StreamTextParams } from '@cherrystudio/ai-core' import { loggerService } from '@logger' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { @@ -31,8 +30,10 @@ import { } from '@renderer/utils/messageUtils/find' import { defaultTimeout } from '@shared/config/constant' import type { AssistantModelMessage, FilePart, ImagePart, ModelMessage, TextPart, UserModelMessage } from 'ai' +import { stepCountIs } from 'ai' import { getAiSdkProviderId } from './provider/factory' +import type { StreamTextParams } from './types' // import { webSearchTool } from './tools/WebSearchTool' // import { jsonSchemaToZod } from 'json-schema-to-zod' import { setupToolsConfig } from './utils/mcp' diff --git a/src/renderer/src/aiCore/types.ts b/src/renderer/src/aiCore/types.ts new file mode 100644 index 0000000000..14edaf5abb --- /dev/null +++ b/src/renderer/src/aiCore/types.ts @@ -0,0 +1,6 @@ +import { generateObject, generateText, streamObject, streamText } from 'ai' + +export type StreamTextParams = Omit[0], 'model'> +export type GenerateTextParams = Omit[0], 'model'> +export type StreamObjectParams = Omit[0], 'model'> +export type GenerateObjectParams = Omit[0], 'model'> diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index d972f6d104..9c11fa230c 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -1,12 +1,12 @@ /** * 职责:提供原子化的、无状态的API调用函数 */ -import { StreamTextParams } from '@cherrystudio/ai-core' import { loggerService } from '@logger' import AiProvider from '@renderer/aiCore' import { CompletionsParams } from '@renderer/aiCore/legacy/middleware/schemas' import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/AiSdkMiddlewareBuilder' import { buildStreamTextParams } from '@renderer/aiCore/transformParameters' +import type { StreamTextParams } from '@renderer/aiCore/types' import { isDedicatedImageGenerationModel, isEmbeddingModel, isQwenMTModel } from '@renderer/config/models' import { LANG_DETECT_PROMPT } from '@renderer/config/prompts' import { getStoreSetting } from '@renderer/hooks/useSettings' diff --git a/src/renderer/src/services/ConversationService.ts b/src/renderer/src/services/ConversationService.ts index cd2532910c..b119d4abfc 100644 --- a/src/renderer/src/services/ConversationService.ts +++ b/src/renderer/src/services/ConversationService.ts @@ -1,5 +1,5 @@ -import { StreamTextParams } from '@cherrystudio/ai-core' import { convertMessagesToSdkMessages } from '@renderer/aiCore/transformParameters' +import type { StreamTextParams } from '@renderer/aiCore/types' import { Assistant, Message } from '@renderer/types' import { filterAdjacentUserMessaegs, filterLastAssistantMessage } from '@renderer/utils/messageUtils/filters' import { isEmpty, takeRight } from 'lodash'