feat(provider): enhance provider registration and validation system

- Introduced a new Zod-based schema for provider validation, improving type safety and consistency.
- Added support for dynamic provider IDs and enhanced the provider registration process.
- Updated the AiProviderRegistry to utilize the new validation functions, ensuring robust provider management.
- Added tests for the provider registry to validate dynamic imports and functionality.
- Updated yarn.lock to reflect the latest dependency versions.
This commit is contained in:
MyPrototypeWhat 2025-08-18 19:41:34 +08:00
parent 356443babf
commit c9c0616c91
10 changed files with 471 additions and 116 deletions

View File

@ -0,0 +1,2 @@
// 模拟 Vite SSR helper避免 Node 环境找不到时报错
;(globalThis as any).__vite_ssr_exportName__ = (name: string, value: any) => value

View File

@ -0,0 +1,131 @@
/**
* registry -
*/
import { beforeEach, describe, expect, it, vi } from 'vitest'
// 模拟 AI SDK - 使用简单版本
vi.mock('@ai-sdk/openai', () => ({
createOpenAI: vi.fn(() => ({ name: 'openai-mock' }))
}))
vi.mock('@ai-sdk/anthropic', () => ({
createAnthropic: vi.fn(() => ({ name: 'anthropic-mock' }))
}))
vi.mock('@ai-sdk/azure', () => ({
createAzure: vi.fn(() => ({ name: 'azure-mock' }))
}))
vi.mock('@ai-sdk/deepseek', () => ({
createDeepSeek: vi.fn(() => ({ name: 'deepseek-mock' }))
}))
vi.mock('@ai-sdk/google', () => ({
createGoogleGenerativeAI: vi.fn(() => ({ name: 'google-mock' }))
}))
vi.mock('@ai-sdk/openai-compatible', () => ({
createOpenAICompatible: vi.fn(() => ({ name: 'openai-compatible-mock' }))
}))
vi.mock('@ai-sdk/xai', () => ({
createXai: vi.fn(() => ({ name: 'xai-mock' }))
}))
describe('Real Registry Test', () => {
beforeEach(() => {
// 清理模块缓存,强制重新加载
vi.resetModules()
})
it('应该能够通过动态导入访问真正的 registry', async () => {
console.log('🔍 Real test - Testing dynamic import...')
try {
// 使用动态导入,每次都重新导入
const { AiProviderRegistry } = await import('../registry')
console.log('🔍 Real test - AiProviderRegistry imported:', {
type: typeof AiProviderRegistry,
isClass: AiProviderRegistry?.prototype?.constructor === AiProviderRegistry
})
if (AiProviderRegistry) {
// 创建新实例,跳过单例模式
const testRegistry = Object.create(AiProviderRegistry.prototype)
// 手动调用构造函数逻辑,但跳过有问题的初始化
testRegistry.registry = new Map()
testRegistry.dynamicMappings = new Map()
testRegistry.dynamicProviders = new Set()
// 手动添加一些测试数据
testRegistry.registry.set('test-provider', {
id: 'test-provider',
name: 'Test Provider',
creator: () => ({ name: 'test' }),
supportsImageGeneration: false
})
// 测试基本功能
const allIds = testRegistry.getAllValidProviderIds?.()
console.log('🔍 Real test - getAllValidProviderIds result:', allIds)
if (allIds) {
expect(Array.isArray(allIds)).toBe(true)
expect(allIds).toContain('test-provider')
}
}
} catch (error) {
console.error('🔍 Real test - Error:', error)
throw error
}
})
it('应该能够通过模块原型访问方法', async () => {
console.log('🔍 Real test - Testing prototype access...')
try {
const registryModule = await import('../registry')
console.log('🔍 Real test - Registry module keys:', Object.keys(registryModule))
// 检查是否有任何可用的导出
const availableExports = Object.keys(registryModule).filter((key) => registryModule[key] !== undefined)
console.log('🔍 Real test - Available exports:', availableExports)
if (availableExports.length === 0) {
console.log('🔍 Real test - No exports available, trying alternative approach...')
// 尝试直接访问模块的内部结构
const moduleEntries = Object.entries(registryModule)
console.log('🔍 Real test - Module entries:', moduleEntries)
}
} catch (error) {
console.error('🔍 Real test - Prototype access error:', error)
}
})
it('应该能够通过 require 访问模块', async () => {
console.log('🔍 Real test - Testing require access...')
try {
// 尝试使用 require 而不是 import
const path = require('path')
const moduleId = path.resolve(__dirname, '../registry.ts')
console.log('🔍 Real test - Module ID:', moduleId)
// 检查模块是否在缓存中
const cached = require.cache[moduleId]
console.log('🔍 Real test - Module cached:', !!cached)
if (cached) {
console.log('🔍 Real test - Cached exports:', Object.keys(cached.exports || {}))
}
} catch (error) {
console.error('🔍 Real test - Require access error:', error)
}
})
})

View File

@ -3,13 +3,39 @@
*/
// Provider 注册表
export { aiProviderRegistry, getAllProviders, getProvider, isProviderSupported, registerProvider } from './registry'
export {
aiProviderRegistry,
getAllProviders,
getAllValidProviderIds,
getProvider,
isProviderSupported,
registerProvider,
validateProviderIdRegistry
} from './registry'
// Provider 创建
export { createImageProvider, createProvider, ProviderCreationError, validateProviderConfig } from './creator'
// 类型定义
export type { ProviderConfig, ProviderError, ProviderId, ProviderSettingsMap } from './types'
export type {
BaseProviderId,
DynamicProviderId,
DynamicProviderRegistration,
ProviderConfig,
ProviderError,
ProviderId,
ProviderSettingsMap
} from './types'
// Zod Schemas 和验证
export {
baseProviderIds,
baseProviders,
isBaseProviderId,
isValidDynamicProviderId,
validateDynamicProviderRegistration,
validateProviderId
} from './schemas'
// 工厂和配置
export * from './factory'

View File

@ -1,17 +1,18 @@
/**
* AI Provider
* +
* - 使 schemas
* -
* - Provider
*/
import { createAnthropic } from '@ai-sdk/anthropic'
import { createAzure } from '@ai-sdk/azure'
import { createDeepSeek } from '@ai-sdk/deepseek'
import { createGoogleGenerativeAI } from '@ai-sdk/google'
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
import { createXai } from '@ai-sdk/xai'
import { type ProviderConfig } from './types'
import {
baseProviders,
type DynamicProviderRegistration,
type ProviderConfig,
type ProviderId,
validateDynamicProviderRegistration,
validateProviderId
} from './schemas'
export class AiProviderRegistry {
private static instance: AiProviderRegistry
@ -33,64 +34,11 @@ export class AiProviderRegistry {
/**
* Providers
* AI SDK 官方文档: https://v5.ai-sdk.dev/providers/ai-sdk-providers
* 使 schemas baseProviders
*/
private initializeProviders(): void {
const providers: ProviderConfig[] = [
{
id: 'openai',
name: 'OpenAI',
creator: createOpenAI,
supportsImageGeneration: true
},
{
id: 'openai-responses',
name: 'OpenAI Responses',
creator: (options: OpenAIProviderSettings) => {
return createOpenAI(options).responses
},
supportsImageGeneration: true
},
{
id: 'openai-compatible',
name: 'OpenAI Compatible',
creator: createOpenAICompatible,
supportsImageGeneration: true
},
{
id: 'anthropic',
name: 'Anthropic',
creator: createAnthropic,
supportsImageGeneration: false
},
{
id: 'google',
name: 'Google Generative AI',
creator: createGoogleGenerativeAI,
supportsImageGeneration: true
},
{
id: 'xai',
name: 'xAI (Grok)',
creator: createXai,
supportsImageGeneration: true
},
{
id: 'azure',
name: 'Azure OpenAI',
creator: createAzure,
supportsImageGeneration: true
},
{
id: 'deepseek',
name: 'DeepSeek',
creator: createDeepSeek,
supportsImageGeneration: false
}
]
providers.forEach((config) => {
this.registry.set(config.id, config)
baseProviders.forEach((config) => {
this.registry.set(config.id, config as ProviderConfig)
})
}
@ -112,16 +60,22 @@ export class AiProviderRegistry {
* Provider
*/
public isSupported(id: string): boolean {
return this.registry.has(id)
// 首先检查是否在注册表中
if (this.registry.has(id)) {
return true
}
// 然后检查是否是有效的 provider ID可能是新的动态 provider
return validateProviderId(id)
}
/**
* Provider
*/
public registerProvider(config: ProviderConfig): void {
// 验证:必须提供 creator 或 (import + creatorFunctionName)
if (!config.creator && !(config.import && config.creatorFunctionName)) {
throw new Error('Must provide either creator function or import configuration')
// 使用 schemas 的验证函数
if (!validateProviderId(config.id)) {
throw new Error(`Invalid provider ID: ${config.id}`)
}
// 验证:不能同时提供两种方式
@ -135,27 +89,24 @@ export class AiProviderRegistry {
/**
* Provider并支持映射关系
*/
public registerDynamicProvider(
config: ProviderConfig & {
mappings?: Record<string, string>
}
): boolean {
public registerDynamicProvider(config: DynamicProviderRegistration): boolean {
try {
// 验证配置
if (!config.id || config.id.trim() === '') {
console.error('Provider ID cannot be empty')
// 使用 schemas 的验证函数
const validatedConfig = validateDynamicProviderRegistration(config)
if (!validatedConfig) {
console.error('Invalid dynamic provider configuration')
return false
}
// 注册provider
this.registerProvider(config)
this.registerProvider(validatedConfig)
// 记录为动态provider
this.dynamicProviders.add(config.id)
this.dynamicProviders.add(validatedConfig.id)
// 添加映射关系(如果提供)
if (config.mappings) {
Object.entries(config.mappings).forEach(([key, value]) => {
if (validatedConfig.mappings) {
Object.entries(validatedConfig.mappings).forEach(([key, value]) => {
this.dynamicMappings.set(key, value)
})
}
@ -170,11 +121,7 @@ export class AiProviderRegistry {
/**
* Providers
*/
public registerMultipleProviders(
configs: (ProviderConfig & {
mappings?: Record<string, string>
})[]
): number {
public registerMultipleProviders(configs: DynamicProviderRegistration[]): number {
let successCount = 0
configs.forEach((config) => {
if (this.registerDynamicProvider(config)) {
@ -213,10 +160,28 @@ export class AiProviderRegistry {
}
/**
*
* Provider IDs
*/
public getAllValidProviderIds(): string[] {
return [...Array.from(this.registry.keys()), ...this.dynamicProviders]
}
/**
* Provider ID
*/
public validateProviderId(id: string): boolean {
return validateProviderId(id)
}
/**
* -
*/
public cleanup(): void {
this.registry.clear()
this.dynamicProviders.clear()
this.dynamicMappings.clear()
// 重新初始化基础 providers
this.initializeProviders()
}
}
@ -228,16 +193,19 @@ export const getProvider = (id: string) => aiProviderRegistry.getProvider(id)
export const getAllProviders = () => aiProviderRegistry.getAllProviders()
export const isProviderSupported = (id: string) => aiProviderRegistry.isSupported(id)
export const registerProvider = (config: ProviderConfig) => aiProviderRegistry.registerProvider(config)
export const validateProviderIdRegistry = (id: string) => aiProviderRegistry.validateProviderId(id)
export const getAllValidProviderIds = () => aiProviderRegistry.getAllValidProviderIds()
// 动态注册相关便捷函数
export const registerDynamicProvider = (config: ProviderConfig & { mappings?: Record<string, string> }) =>
export const registerDynamicProvider = (config: DynamicProviderRegistration) =>
aiProviderRegistry.registerDynamicProvider(config)
export const registerMultipleProviders = (configs: (ProviderConfig & { mappings?: Record<string, string> })[]) =>
export const registerMultipleProviders = (configs: DynamicProviderRegistration[]) =>
aiProviderRegistry.registerMultipleProviders(configs)
export const getProviderMapping = (providerId: string) => aiProviderRegistry.getProviderMapping(providerId)
export const isDynamicProvider = (providerId: string) => aiProviderRegistry.isDynamicProvider(providerId)
export const getAllDynamicMappings = () => aiProviderRegistry.getAllDynamicMappings()
export const getDynamicProviders = () => aiProviderRegistry.getDynamicProviders()
export const cleanup = () => aiProviderRegistry.cleanup()
// 兼容现有实现的导出
// export const PROVIDER_REGISTRY = aiProviderRegistry.getCompatibleRegistry()
// 导出类型
export type { DynamicProviderRegistration, ProviderConfig, ProviderId }

View File

@ -0,0 +1,197 @@
/**
* Zod Provider
* -
* - Provider
* -
*/
import { createAnthropic } from '@ai-sdk/anthropic'
import { createAzure } from '@ai-sdk/azure'
import { createDeepSeek } from '@ai-sdk/deepseek'
import { createGoogleGenerativeAI } from '@ai-sdk/google'
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
import { createXai } from '@ai-sdk/xai'
import * as z from 'zod'
/**
* Providers
*
*/
export const baseProviders = [
{
id: 'openai',
name: 'OpenAI',
creator: createOpenAI,
supportsImageGeneration: true
},
{
id: 'openai-responses',
name: 'OpenAI Responses',
creator: (options: OpenAIProviderSettings) => createOpenAI(options).responses,
supportsImageGeneration: true
},
{
id: 'openai-compatible',
name: 'OpenAI Compatible',
creator: createOpenAICompatible,
supportsImageGeneration: true
},
{
id: 'anthropic',
name: 'Anthropic',
creator: createAnthropic,
supportsImageGeneration: false
},
{
id: 'google',
name: 'Google Generative AI',
creator: createGoogleGenerativeAI,
supportsImageGeneration: true
},
{
id: 'xai',
name: 'xAI (Grok)',
creator: createXai,
supportsImageGeneration: true
},
{
id: 'azure',
name: 'Azure OpenAI',
creator: createAzure,
supportsImageGeneration: true
},
{
id: 'deepseek',
name: 'DeepSeek',
creator: createDeepSeek,
supportsImageGeneration: false
}
] as const
/**
* Provider IDs
* baseProviders
*/
export const baseProviderIds = baseProviders.map((p) => p.id) as unknown as readonly [string, ...string[]]
/**
* Provider ID Schema
*/
export const baseProviderIdSchema = z.enum(baseProviderIds)
/**
* Provider ID Schema
* provider IDs
*/
export const dynamicProviderIdSchema = z
.string()
.min(1)
.refine((id) => !baseProviderIds.includes(id as any), {
message: 'Dynamic provider ID cannot conflict with base provider IDs'
})
/**
* Provider ID Schema
* providers +
*/
export const providerIdSchema = z.union([baseProviderIdSchema, dynamicProviderIdSchema])
/**
* Provider Schema
*/
export const providerConfigSchema = z
.object({
id: providerIdSchema,
name: z.string().min(1),
creator: z.function().optional(),
import: z.function().optional(),
creatorFunctionName: z.string().optional(),
supportsImageGeneration: z.boolean().default(false),
imageCreator: z.function().optional(),
validateOptions: z.function().optional()
})
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
message: 'Must provide either creator function or import configuration'
})
/**
* Provider Schema
*/
export const dynamicProviderRegistrationSchema = z
.object({
id: dynamicProviderIdSchema,
name: z.string().min(1),
creator: z.function().optional(),
import: z.function().optional(),
creatorFunctionName: z.string().optional(),
supportsImageGeneration: z.boolean().default(false),
imageCreator: z.function().optional(),
validateOptions: z.function().optional(),
mappings: z.record(z.string()).optional()
})
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
message: 'Must provide either creator function or import configuration'
})
// ===== 类型推导 =====
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
export type DynamicProviderId = z.infer<typeof dynamicProviderIdSchema>
export type ProviderId = z.infer<typeof providerIdSchema>
export type ProviderConfig = z.infer<typeof providerConfigSchema>
export type DynamicProviderRegistration = z.infer<typeof dynamicProviderRegistrationSchema>
// ===== 纯验证函数 =====
/**
* Provider ID
*/
export function validateProviderId(id: string): boolean {
return providerIdSchema.safeParse(id).success
}
/**
* Provider ID
*/
export function isBaseProviderId(id: string): id is BaseProviderId {
return baseProviderIdSchema.safeParse(id).success
}
/**
* Provider ID
*/
export function isValidDynamicProviderId(id: string): boolean {
return dynamicProviderIdSchema.safeParse(id).success
}
/**
* Provider
*/
export function validateProviderConfig(config: unknown): ProviderConfig | null {
const result = providerConfigSchema.safeParse(config)
if (result.success) {
return result.data
}
console.error('Invalid provider config:', result.error.errors)
return null
}
/**
* Provider
*/
export function validateDynamicProviderRegistration(config: unknown): DynamicProviderRegistration | null {
const result = dynamicProviderRegistrationSchema.safeParse(config)
if (result.success) {
return result.data
}
console.error('Invalid dynamic provider registration:', result.error.errors)
return null
}
/**
* Provider
*/
export function getBaseProviderConfig(id: BaseProviderId): (typeof baseProviders)[number] | undefined {
return baseProviders.find((p) => p.id === id)
}

View File

@ -6,6 +6,15 @@ import { type OpenAIProviderSettings } from '@ai-sdk/openai'
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
import { type XaiProviderSettings } from '@ai-sdk/xai'
// 导入基于 Zod 的 ProviderId 类型
import {
type BaseProviderId,
type DynamicProviderId,
type DynamicProviderRegistration,
type ProviderConfig,
type ProviderId as ZodProviderId
} from './schemas'
export interface ExtensibleProviderSettingsMap {
// 基础的静态providers
openai: OpenAIProviderSettings
@ -32,24 +41,24 @@ export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProvide
*/
// Provider 配置接口 - 支持灵活的创建方式
export interface ProviderConfig {
id: string
name: string
// export interface ProviderConfig {
// id: string
// name: string
// 方式一:直接提供 creator 函数(推荐用于自定义)
creator?: (options: any) => any
// // 方式一:直接提供 creator 函数(推荐用于自定义)
// creator?: (options: any) => any
// 方式二:动态导入 + 函数名(用于包导入)
import?: () => Promise<any>
creatorFunctionName?: string
// // 方式二:动态导入 + 函数名(用于包导入)
// import?: () => Promise<any>
// creatorFunctionName?: string
// 图片生成支持
supportsImageGeneration?: boolean
imageCreator?: (options: any) => any
// // 图片生成支持
// supportsImageGeneration?: boolean
// imageCreator?: (options: any) => any
// 可选的验证函数
validateOptions?: (options: any) => boolean
}
// // 可选的验证函数
// validateOptions?: (options: any) => boolean
// }
// 错误类型
export class ProviderError extends Error {
@ -64,8 +73,11 @@ export class ProviderError extends Error {
}
}
// 动态ProviderId类型 - 支持运行时扩展
export type ProviderId = keyof ExtensibleProviderSettingsMap | (string & {})
// 动态ProviderId类型 - 基于 Zod Schema支持运行时扩展和验证
export type ProviderId = ZodProviderId
// 重新导出相关类型
export type { BaseProviderId, DynamicProviderId, DynamicProviderRegistration, ProviderConfig }
// Provider类型注册工具
export interface ProviderTypeRegistrar {

View File

@ -123,6 +123,7 @@ export {
export {
getAllDynamicMappings,
getAllProviders,
getAllValidProviderIds,
getDynamicProviders,
getProvider,
getProviderMapping,
@ -131,9 +132,21 @@ export {
// 动态注册功能
registerDynamicProvider,
registerMultipleProviders,
registerProvider
registerProvider,
// Zod 验证相关
validateProviderIdRegistry
} from './core/providers/registry'
// ==================== Zod Schema 和验证 ====================
export {
type BaseProviderId,
baseProviderIds,
type DynamicProviderId,
type DynamicProviderRegistration,
validateDynamicProviderRegistration,
validateProviderId
} from './core/providers'
// ==================== Provider 配置工厂 ====================
export {
type BaseProviderConfig,

View File

@ -2,12 +2,17 @@ import { defineConfig } from 'vitest/config'
export default defineConfig({
test: {
globals: true,
environment: 'node',
globals: true
setupFiles: ['./setupVitest.ts'],
include: ['src/**/*.{test,spec}.{ts,tsx}', 'src/**/__tests__/**/*.{test,spec}.{ts,tsx}']
},
resolve: {
alias: {
'@': './src'
}
},
esbuild: {
target: 'node18'
}
})

View File

@ -2,9 +2,10 @@ import { aiSdk, Tool } from '@cherrystudio/ai-core'
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
import { MCPTool, MCPToolResponse } from '@renderer/types'
import { callMCPTool } from '@renderer/utils/mcp-tools'
import { tool } from 'ai'
import { JSONSchema7 } from 'json-schema'
const { tool } = aiSdk
// Setup tools configuration based on provided parameters
export function setupToolsConfig(mcpTools?: MCPTool[]): Record<string, Tool> | undefined {
let tools: Record<string, Tool> = {}

View File

@ -5014,9 +5014,9 @@ __metadata:
linkType: hard
"@opentelemetry/semantic-conventions@npm:^1.29.0":
version: 1.34.0
resolution: "@opentelemetry/semantic-conventions@npm:1.34.0"
checksum: 10c0/a51a32a5cf5c803bd2125a680d0abacbff632f3b255d0fe52379dac191114a0e8d72a34f9c46c5483ccfe91c4061c309f3cf61a19d11347e2a69779e82cfefd0
version: 1.36.0
resolution: "@opentelemetry/semantic-conventions@npm:1.36.0"
checksum: 10c0/edc8a6fe3ec4fc0c67ba3a92b86fb3dcc78fe1eb4f19838d8013c3232b9868540a034dd25cfe0afdd5eae752c5f0e9f42272ff46da144a2d5b35c644478e1c62
languageName: node
linkType: hard