From c5cb443de0956f762572fe9f4a888d3078dff28c Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat Date: Fri, 20 Jun 2025 16:19:21 +0800 Subject: [PATCH] feat: enhance AI SDK documentation and client functionality - Added detailed usage examples for the native provider registry in the README.md, demonstrating how to create and utilize custom provider registries. - Updated ApiClientFactory to enforce type safety for model instances. - Refactored PluginEnabledAiClient methods to support both built-in logic and custom registry usage for text and object generation, improving flexibility and usability. --- packages/aiCore/README.md | 113 ++++++++++++++++++ .../aiCore/src/clients/ApiClientFactory.ts | 2 +- .../src/clients/PluginEnabledAiClient.ts | 92 ++++++++++---- 3 files changed, 181 insertions(+), 26 deletions(-) diff --git a/packages/aiCore/README.md b/packages/aiCore/README.md index 7e10c998f5..e400bd85d0 100644 --- a/packages/aiCore/README.md +++ b/packages/aiCore/README.md @@ -104,6 +104,119 @@ const googleClient = await createAiSdkClient('google', { apiKey: 'google-key' }) const xaiClient = await createAiSdkClient('xai', { apiKey: 'xai-key' }) ``` +### 使用 AI SDK 原生 Provider 注册表 + +> https://ai-sdk.dev/docs/reference/ai-sdk-core/provider-registry + +除了使用内建的 provider 管理,你还可以使用 AI SDK 原生的 `createProviderRegistry` 来构建自己的 provider 注册表。 + +#### 基本用法示例 + +```typescript +import { createClient } from '@cherrystudio/ai-core' +import { createProviderRegistry } from 'ai' +import { createOpenAI } from '@ai-sdk/openai' +import { anthropic } from '@ai-sdk/anthropic' + +// 1. 创建 AI SDK 原生注册表 +export const registry = createProviderRegistry({ + // register provider with prefix and default setup: + anthropic, + + // register provider with prefix and custom setup: + openai: createOpenAI({ + apiKey: process.env.OPENAI_API_KEY + }) +}) + +// 2. 创建client,'openai'可以传空或者传providerId(内建的provider) +const client = PluginEnabledAiClient.create('openai', { + apiKey: process.env.OPENAI_API_KEY +}) + +// 3. 方式1:使用内建逻辑(传统方式) +const result1 = await client.streamText('gpt-4', { + messages: [{ role: 'user', content: 'Hello with built-in logic!' }] +}) + +// 4. 方式2:使用自定义注册表(灵活方式) +const result2 = await client.streamText({ + model: registry.languageModel('openai:gpt-4'), + messages: [{ role: 'user', content: 'Hello with custom registry!' }] +}) + +// 5. 支持的重载方法 +await client.generateObject({ + model: registry.languageModel('openai:gpt-4'), + schema: z.object({ name: z.string() }), + messages: [{ role: 'user', content: 'Generate a user' }] +}) + +await client.streamObject({ + model: registry.languageModel('anthropic:claude-3-opus-20240229'), + schema: z.object({ items: z.array(z.string()) }), + messages: [{ role: 'user', content: 'Generate a list' }] +}) +``` + +#### 与插件系统配合使用 + +更强大的是,你还可以将自定义注册表与 Cherry Studio 的插件系统结合使用: + +```typescript +import { PluginEnabledAiClient } from '@cherrystudio/ai-core' +import { createProviderRegistry } from 'ai' +import { createOpenAI } from '@ai-sdk/openai' +import { anthropic } from '@ai-sdk/anthropic' + +// 1. 创建带插件的客户端 +const client = PluginEnabledAiClient.create( + 'openai', + { + apiKey: process.env.OPENAI_API_KEY + }, + [LoggingPlugin, RetryPlugin] +) + +// 2. 创建自定义注册表 +const registry = createProviderRegistry({ + openai: createOpenAI({ apiKey: process.env.OPENAI_API_KEY }), + anthropic: anthropic({ apiKey: process.env.ANTHROPIC_API_KEY }) +}) + +// 3. 方式1:使用内建逻辑 + 完整插件系统 +await client.streamText('gpt-4', { + messages: [{ role: 'user', content: 'Hello with plugins!' }] +}) + +// 4. 方式2:使用自定义注册表 + 有限插件支持 +await client.streamText({ + model: registry.languageModel('anthropic:claude-3-opus-20240229'), + messages: [{ role: 'user', content: 'Hello from Claude!' }] +}) + +// 5. 支持的方法 +await client.generateObject({ + model: registry.languageModel('openai:gpt-4'), + schema: z.object({ name: z.string() }), + messages: [{ role: 'user', content: 'Generate a user' }] +}) + +await client.streamObject({ + model: registry.languageModel('openai:gpt-4'), + schema: z.object({ items: z.array(z.string()) }), + messages: [{ role: 'user', content: 'Generate a list' }] +}) +``` + +#### 混合使用的优势 + +- **灵活性**:可以根据需要选择使用内建逻辑或自定义注册表 +- **兼容性**:完全兼容 AI SDK 的 `createProviderRegistry` API +- **渐进式**:可以逐步迁移现有代码,无需一次性重构 +- **插件支持**:自定义注册表仍可享受 Cherry Studio 插件系统的部分功能 +- **最佳实践**:结合两种方式的优点,既有动态加载的性能优势,又有统一注册表的便利性 + ## License MIT diff --git a/packages/aiCore/src/clients/ApiClientFactory.ts b/packages/aiCore/src/clients/ApiClientFactory.ts index 6c5e67edf6..edcb80046e 100644 --- a/packages/aiCore/src/clients/ApiClientFactory.ts +++ b/packages/aiCore/src/clients/ApiClientFactory.ts @@ -82,7 +82,7 @@ export class ApiClientFactory { // 返回模型实例 if (typeof provider === 'function') { - let model = provider(modelId) + let model: LanguageModelV1 = provider(modelId) // 应用 AI SDK 中间件 if (middlewares && middlewares.length > 0) { diff --git a/packages/aiCore/src/clients/PluginEnabledAiClient.ts b/packages/aiCore/src/clients/PluginEnabledAiClient.ts index 1f719f38f4..38895f7cd2 100644 --- a/packages/aiCore/src/clients/PluginEnabledAiClient.ts +++ b/packages/aiCore/src/clients/PluginEnabledAiClient.ts @@ -194,29 +194,40 @@ export class PluginEnabledAiClient { } /** - * 流式文本生成 - 集成插件系统 + * 流式文本生成 */ async streamText( modelId: string, params: Omit[0], 'model'> + ): Promise> + async streamText(params: Parameters[0]): Promise> + async streamText( + modelIdOrParams: string | Parameters[0], + params?: Omit[0], 'model'> ): Promise> { - return this.executeStreamWithPlugins( - 'streamText', - modelId, - params, - async (finalModelId, transformedParams, streamTransforms) => { - const model = await this.getModelWithMiddlewares(finalModelId) - return await streamText({ - model, - ...transformedParams, - experimental_transform: streamTransforms.length > 0 ? streamTransforms : undefined - }) - } - ) + if (typeof modelIdOrParams === 'string') { + // 传统方式:使用内建逻辑 + return this.executeStreamWithPlugins( + 'streamText', + modelIdOrParams, + params!, + async (finalModelId, transformedParams, streamTransforms) => { + const model = await this.getModelWithMiddlewares(finalModelId) + return await streamText({ + model, + ...transformedParams, + experimental_transform: streamTransforms.length > 0 ? streamTransforms : undefined + }) + } + ) + } else { + // 外部 registry 方式:直接使用用户提供的 model + return await streamText(modelIdOrParams) + } } /** - * 生成文本 - 集成插件系统 + * 生成文本 * 可能不需要了,因为内置模拟非流中间件 */ async generateText( @@ -230,29 +241,60 @@ export class PluginEnabledAiClient { } /** - * 生成结构化对象 - 集成插件系统 + * 生成结构化对象 */ async generateObject( modelId: string, params: Omit[0], 'model'> + ): Promise> + async generateObject(params: Parameters[0]): Promise> + async generateObject( + modelIdOrParams: string | Parameters[0], + params?: Omit[0], 'model'> ): Promise> { - return this.executeWithPlugins('generateObject', modelId, params, async (finalModelId, transformedParams) => { - const model = await this.getModelWithMiddlewares(finalModelId) - return await generateObject({ model, ...transformedParams }) - }) + if (typeof modelIdOrParams === 'string') { + // 传统方式:使用内建逻辑 + return this.executeWithPlugins( + 'generateObject', + modelIdOrParams, + params!, + async (finalModelId, transformedParams) => { + const model = await this.getModelWithMiddlewares(finalModelId) + return await generateObject({ model, ...transformedParams }) + } + ) + } else { + // 外部 registry 方式:直接使用用户提供的 model + return await generateObject(modelIdOrParams) + } } /** - * 流式生成结构化对象 - 集成插件系统 - * 注意:streamObject 目前不支持流转换器,所以使用普通的插件处理 + * 流式生成结构化对象 */ async streamObject( modelId: string, params: Omit[0], 'model'> + ): Promise> + async streamObject(params: Parameters[0]): Promise> + async streamObject( + modelIdOrParams: string | Parameters[0], + params?: Omit[0], 'model'> ): Promise> { - return this.executeWithPlugins('streamObject', modelId, params, async (finalModelId, transformedParams) => { - return await this.baseClient.streamObject(finalModelId, transformedParams) - }) + if (typeof modelIdOrParams === 'string') { + // 传统方式:使用内建逻辑 + return this.executeWithPlugins( + 'streamObject', + modelIdOrParams, + params!, + async (finalModelId, transformedParams) => { + return await this.baseClient.streamObject(finalModelId, transformedParams) + } + ) + } else { + // 外部 registry 方式:直接使用用户提供的 model + return await streamObject(modelIdOrParams) + } } /**