refactor(aiCore): streamline type exports and enhance provider registration

Removed unused type exports from the aiCore module and consolidated type definitions for better clarity. Updated provider registration tests to reflect new configurations and improved error handling for non-existent providers. Enhanced the overall structure of the provider management system, ensuring better type safety and consistency across the codebase.
This commit is contained in:
MyPrototypeWhat 2025-08-28 15:26:34 +08:00
parent 4b7023f855
commit d10ba04047
15 changed files with 458 additions and 564 deletions

View File

@ -34,35 +34,50 @@ vi.mock('@ai-sdk/xai', () => ({
}))
import {
AiProviderRegistry,
cleanup,
getAllDynamicMappings,
getAllProviders,
getAllValidProviderIds,
getDynamicProviders,
getProvider,
getProviderMapping,
isDynamicProvider,
isProviderSupported,
registerDynamicProvider,
registerMultipleProviders,
clearAllProviders,
createAndRegisterProvider,
createProvider,
getAllProviderConfigAliases,
getAllProviderConfigs,
getInitializedProviders,
getLanguageModel,
getProviderConfig,
getProviderConfigByAlias,
getSupportedProviders,
hasInitializedProviders,
hasProviderConfig,
hasProviderConfigByAlias,
isProviderConfigAlias,
ProviderInitializationError,
providerRegistry,
registerMultipleProviderConfigs,
registerProvider,
validateProviderIdRegistry
registerProviderConfig,
resolveProviderConfigId
} from '../registry'
import type { DynamicProviderRegistration, ProviderConfig } from '../schemas'
import type { ProviderConfig } from '../schemas'
describe('AiProviderRegistry 功能测试', () => {
describe('Provider Registry 功能测试', () => {
beforeEach(() => {
// 清理状态
cleanup()
})
describe('基础功能', () => {
it('能够获取所有 providers', () => {
const providers = getAllProviders()
it('能够获取支持的 providers 列表', () => {
const providers = getSupportedProviders()
expect(Array.isArray(providers)).toBe(true)
expect(providers.length).toBeGreaterThan(0)
// 检查返回的数据结构
providers.forEach((provider) => {
expect(provider).toHaveProperty('id')
expect(provider).toHaveProperty('name')
expect(typeof provider.id).toBe('string')
expect(typeof provider.name).toBe('string')
})
// 包含基础 providers
const providerIds = providers.map((p) => p.id)
expect(providerIds).toContain('openai')
@ -70,74 +85,57 @@ describe('AiProviderRegistry 功能测试', () => {
expect(providerIds).toContain('google')
})
it('能够检查 provider 支持状态', () => {
expect(isProviderSupported('openai')).toBe(true)
expect(isProviderSupported('anthropic')).toBe(true)
expect(isProviderSupported('google')).toBe(true)
expect(isProviderSupported('non-existent')).toBe(true) // validateProviderId 通过
expect(isProviderSupported('')).toBe(false)
it('能够获取已初始化的 providers', () => {
// 初始状态下没有已初始化的 providers
expect(getInitializedProviders()).toEqual([])
expect(hasInitializedProviders()).toBe(false)
})
it('能够获取有效的 provider IDs', () => {
const allIds = getAllValidProviderIds()
expect(Array.isArray(allIds)).toBe(true)
expect(allIds).toContain('openai')
expect(allIds).toContain('anthropic')
it('能够访问全局注册管理器', () => {
expect(providerRegistry).toBeDefined()
expect(typeof providerRegistry.clear).toBe('function')
expect(typeof providerRegistry.getRegisteredProviders).toBe('function')
expect(typeof providerRegistry.hasProviders).toBe('function')
})
it('能够根据 ID 获取特定的 provider', () => {
// 获取存在的 provider
const openaiProvider = getProvider('openai')
expect(openaiProvider).toBeDefined()
expect(openaiProvider?.id).toBe('openai')
expect(openaiProvider?.name).toBe('OpenAI')
// 获取不存在的 providerfallback到openai-compatible
const nonExistentProvider = getProvider('non-existent')
expect(nonExistentProvider).toBeDefined()
expect(nonExistentProvider?.id).toBe('openai-compatible')
})
it('能够验证 provider ID', () => {
expect(validateProviderIdRegistry('valid-id')).toBe(true)
expect(validateProviderIdRegistry('another-valid-id')).toBe(true)
expect(validateProviderIdRegistry('')).toBe(false)
// 注意:单个空格字符被认为是有效的,因为它不是空字符串
// 如果需要更严格的验证schemas 包含更多验证规则
expect(validateProviderIdRegistry(' ')).toBe(true)
it('能够获取语言模型', () => {
// 在没有注册 provider 的情况下,这个函数可能会抛出错误或返回 undefined
expect(() => getLanguageModel('non-existent')).not.toThrow()
})
})
describe('动态 Provider 注册', () => {
it('能够注册动态 provider', () => {
const config = {
describe('Provider 配置注册', () => {
it('能够注册自定义 provider 配置', () => {
const config: ProviderConfig = {
id: 'custom-provider',
name: 'Custom Provider',
creator: vi.fn(() => ({ name: 'custom' })),
supportsImageGeneration: false
}
const success = registerDynamicProvider(config)
const success = registerProviderConfig(config)
expect(success).toBe(true)
expect(isDynamicProvider('custom-provider')).toBe(true)
expect(isProviderSupported('custom-provider')).toBe(true)
const allIds = getAllValidProviderIds()
expect(allIds).toContain('custom-provider')
expect(hasProviderConfig('custom-provider')).toBe(true)
expect(getProviderConfig('custom-provider')).toEqual(config)
})
it('拒绝与基础 provider 冲突的配置', () => {
const config = {
id: 'openai',
name: 'Duplicate OpenAI',
creator: vi.fn(() => ({ name: 'duplicate' })),
supportsImageGeneration: false
it('能够注册带别名的 provider 配置', () => {
const config: ProviderConfig = {
id: 'custom-provider-with-aliases',
name: 'Custom Provider with Aliases',
creator: vi.fn(() => ({ name: 'custom-aliased' })),
supportsImageGeneration: false,
aliases: ['alias-1', 'alias-2']
}
const success = registerDynamicProvider(config)
expect(success).toBe(false)
expect(isDynamicProvider('openai')).toBe(false)
const success = registerProviderConfig(config)
expect(success).toBe(true)
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
expect(getProviderConfigByAlias('alias-1')).toEqual(config)
expect(resolveProviderConfigId('alias-1')).toBe('custom-provider-with-aliases')
})
it('拒绝无效的配置', () => {
@ -147,12 +145,12 @@ describe('AiProviderRegistry 功能测试', () => {
// 缺少 name, creator 等
}
const success = registerDynamicProvider(invalidConfig as any)
const success = registerProviderConfig(invalidConfig as any)
expect(success).toBe(false)
})
it('能够批量注册动态 providers', () => {
const configs: DynamicProviderRegistration[] = [
it('能够批量注册 provider 配置', () => {
const configs: ProviderConfig[] = [
{
id: 'provider-1',
name: 'Provider 1',
@ -166,146 +164,146 @@ describe('AiProviderRegistry 功能测试', () => {
supportsImageGeneration: true
},
{
id: 'openai', // 这个失败,因为与基础 provider 冲突
id: '', // 无效配置
name: 'Invalid Provider',
creator: vi.fn(() => ({ name: 'invalid' })),
supportsImageGeneration: false
}
} as any
]
const successCount = registerMultipleProviders(configs)
const successCount = registerMultipleProviderConfigs(configs)
expect(successCount).toBe(2) // 只有前两个成功
expect(isDynamicProvider('provider-1')).toBe(true)
expect(isDynamicProvider('provider-2')).toBe(true)
expect(isDynamicProvider('openai')).toBe(false) // 基础 provider不是动态的
expect(hasProviderConfig('provider-1')).toBe(true)
expect(hasProviderConfig('provider-2')).toBe(true)
expect(hasProviderConfig('')).toBe(false)
})
it('支持带映射关系的动态 provider', () => {
const configWithMappings: DynamicProviderRegistration = {
id: 'custom-provider-with-mappings',
name: 'Custom Provider with Mappings',
creator: vi.fn(() => ({ name: 'custom-mapped' })),
it('能够获取所有配置和别名信息', () => {
// 注册一些配置
registerProviderConfig({
id: 'test-provider',
name: 'Test Provider',
creator: vi.fn(),
supportsImageGeneration: false,
mappings: {
'custom-alias-1': 'custom-provider-with-mappings',
'custom-alias-2': 'custom-provider-with-mappings'
}
aliases: ['test-alias']
})
const allConfigs = getAllProviderConfigs()
expect(Array.isArray(allConfigs)).toBe(true)
expect(allConfigs.some((config) => config.id === 'test-provider')).toBe(true)
const aliases = getAllProviderConfigAliases()
expect(aliases['test-alias']).toBe('test-provider')
expect(isProviderConfigAlias('test-alias')).toBe(true)
})
})
describe('Provider 创建和注册', () => {
it('能够创建 provider 实例', async () => {
const config: ProviderConfig = {
id: 'test-create-provider',
name: 'Test Create Provider',
creator: vi.fn(() => ({ name: 'test-created' })),
supportsImageGeneration: false
}
const success = registerDynamicProvider(configWithMappings)
// 先注册配置
registerProviderConfig(config)
// 创建 provider 实例
const provider = await createProvider('test-create-provider', { apiKey: 'test' })
expect(provider).toBeDefined()
expect(config.creator).toHaveBeenCalledWith({ apiKey: 'test' })
})
it('能够注册 provider 到全局管理器', () => {
const mockProvider = { name: 'mock-provider' }
const config: ProviderConfig = {
id: 'test-register-provider',
name: 'Test Register Provider',
creator: vi.fn(() => mockProvider),
supportsImageGeneration: false
}
// 先注册配置
registerProviderConfig(config)
// 注册 provider 到全局管理器
const success = registerProvider('test-register-provider', mockProvider)
expect(success).toBe(true)
// 验证映射关系
expect(getProviderMapping('custom-alias-1')).toBe('custom-provider-with-mappings')
expect(getProviderMapping('custom-alias-2')).toBe('custom-provider-with-mappings')
expect(getProviderMapping('custom-provider-with-mappings')).toBe('custom-provider-with-mappings')
// 验证注册成功
const registeredProviders = getInitializedProviders()
expect(registeredProviders).toContain('test-register-provider')
expect(hasInitializedProviders()).toBe(true)
})
it('能够一步完成创建和注册', async () => {
const config: ProviderConfig = {
id: 'test-create-and-register',
name: 'Test Create and Register',
creator: vi.fn(() => ({ name: 'test-both' })),
supportsImageGeneration: false
}
// 先注册配置
registerProviderConfig(config)
// 一步完成创建和注册
const success = await createAndRegisterProvider('test-create-and-register', { apiKey: 'test' })
expect(success).toBe(true)
// 验证注册成功
const registeredProviders = getInitializedProviders()
expect(registeredProviders).toContain('test-create-and-register')
})
})
describe('Registry 管理', () => {
it('能够清理动态 providers', () => {
// 注册动态 provider
registerDynamicProvider({
it('能够清理所有配置和注册的 providers', () => {
// 注册一些配置
registerProviderConfig({
id: 'temp-provider',
name: 'Temp Provider',
creator: vi.fn(() => ({ name: 'temp' })),
supportsImageGeneration: false
})
expect(isDynamicProvider('temp-provider')).toBe(true)
expect(hasProviderConfig('temp-provider')).toBe(true)
// 清理
cleanup()
expect(isDynamicProvider('temp-provider')).toBe(false)
expect(isProviderSupported('openai')).toBe(true) // 基础 providers 仍存在
expect(hasProviderConfig('temp-provider')).toBe(false)
// 但基础配置应该重新加载
expect(hasProviderConfig('openai')).toBe(true) // 基础 providers 会重新初始化
})
it('保持单例模式', () => {
const instance1 = AiProviderRegistry.getInstance()
const instance2 = AiProviderRegistry.getInstance()
expect(instance1).toBe(instance2)
it('能够单独清理已注册的 providers', () => {
// 清理所有 providers
clearAllProviders()
expect(getInitializedProviders()).toEqual([])
expect(hasInitializedProviders()).toBe(false)
})
it('能够注册基础 provider', () => {
const customConfig: ProviderConfig = {
id: 'custom-base-provider',
name: 'Custom Base Provider',
creator: vi.fn(() => ({ name: 'custom-base' })),
supportsImageGeneration: false
}
// 注册基础 provider 不抛出错误
expect(() => registerProvider(customConfig)).not.toThrow()
// 验证注册成功
const registeredProvider = getProvider('custom-base-provider')
expect(registeredProvider).toBeDefined()
expect(registeredProvider?.id).toBe('custom-base-provider')
expect(registeredProvider?.name).toBe('Custom Base Provider')
})
it('能够获取动态 providers 列表', () => {
// 初始状态没有动态 providers
expect(getDynamicProviders()).toEqual([])
// 注册一些动态 providers
registerDynamicProvider({
id: 'dynamic-1',
name: 'Dynamic 1',
creator: vi.fn(() => ({ name: 'dynamic-1' })),
supportsImageGeneration: false
})
registerDynamicProvider({
id: 'dynamic-2',
name: 'Dynamic 2',
creator: vi.fn(() => ({ name: 'dynamic-2' })),
supportsImageGeneration: true
})
const dynamicProviders = getDynamicProviders()
expect(Array.isArray(dynamicProviders)).toBe(true)
expect(dynamicProviders).toContain('dynamic-1')
expect(dynamicProviders).toContain('dynamic-2')
expect(dynamicProviders.length).toBe(2)
})
it('能够获取所有动态映射', () => {
// 初始状态没有动态映射
expect(getAllDynamicMappings()).toEqual({})
// 注册带映射的动态 provider
registerDynamicProvider({
id: 'mapped-provider',
name: 'Mapped Provider',
creator: vi.fn(() => ({ name: 'mapped' })),
supportsImageGeneration: false,
mappings: {
'alias-1': 'mapped-provider',
'alias-2': 'mapped-provider',
'custom-name': 'mapped-provider'
}
})
const allMappings = getAllDynamicMappings()
expect(allMappings).toEqual({
'alias-1': 'mapped-provider',
'alias-2': 'mapped-provider',
'custom-name': 'mapped-provider'
})
it('ProviderInitializationError 错误类工作正常', () => {
const error = new ProviderInitializationError('Test error', 'test-provider')
expect(error.message).toBe('Test error')
expect(error.providerId).toBe('test-provider')
expect(error.name).toBe('ProviderInitializationError')
})
})
describe('错误处理', () => {
it('优雅处理空配置', () => {
const success = registerDynamicProvider(null as any)
const success = registerProviderConfig(null as any)
expect(success).toBe(false)
})
it('优雅处理未定义配置', () => {
const success = registerDynamicProvider(undefined as any)
const success = registerProviderConfig(undefined as any)
expect(success).toBe(false)
})
@ -317,27 +315,31 @@ describe('AiProviderRegistry 功能测试', () => {
supportsImageGeneration: false
}
const success = registerDynamicProvider(config)
const success = registerProviderConfig(config)
expect(success).toBe(false)
})
it('处理注册基础 provider 时的无效 ID', () => {
const invalidConfig: ProviderConfig = {
id: '', // 无效 ID
name: 'Invalid Provider',
creator: vi.fn(() => ({ name: 'invalid' })),
supportsImageGeneration: false
}
expect(() => registerProvider(invalidConfig)).toThrow('Invalid provider ID:')
it('处理创建不存在配置的 provider', async () => {
await expect(createProvider('non-existent-provider', {})).rejects.toThrow(
'ProviderConfig not found for id: non-existent-provider'
)
})
it('处理获取不存在映射时的情况', () => {
expect(getProviderMapping('non-existent-mapping')).toBeUndefined()
it('处理注册不存在配置的 provider', () => {
const mockProvider = { name: 'mock' }
const success = registerProvider('non-existent-provider', mockProvider)
expect(success).toBe(false)
})
it('处理获取不存在配置的情况', () => {
expect(getProviderConfig('non-existent')).toBeUndefined()
expect(getProviderConfigByAlias('non-existent-alias')).toBeUndefined()
expect(hasProviderConfig('non-existent')).toBe(false)
expect(hasProviderConfigByAlias('non-existent-alias')).toBe(false)
})
it('处理批量注册时的部分失败', () => {
const mixedConfigs: DynamicProviderRegistration[] = [
const mixedConfigs: ProviderConfig[] = [
{
id: 'valid-provider-1',
name: 'Valid Provider 1',
@ -349,7 +351,7 @@ describe('AiProviderRegistry 功能测试', () => {
name: 'Invalid Provider',
creator: vi.fn(() => ({ name: 'invalid' })),
supportsImageGeneration: false
},
} as any,
{
id: 'valid-provider-2',
name: 'Valid Provider 2',
@ -358,153 +360,168 @@ describe('AiProviderRegistry 功能测试', () => {
}
]
const successCount = registerMultipleProviders(mixedConfigs)
const successCount = registerMultipleProviderConfigs(mixedConfigs)
expect(successCount).toBe(2) // 只有两个有效配置成功
expect(isDynamicProvider('valid-provider-1')).toBe(true)
expect(isDynamicProvider('valid-provider-2')).toBe(true)
expect(getDynamicProviders()).not.toContain('')
expect(hasProviderConfig('valid-provider-1')).toBe(true)
expect(hasProviderConfig('valid-provider-2')).toBe(true)
expect(hasProviderConfig('')).toBe(false)
})
it('处理动态导入失败的情况', async () => {
const config: ProviderConfig = {
id: 'import-test-provider',
name: 'Import Test Provider',
import: vi.fn().mockRejectedValue(new Error('Import failed')),
creatorFunctionName: 'createTest',
supportsImageGeneration: false
}
registerProviderConfig(config)
await expect(createProvider('import-test-provider', {})).rejects.toThrow('Import failed')
})
})
describe('集成测试', () => {
it('正确处理复杂的注册、映射和清理场景', () => {
it('正确处理复杂的配置、创建、注册和清理场景', async () => {
// 初始状态验证
const initialProviders = getAllProviders()
const initialIds = getAllValidProviderIds()
expect(initialProviders.length).toBeGreaterThan(0)
expect(getDynamicProviders()).toEqual([])
expect(getAllDynamicMappings()).toEqual({})
const initialConfigs = getAllProviderConfigs()
expect(initialConfigs.length).toBeGreaterThan(0) // 有基础配置
expect(getInitializedProviders()).toEqual([])
// 注册多个带映射的动态 providers
const configs: DynamicProviderRegistration[] = [
// 注册多个带别名的 provider 配置
const configs: ProviderConfig[] = [
{
id: 'integration-provider-1',
name: 'Integration Provider 1',
creator: vi.fn(() => ({ name: 'integration-1' })),
supportsImageGeneration: false,
mappings: {
'alias-1': 'integration-provider-1',
'short-name-1': 'integration-provider-1'
}
aliases: ['alias-1', 'short-name-1']
},
{
id: 'integration-provider-2',
name: 'Integration Provider 2',
creator: vi.fn(() => ({ name: 'integration-2' })),
supportsImageGeneration: true,
mappings: {
'alias-2': 'integration-provider-2',
'short-name-2': 'integration-provider-2'
}
aliases: ['alias-2', 'short-name-2']
}
]
const successCount = registerMultipleProviders(configs)
const successCount = registerMultipleProviderConfigs(configs)
expect(successCount).toBe(2)
// 验证注册后的状态
const afterRegisterProviders = getAllProviders()
const afterRegisterIds = getAllValidProviderIds()
expect(afterRegisterProviders.length).toBe(initialProviders.length + 2)
expect(afterRegisterIds.length).toBeGreaterThanOrEqual(initialIds.length + 2)
// 验证配置注册成功
expect(hasProviderConfig('integration-provider-1')).toBe(true)
expect(hasProviderConfig('integration-provider-2')).toBe(true)
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
// 验证动态 providers
const dynamicProviders = getDynamicProviders()
expect(dynamicProviders).toContain('integration-provider-1')
expect(dynamicProviders).toContain('integration-provider-2')
// 验证别名映射
const aliases = getAllProviderConfigAliases()
expect(aliases['alias-1']).toBe('integration-provider-1')
expect(aliases['alias-2']).toBe('integration-provider-2')
// 验证映射
const mappings = getAllDynamicMappings()
expect(mappings['alias-1']).toBe('integration-provider-1')
expect(mappings['alias-2']).toBe('integration-provider-2')
expect(mappings['short-name-1']).toBe('integration-provider-1')
expect(mappings['short-name-2']).toBe('integration-provider-2')
// 创建和注册 providers
const success1 = await createAndRegisterProvider('integration-provider-1', { apiKey: 'test1' })
const success2 = await createAndRegisterProvider('integration-provider-2', { apiKey: 'test2' })
expect(success1).toBe(true)
expect(success2).toBe(true)
// 验证通过映射能够获取 provider
expect(getProviderMapping('alias-1')).toBe('integration-provider-1')
expect(getProviderMapping('integration-provider-1')).toBe('integration-provider-1')
// 验证注册成功
const registeredProviders = getInitializedProviders()
expect(registeredProviders).toContain('integration-provider-1')
expect(registeredProviders).toContain('integration-provider-2')
expect(hasInitializedProviders()).toBe(true)
// 清理
cleanup()
// 验证清理后的状态
const afterCleanupProviders = getAllProviders()
const afterCleanupIds = getAllValidProviderIds()
expect(afterCleanupProviders.length).toBe(initialProviders.length)
expect(afterCleanupIds.length).toBe(initialIds.length)
expect(getDynamicProviders()).toEqual([])
expect(getAllDynamicMappings()).toEqual({})
expect(getInitializedProviders()).toEqual([])
expect(hasProviderConfig('integration-provider-1')).toBe(false)
expect(hasProviderConfig('integration-provider-2')).toBe(false)
expect(getAllProviderConfigAliases()).toEqual({})
// 基础配置应该重新加载
expect(hasProviderConfig('openai')).toBe(true)
})
it('正确处理 provider 的优先级和 fallback 机制', () => {
// 验证 getProvider 的 fallback 机制
const existingProvider = getProvider('openai')
expect(existingProvider?.id).toBe('openai')
const nonExistentProvider = getProvider('definitely-non-existent')
expect(nonExistentProvider?.id).toBe('openai-compatible') // fallback
// 注册自定义 provider 后能直接获取
registerDynamicProvider({
id: 'priority-test-provider',
name: 'Priority Test Provider',
creator: vi.fn(() => ({ name: 'priority-test' })),
it('正确处理动态导入配置的 provider', async () => {
const mockModule = { createCustomProvider: vi.fn(() => ({ name: 'custom-dynamic' })) }
const dynamicImportConfig: ProviderConfig = {
id: 'dynamic-import-provider',
name: 'Dynamic Import Provider',
import: vi.fn().mockResolvedValue(mockModule),
creatorFunctionName: 'createCustomProvider',
supportsImageGeneration: false
})
}
const customProvider = getProvider('priority-test-provider')
expect(customProvider?.id).toBe('priority-test-provider')
expect(customProvider?.name).toBe('Priority Test Provider')
// 注册配置
const configSuccess = registerProviderConfig(dynamicImportConfig)
expect(configSuccess).toBe(true)
// 创建和注册 provider
const registerSuccess = await createAndRegisterProvider('dynamic-import-provider', { apiKey: 'test' })
expect(registerSuccess).toBe(true)
// 验证导入函数被调用
expect(dynamicImportConfig.import).toHaveBeenCalled()
expect(mockModule.createCustomProvider).toHaveBeenCalledWith({ apiKey: 'test' })
// 验证注册成功
expect(getInitializedProviders()).toContain('dynamic-import-provider')
})
it('正确处理大量动态 providers 的注册和管理', () => {
const largeConfigList: DynamicProviderRegistration[] = []
it('正确处理大量配置的注册和管理', () => {
const largeConfigList: ProviderConfig[] = []
// 生成100个动态 providers
for (let i = 0; i < 100; i++) {
// 生成50个配置
for (let i = 0; i < 50; i++) {
largeConfigList.push({
id: `bulk-provider-${i}`,
name: `Bulk Provider ${i}`,
creator: vi.fn(() => ({ name: `bulk-${i}` })),
supportsImageGeneration: i % 2 === 0, // 偶数支持图像生成
mappings: {
[`alias-${i}`]: `bulk-provider-${i}`,
[`short-${i}`]: `bulk-provider-${i}`
}
aliases: [`alias-${i}`, `short-${i}`]
})
}
const successCount = registerMultipleProviders(largeConfigList)
expect(successCount).toBe(100)
const successCount = registerMultipleProviderConfigs(largeConfigList)
expect(successCount).toBe(50)
// 验证所有 providers 都被正确注册
const dynamicProviders = getDynamicProviders()
expect(dynamicProviders.length).toBe(100)
// 验证所有配置都被正确注册
const allConfigs = getAllProviderConfigs()
expect(allConfigs.filter((config) => config.id.startsWith('bulk-provider-')).length).toBe(50)
// 验证映射数量
const mappings = getAllDynamicMappings()
expect(Object.keys(mappings).length).toBe(200) // 每个 provider 有2个映射
// 验证别名数量
const aliases = getAllProviderConfigAliases()
const bulkAliases = Object.keys(aliases).filter(
(alias) => alias.startsWith('alias-') || alias.startsWith('short-')
)
expect(bulkAliases.length).toBe(100) // 每个 provider 有2个别名
// 随机验证几个 providers
expect(isDynamicProvider('bulk-provider-0')).toBe(true)
expect(isDynamicProvider('bulk-provider-50')).toBe(true)
expect(isDynamicProvider('bulk-provider-99')).toBe(true)
// 随机验证几个配置
expect(hasProviderConfig('bulk-provider-0')).toBe(true)
expect(hasProviderConfig('bulk-provider-25')).toBe(true)
expect(hasProviderConfig('bulk-provider-49')).toBe(true)
// 验证映射工作正常
expect(getProviderMapping('alias-25')).toBe('bulk-provider-25')
expect(getProviderMapping('short-75')).toBe('bulk-provider-75')
// 验证别名工作正常
expect(resolveProviderConfigId('alias-25')).toBe('bulk-provider-25')
expect(isProviderConfigAlias('short-30')).toBe(true)
// 清理能正确处理大量数据
cleanup()
expect(getDynamicProviders()).toEqual([])
expect(getAllDynamicMappings()).toEqual({})
const cleanupAliases = getAllProviderConfigAliases()
expect(
Object.keys(cleanupAliases).filter((alias) => alias.startsWith('alias-') || alias.startsWith('short-'))
).toEqual([])
})
})
describe('边界测试', () => {
it('处理包含特殊字符的 provider IDs', () => {
const specialCharsConfigs: DynamicProviderRegistration[] = [
const specialCharsConfigs: ProviderConfig[] = [
{
id: 'provider-with-dashes',
name: 'Provider With Dashes',
@ -525,26 +542,29 @@ describe('AiProviderRegistry 功能测试', () => {
}
]
const successCount = registerMultipleProviders(specialCharsConfigs)
const successCount = registerMultipleProviderConfigs(specialCharsConfigs)
expect(successCount).toBeGreaterThan(0) // 至少有一些成功
// 验证支持的特殊字符格式
if (isDynamicProvider('provider-with-dashes')) {
expect(getProvider('provider-with-dashes')).toBeDefined()
if (hasProviderConfig('provider-with-dashes')) {
expect(getProviderConfig('provider-with-dashes')).toBeDefined()
}
if (isDynamicProvider('provider_with_underscores')) {
expect(getProvider('provider_with_underscores')).toBeDefined()
if (hasProviderConfig('provider_with_underscores')) {
expect(getProviderConfig('provider_with_underscores')).toBeDefined()
}
})
it('处理空的批量注册', () => {
const successCount = registerMultipleProviders([])
const successCount = registerMultipleProviderConfigs([])
expect(successCount).toBe(0)
expect(getDynamicProviders()).toEqual([])
// 确保没有额外的配置被添加
const configsBefore = getAllProviderConfigs().length
expect(configsBefore).toBeGreaterThan(0) // 应该有基础配置
})
it('处理重复的 provider 注册', () => {
const config: DynamicProviderRegistration = {
it('处理重复的配置注册', () => {
const config: ProviderConfig = {
id: 'duplicate-test-provider',
name: 'Duplicate Test Provider',
creator: vi.fn(() => ({ name: 'duplicate' })),
@ -552,17 +572,61 @@ describe('AiProviderRegistry 功能测试', () => {
}
// 第一次注册成功
expect(registerDynamicProvider(config)).toBe(true)
expect(isDynamicProvider('duplicate-test-provider')).toBe(true)
expect(registerProviderConfig(config)).toBe(true)
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
// 重复注册相同的 provider
expect(registerDynamicProvider(config)).toBe(true) // 允许覆盖
expect(isDynamicProvider('duplicate-test-provider')).toBe(true)
// 重复注册相同的配置(允许覆盖)
const updatedConfig: ProviderConfig = {
...config,
name: 'Updated Duplicate Test Provider'
}
expect(registerProviderConfig(updatedConfig)).toBe(true)
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
// 验证只有一个实例
const dynamicProviders = getDynamicProviders()
const duplicateCount = dynamicProviders.filter((id) => id === 'duplicate-test-provider').length
expect(duplicateCount).toBe(1)
// 验证配置被更新
const retrievedConfig = getProviderConfig('duplicate-test-provider')
expect(retrievedConfig?.name).toBe('Updated Duplicate Test Provider')
})
it('处理极长的 ID 和名称', () => {
const longId = 'very-long-provider-id-' + 'x'.repeat(100)
const longName = 'Very Long Provider Name ' + 'Y'.repeat(100)
const config: ProviderConfig = {
id: longId,
name: longName,
creator: vi.fn(() => ({ name: 'long-test' })),
supportsImageGeneration: false
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
expect(hasProviderConfig(longId)).toBe(true)
const retrievedConfig = getProviderConfig(longId)
expect(retrievedConfig?.name).toBe(longName)
})
it('处理大量别名的配置', () => {
const manyAliases = Array.from({ length: 50 }, (_, i) => `alias-${i}`)
const config: ProviderConfig = {
id: 'provider-with-many-aliases',
name: 'Provider With Many Aliases',
creator: vi.fn(() => ({ name: 'many-aliases' })),
supportsImageGeneration: false,
aliases: manyAliases
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
// 验证所有别名都能正确解析
manyAliases.forEach((alias) => {
expect(hasProviderConfigByAlias(alias)).toBe(true)
expect(resolveProviderConfigId(alias)).toBe('provider-with-many-aliases')
expect(isProviderConfigAlias(alias)).toBe(true)
})
})
})
})

View File

@ -5,18 +5,11 @@ import {
baseProviderIds,
baseProviderIdSchema,
baseProviders,
type DynamicProviderId,
dynamicProviderIdSchema,
dynamicProviderRegistrationSchema,
getBaseProviderConfig,
isBaseProviderId,
isValidDynamicProviderId,
type CustomProviderId,
customProviderIdSchema,
providerConfigSchema,
type ProviderId,
providerIdSchema,
validateDynamicProviderRegistration,
validateProviderConfig,
validateProviderId
providerIdSchema
} from '../schemas'
describe('Provider Schemas', () => {
@ -90,22 +83,22 @@ describe('Provider Schemas', () => {
})
})
describe('dynamicProviderIdSchema', () => {
it('接受有效的动态 provider IDs', () => {
describe('customProviderIdSchema', () => {
it('接受有效的自定义 provider IDs', () => {
const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2']
validIds.forEach((id) => {
expect(dynamicProviderIdSchema.safeParse(id).success).toBe(true)
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
})
})
it('拒绝与基础 provider IDs 冲突的 IDs', () => {
baseProviderIds.forEach((id) => {
expect(dynamicProviderIdSchema.safeParse(id).success).toBe(false)
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
})
})
it('拒绝空字符串', () => {
expect(dynamicProviderIdSchema.safeParse('').success).toBe(false)
expect(customProviderIdSchema.safeParse('').success).toBe(false)
})
})
@ -116,9 +109,9 @@ describe('Provider Schemas', () => {
})
})
it('接受有效的动态 provider IDs', () => {
const validDynamicIds = ['custom-provider', 'my-ai-service']
validDynamicIds.forEach((id) => {
it('接受有效的自定义 provider IDs', () => {
const validCustomIds = ['custom-provider', 'my-ai-service']
validCustomIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
})
@ -134,8 +127,8 @@ describe('Provider Schemas', () => {
describe('providerConfigSchema', () => {
it('验证带有 creator 的有效配置', () => {
const validConfig = {
id: 'openai',
name: 'OpenAI',
id: 'custom-provider',
name: 'Custom Provider',
creator: vi.fn(),
supportsImageGeneration: true
}
@ -175,12 +168,21 @@ describe('Provider Schemas', () => {
}
})
it('拒绝使用基础 provider ID 的配置', () => {
const invalidConfig = {
id: 'openai', // 基础 provider ID
name: 'Should Fail',
creator: vi.fn()
}
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
})
it('拒绝缺少必需字段的配置', () => {
const invalidConfigs = [
{ name: 'Missing ID', creator: vi.fn() },
{ id: 'missing-name', creator: vi.fn() },
{ id: '', name: 'Empty ID', creator: vi.fn() },
{ id: 'valid', name: '', creator: vi.fn() }
{ id: 'valid-custom', name: '', creator: vi.fn() }
]
invalidConfigs.forEach((config) => {
@ -189,186 +191,55 @@ describe('Provider Schemas', () => {
})
})
describe('dynamicProviderRegistrationSchema', () => {
it('验证有效的动态 provider 注册配置', () => {
describe('Schema 验证功能', () => {
it('baseProviderIdSchema 正确验证基础 provider IDs', () => {
baseProviderIds.forEach((id) => {
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
})
expect(baseProviderIdSchema.safeParse('invalid-id').success).toBe(false)
})
it('customProviderIdSchema 正确验证自定义 provider IDs', () => {
const customIds = ['custom-provider', 'my-service', 'company-llm']
customIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
})
// 拒绝基础 provider IDs
baseProviderIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
})
})
it('providerIdSchema 接受基础和自定义 provider IDs', () => {
// 基础 IDs
baseProviderIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
// 自定义 IDs
const customIds = ['custom-provider', 'my-service']
customIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
})
it('providerConfigSchema 验证完整的 provider 配置', () => {
const validConfig = {
id: 'custom-provider',
name: 'Custom Provider',
creator: vi.fn(),
supportsImageGeneration: true,
mappings: { model1: 'mapped-model1' }
}
expect(dynamicProviderRegistrationSchema.safeParse(validConfig).success).toBe(true)
})
it('拒绝使用基础 provider ID 的配置', () => {
const invalidConfig = {
id: 'openai',
name: 'Should Fail',
creator: vi.fn()
}
expect(dynamicProviderRegistrationSchema.safeParse(invalidConfig).success).toBe(false)
})
it('要求 creator 或 import 配置', () => {
const configWithoutCreator = {
id: 'custom-provider',
name: 'Custom Provider'
}
expect(dynamicProviderRegistrationSchema.safeParse(configWithoutCreator).success).toBe(false)
})
})
describe('validateProviderId', () => {
it('验证基础 provider IDs', () => {
baseProviderIds.forEach((id) => {
expect(validateProviderId(id)).toBe(true)
})
})
it('验证有效的动态 provider IDs', () => {
const validDynamicIds = ['custom-provider', 'my-service', 'company-llm']
validDynamicIds.forEach((id) => {
expect(validateProviderId(id)).toBe(true)
})
})
it('拒绝无效的 IDs', () => {
const invalidIds = [undefined as any, null as any, 123 as any]
invalidIds.forEach((id) => {
expect(validateProviderId(id)).toBe(false)
})
// 空字符串和只有空格的字符串会被当作有效的动态 provider ID
expect(validateProviderId('')).toBe(false)
})
})
describe('isBaseProviderId', () => {
it('正确识别基础 provider IDs', () => {
baseProviderIds.forEach((id) => {
expect(isBaseProviderId(id)).toBe(true)
})
})
it('拒绝动态 provider IDs', () => {
const dynamicIds = ['custom-provider', 'my-service']
dynamicIds.forEach((id) => {
expect(isBaseProviderId(id)).toBe(false)
})
})
it('拒绝无效的 IDs', () => {
const invalidIds = ['', 'invalid', undefined as any]
invalidIds.forEach((id) => {
expect(isBaseProviderId(id)).toBe(false)
})
})
})
describe('isValidDynamicProviderId', () => {
it('接受有效的动态 provider IDs', () => {
const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2']
validIds.forEach((id) => {
expect(isValidDynamicProviderId(id)).toBe(true)
})
})
it('拒绝基础 provider IDs', () => {
baseProviderIds.forEach((id) => {
expect(isValidDynamicProviderId(id)).toBe(false)
})
})
it('拒绝无效的 IDs', () => {
const invalidIds = [undefined as any, null as any]
invalidIds.forEach((id) => {
expect(isValidDynamicProviderId(id)).toBe(false)
})
// 空字符串会被 schema 拒绝
expect(isValidDynamicProviderId('')).toBe(false)
// 只有空格的字符串是有效的动态 provider ID但不推荐使用
expect(isValidDynamicProviderId(' ')).toBe(true)
})
})
describe('validateProviderConfig', () => {
it('返回有效配置', () => {
const validConfig = {
id: 'openai',
name: 'OpenAI',
creator: vi.fn(),
supportsImageGeneration: true
}
const result = validateProviderConfig(validConfig)
expect(result).not.toBeNull()
expect(result?.id).toBe('openai')
expect(result?.name).toBe('OpenAI')
})
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
it('对无效配置返回 null', () => {
const invalidConfig = {
id: '',
name: 'Invalid'
id: 'openai', // 不允许基础 provider ID
name: 'OpenAI',
creator: vi.fn()
}
const result = validateProviderConfig(invalidConfig)
expect(result).toBeNull()
})
it('处理完全无效的输入', () => {
const invalidInputs = [undefined, null, 'string', 123, []]
invalidInputs.forEach((input) => {
const result = validateProviderConfig(input)
expect(result).toBeNull()
})
})
})
describe('validateDynamicProviderRegistration', () => {
it('返回有效的动态 provider 注册配置', () => {
const validConfig = {
id: 'custom-provider',
name: 'Custom Provider',
creator: vi.fn(),
mappings: { model1: 'mapped-model1' }
}
const result = validateDynamicProviderRegistration(validConfig)
expect(result).not.toBeNull()
expect(result?.id).toBe('custom-provider')
expect(result?.name).toBe('Custom Provider')
})
it('对无效配置返回 null', () => {
const invalidConfig = {
id: 'openai',
name: 'Should Fail'
}
const result = validateDynamicProviderRegistration(invalidConfig)
expect(result).toBeNull()
})
})
describe('getBaseProviderConfig', () => {
it('返回有效基础 provider ID 的配置', () => {
const config = getBaseProviderConfig('openai')
expect(config).toBeDefined()
expect(config?.id).toBe('openai')
expect(config?.name).toBe('OpenAI')
expect(config?.creator).toBeDefined()
})
it('对无效 ID 返回 undefined', () => {
const config = getBaseProviderConfig('invalid' as BaseProviderId)
expect(config).toBeUndefined()
})
it('返回所有基础 providers 的配置', () => {
baseProviderIds.forEach((id) => {
const config = getBaseProviderConfig(id)
expect(config).toBeDefined()
expect(config?.id).toBe(id)
})
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
})
})
@ -378,16 +249,16 @@ describe('Provider Schemas', () => {
expect(baseProviderIds).toContain(id)
})
it('DynamicProviderId 类型是字符串', () => {
const id: DynamicProviderId = 'custom-provider'
it('CustomProviderId 类型是字符串', () => {
const id: CustomProviderId = 'custom-provider'
expect(typeof id).toBe('string')
})
it('ProviderId 类型支持基础和动态 IDs', () => {
it('ProviderId 类型支持基础和自定义 IDs', () => {
const baseId: ProviderId = 'openai'
const dynamicId: ProviderId = 'custom-provider'
const customId: ProviderId = 'custom-provider'
expect(typeof baseId).toBe('string')
expect(typeof dynamicId).toBe('string')
expect(typeof customId).toBe('string')
})
})
})

View File

@ -2,8 +2,8 @@ import { ImageModelV2 } from '@ai-sdk/provider'
import { experimental_generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createImageModel } from '../../models/ModelCreator'
import { type AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { ImageGenerationError } from '../errors'
import { RuntimeExecutor } from '../executor'
@ -19,8 +19,10 @@ vi.mock('ai', () => ({
}
}))
vi.mock('../../models/ModelCreator', () => ({
createImageModel: vi.fn()
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
imageModel: vi.fn()
}
}))
describe('RuntimeExecutor.generateImage', () => {
@ -67,8 +69,11 @@ describe('RuntimeExecutor.generateImage', () => {
}
// Setup mocks
vi.mocked(createImageModel).mockResolvedValue(mockImageModel)
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
// Reset mock implementation in case it was changed by previous tests
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => mockImageModel)
})
describe('Basic functionality', () => {
@ -77,7 +82,7 @@ describe('RuntimeExecutor.generateImage', () => {
prompt: 'A futuristic cityscape at sunset'
})
expect(createImageModel).toHaveBeenCalledWith('openai', 'dall-e-3', { apiKey: 'test-key' })
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai:dall-e-3')
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
@ -92,7 +97,7 @@ describe('RuntimeExecutor.generateImage', () => {
prompt: 'A beautiful landscape'
})
// Note: createImageModel may still be called due to resolveImageModel logic
// Note: globalRegistryManagement.imageModel may still be called due to resolveImageModel logic
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A beautiful landscape'
@ -335,8 +340,10 @@ describe('RuntimeExecutor.generateImage', () => {
describe('Error handling', () => {
it('should handle model creation errors', async () => {
const modelError = new Error('Failed to create image model')
vi.mocked(createImageModel).mockRejectedValue(modelError)
const modelError = new Error('Failed to get image model')
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
throw modelError
})
await expect(
executor.generateImage('invalid-model', {
@ -431,7 +438,7 @@ describe('RuntimeExecutor.generateImage', () => {
prompt: 'A landscape'
})
expect(createImageModel).toHaveBeenCalledWith('google', 'imagen-3.0-generate-002', { apiKey: 'google-key' })
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google:imagen-3.0-generate-002')
})
it('should support xAI Grok image models', async () => {
@ -443,7 +450,7 @@ describe('RuntimeExecutor.generateImage', () => {
prompt: 'A futuristic robot'
})
expect(createImageModel).toHaveBeenCalledWith('xai', 'grok-2-image', { apiKey: 'xai-key' })
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai:grok-2-image')
})
})

View File

@ -24,57 +24,11 @@ export { createContext, definePlugin, PluginManager } from './core/plugins'
// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
export { PluginEngine } from './core/runtime/pluginEngine'
// ==================== 类型定义 ====================
export type { GenerateObjectParams, GenerateTextParams, StreamObjectParams, StreamTextParams } from './types'
// ==================== AI SDK 常用类型导出 ====================
// 直接导出 AI SDK 的常用类型,方便使用
export type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
export type { ToolCall } from '@ai-sdk/provider-utils'
export type { ReasoningPart } from '@ai-sdk/provider-utils'
export type {
AssistantModelMessage,
FilePart,
// 通用类型
FinishReason,
GenerateObjectResult,
// 生成相关类型
GenerateTextResult,
ImagePart,
InferToolInput,
InferToolOutput,
InvalidToolInputError,
LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage
// 消息相关类型
ModelMessage,
// 错误类型
NoSuchToolError,
ProviderMetadata,
StreamTextResult,
SystemModelMessage,
TextPart,
// 流相关类型
TextStreamPart,
// 工具相关类型
Tool,
ToolCallPart,
ToolModelMessage,
ToolResultPart,
ToolSet,
TypedToolCall,
TypedToolError,
TypedToolResult,
UserModelMessage
} from 'ai'
export {
defaultSettingsMiddleware,
extractReasoningMiddleware,
simulateStreamingMiddleware,
smoothStream,
stepCountIs
} from 'ai'
// 重新导出 Agent
export { Experimental_Agent as Agent } from 'ai'
// ==================== 选项 ====================
export {

View File

@ -1,9 +1,2 @@
import { generateObject, generateText, streamObject, streamText } from 'ai'
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
// 重新导出插件类型
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'

View File

@ -3,10 +3,10 @@
* AI SDK fullStream Cherry Studio chunk
*/
import { TextStreamPart, ToolSet } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import type { TextStreamPart, ToolSet } from 'ai'
import { ToolCallChunkHandler } from './handleToolCallChunk'

View File

@ -4,10 +4,10 @@
* API使
*/
import { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai'
// import type {
// AnthropicSearchOutput,
// WebSearchPluginConfig

View File

@ -8,7 +8,7 @@
* 3.
*/
import { createExecutor, generateImage, StreamTextParams } from '@cherrystudio/ai-core'
import { createExecutor, generateImage } from '@cherrystudio/ai-core'
import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
@ -23,6 +23,7 @@ import { CompletionsResult } from './legacy/middleware/schemas'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
import { buildPlugins } from './plugins/PluginBuilder'
import { getActualProvider, isModernSdkSupported, providerToAiSdkConfig } from './provider/ProviderConfigProcessor'
import type { StreamTextParams } from './types'
const logger = loggerService.withContext('ModernAiProvider')

View File

@ -1,11 +1,7 @@
import {
extractReasoningMiddleware,
LanguageModelV2Middleware,
simulateStreamingMiddleware
} from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import type { MCPTool, Model, Provider } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { extractReasoningMiddleware, LanguageModelMiddleware, simulateStreamingMiddleware } from 'ai'
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
@ -34,7 +30,7 @@ export interface AiSdkMiddlewareConfig {
*/
export interface NamedAiSdkMiddleware {
name: string
middleware: LanguageModelV2Middleware
middleware: LanguageModelMiddleware
}
/**
@ -83,7 +79,7 @@ export class AiSdkMiddlewareBuilder {
/**
*
*/
public build(): LanguageModelV2Middleware[] {
public build(): LanguageModelMiddleware[] {
return this.middlewares.map((m) => m.middleware)
}
@ -114,7 +110,7 @@ export class AiSdkMiddlewareBuilder {
* AI SDK中间件的工厂函数
*
*/
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV2Middleware[] {
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] {
const builder = new AiSdkMiddlewareBuilder()
// 1. 根据provider添加特定中间件

View File

@ -1,4 +1,5 @@
import { definePlugin, TextStreamPart, ToolSet } from '@cherrystudio/ai-core'
import { definePlugin } from '@cherrystudio/ai-core'
import type { TextStreamPart, ToolSet } from 'ai'
export default definePlugin({
name: 'reasoningTimePlugin',

View File

@ -6,8 +6,7 @@
* 2. transformParams: 根据意图分析结果动态添加对应的工具
* 3. onRequestEnd: 自动记忆存储
*/
import type { AiRequestContext, ModelMessage } from '@cherrystudio/ai-core'
import { definePlugin } from '@cherrystudio/ai-core'
import { type AiRequestContext, definePlugin } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
// import { generateObject } from '@cherrystudio/ai-core'
import {
@ -20,6 +19,7 @@ import store from '@renderer/store'
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
import type { Assistant } from '@renderer/types'
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
import type { ModelMessage } from 'ai'
import { isEmpty } from 'lodash'
import { MemoryProcessor } from '../../services/MemoryProcessor'

View File

@ -3,7 +3,6 @@
* apiClient
*/
import { stepCountIs, type StreamTextParams } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
@ -31,8 +30,10 @@ import {
} from '@renderer/utils/messageUtils/find'
import { defaultTimeout } from '@shared/config/constant'
import type { AssistantModelMessage, FilePart, ImagePart, ModelMessage, TextPart, UserModelMessage } from 'ai'
import { stepCountIs } from 'ai'
import { getAiSdkProviderId } from './provider/factory'
import type { StreamTextParams } from './types'
// import { webSearchTool } from './tools/WebSearchTool'
// import { jsonSchemaToZod } from 'json-schema-to-zod'
import { setupToolsConfig } from './utils/mcp'

View File

@ -0,0 +1,6 @@
import { generateObject, generateText, streamObject, streamText } from 'ai'
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>

View File

@ -1,12 +1,12 @@
/**
* API调用函数
*/
import { StreamTextParams } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import AiProvider from '@renderer/aiCore'
import { CompletionsParams } from '@renderer/aiCore/legacy/middleware/schemas'
import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/AiSdkMiddlewareBuilder'
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
import type { StreamTextParams } from '@renderer/aiCore/types'
import { isDedicatedImageGenerationModel, isEmbeddingModel, isQwenMTModel } from '@renderer/config/models'
import { LANG_DETECT_PROMPT } from '@renderer/config/prompts'
import { getStoreSetting } from '@renderer/hooks/useSettings'

View File

@ -1,5 +1,5 @@
import { StreamTextParams } from '@cherrystudio/ai-core'
import { convertMessagesToSdkMessages } from '@renderer/aiCore/transformParameters'
import type { StreamTextParams } from '@renderer/aiCore/types'
import { Assistant, Message } from '@renderer/types'
import { filterAdjacentUserMessaegs, filterLastAssistantMessage } from '@renderer/utils/messageUtils/filters'
import { isEmpty, takeRight } from 'lodash'