feat: enhance AI SDK documentation and client functionality

- Added detailed usage examples for the native provider registry in the README.md, demonstrating how to create and utilize custom provider registries.
- Updated ApiClientFactory to enforce type safety for model instances.
- Refactored PluginEnabledAiClient methods to support both built-in logic and custom registry usage for text and object generation, improving flexibility and usability.
This commit is contained in:
MyPrototypeWhat 2025-06-20 16:19:21 +08:00
parent 9318d9ffeb
commit c5cb443de0
3 changed files with 181 additions and 26 deletions

View File

@ -104,6 +104,119 @@ const googleClient = await createAiSdkClient('google', { apiKey: 'google-key' })
const xaiClient = await createAiSdkClient('xai', { apiKey: 'xai-key' })
```
### 使用 AI SDK 原生 Provider 注册表
> https://ai-sdk.dev/docs/reference/ai-sdk-core/provider-registry
除了使用内建的 provider 管理,你还可以使用 AI SDK 原生的 `createProviderRegistry` 来构建自己的 provider 注册表。
#### 基本用法示例
```typescript
import { createClient } from '@cherrystudio/ai-core'
import { createProviderRegistry } from 'ai'
import { createOpenAI } from '@ai-sdk/openai'
import { anthropic } from '@ai-sdk/anthropic'
// 1. 创建 AI SDK 原生注册表
export const registry = createProviderRegistry({
// register provider with prefix and default setup:
anthropic,
// register provider with prefix and custom setup:
openai: createOpenAI({
apiKey: process.env.OPENAI_API_KEY
})
})
// 2. 创建client,'openai'可以传空或者传providerId(内建的provider)
const client = PluginEnabledAiClient.create('openai', {
apiKey: process.env.OPENAI_API_KEY
})
// 3. 方式1使用内建逻辑传统方式
const result1 = await client.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello with built-in logic!' }]
})
// 4. 方式2使用自定义注册表灵活方式
const result2 = await client.streamText({
model: registry.languageModel('openai:gpt-4'),
messages: [{ role: 'user', content: 'Hello with custom registry!' }]
})
// 5. 支持的重载方法
await client.generateObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ name: z.string() }),
messages: [{ role: 'user', content: 'Generate a user' }]
})
await client.streamObject({
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
schema: z.object({ items: z.array(z.string()) }),
messages: [{ role: 'user', content: 'Generate a list' }]
})
```
#### 与插件系统配合使用
更强大的是,你还可以将自定义注册表与 Cherry Studio 的插件系统结合使用:
```typescript
import { PluginEnabledAiClient } from '@cherrystudio/ai-core'
import { createProviderRegistry } from 'ai'
import { createOpenAI } from '@ai-sdk/openai'
import { anthropic } from '@ai-sdk/anthropic'
// 1. 创建带插件的客户端
const client = PluginEnabledAiClient.create(
'openai',
{
apiKey: process.env.OPENAI_API_KEY
},
[LoggingPlugin, RetryPlugin]
)
// 2. 创建自定义注册表
const registry = createProviderRegistry({
openai: createOpenAI({ apiKey: process.env.OPENAI_API_KEY }),
anthropic: anthropic({ apiKey: process.env.ANTHROPIC_API_KEY })
})
// 3. 方式1使用内建逻辑 + 完整插件系统
await client.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello with plugins!' }]
})
// 4. 方式2使用自定义注册表 + 有限插件支持
await client.streamText({
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
messages: [{ role: 'user', content: 'Hello from Claude!' }]
})
// 5. 支持的方法
await client.generateObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ name: z.string() }),
messages: [{ role: 'user', content: 'Generate a user' }]
})
await client.streamObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ items: z.array(z.string()) }),
messages: [{ role: 'user', content: 'Generate a list' }]
})
```
#### 混合使用的优势
- **灵活性**:可以根据需要选择使用内建逻辑或自定义注册表
- **兼容性**:完全兼容 AI SDK 的 `createProviderRegistry` API
- **渐进式**:可以逐步迁移现有代码,无需一次性重构
- **插件支持**:自定义注册表仍可享受 Cherry Studio 插件系统的部分功能
- **最佳实践**:结合两种方式的优点,既有动态加载的性能优势,又有统一注册表的便利性
## License
MIT

View File

@ -82,7 +82,7 @@ export class ApiClientFactory {
// 返回模型实例
if (typeof provider === 'function') {
let model = provider(modelId)
let model: LanguageModelV1 = provider(modelId)
// 应用 AI SDK 中间件
if (middlewares && middlewares.length > 0) {

View File

@ -194,29 +194,40 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
}
/**
* -
*
*/
async streamText(
modelId: string,
params: Omit<Parameters<typeof streamText>[0], 'model'>
): Promise<ReturnType<typeof streamText>>
async streamText(params: Parameters<typeof streamText>[0]): Promise<ReturnType<typeof streamText>>
async streamText(
modelIdOrParams: string | Parameters<typeof streamText>[0],
params?: Omit<Parameters<typeof streamText>[0], 'model'>
): Promise<ReturnType<typeof streamText>> {
return this.executeStreamWithPlugins(
'streamText',
modelId,
params,
async (finalModelId, transformedParams, streamTransforms) => {
const model = await this.getModelWithMiddlewares(finalModelId)
return await streamText({
model,
...transformedParams,
experimental_transform: streamTransforms.length > 0 ? streamTransforms : undefined
})
}
)
if (typeof modelIdOrParams === 'string') {
// 传统方式:使用内建逻辑
return this.executeStreamWithPlugins(
'streamText',
modelIdOrParams,
params!,
async (finalModelId, transformedParams, streamTransforms) => {
const model = await this.getModelWithMiddlewares(finalModelId)
return await streamText({
model,
...transformedParams,
experimental_transform: streamTransforms.length > 0 ? streamTransforms : undefined
})
}
)
} else {
// 外部 registry 方式:直接使用用户提供的 model
return await streamText(modelIdOrParams)
}
}
/**
* -
*
*
*/
async generateText(
@ -230,29 +241,60 @@ export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
}
/**
* -
*
*/
async generateObject(
modelId: string,
params: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>>
async generateObject(params: Parameters<typeof generateObject>[0]): Promise<ReturnType<typeof generateObject>>
async generateObject(
modelIdOrParams: string | Parameters<typeof generateObject>[0],
params?: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>> {
return this.executeWithPlugins('generateObject', modelId, params, async (finalModelId, transformedParams) => {
const model = await this.getModelWithMiddlewares(finalModelId)
return await generateObject({ model, ...transformedParams })
})
if (typeof modelIdOrParams === 'string') {
// 传统方式:使用内建逻辑
return this.executeWithPlugins(
'generateObject',
modelIdOrParams,
params!,
async (finalModelId, transformedParams) => {
const model = await this.getModelWithMiddlewares(finalModelId)
return await generateObject({ model, ...transformedParams })
}
)
} else {
// 外部 registry 方式:直接使用用户提供的 model
return await generateObject(modelIdOrParams)
}
}
/**
* -
* streamObject 使
*
*/
async streamObject(
modelId: string,
params: Omit<Parameters<typeof streamObject>[0], 'model'>
): Promise<ReturnType<typeof streamObject>>
async streamObject(params: Parameters<typeof streamObject>[0]): Promise<ReturnType<typeof streamObject>>
async streamObject(
modelIdOrParams: string | Parameters<typeof streamObject>[0],
params?: Omit<Parameters<typeof streamObject>[0], 'model'>
): Promise<ReturnType<typeof streamObject>> {
return this.executeWithPlugins('streamObject', modelId, params, async (finalModelId, transformedParams) => {
return await this.baseClient.streamObject(finalModelId, transformedParams)
})
if (typeof modelIdOrParams === 'string') {
// 传统方式:使用内建逻辑
return this.executeWithPlugins(
'streamObject',
modelIdOrParams,
params!,
async (finalModelId, transformedParams) => {
return await this.baseClient.streamObject(finalModelId, transformedParams)
}
)
} else {
// 外部 registry 方式:直接使用用户提供的 model
return await streamObject(modelIdOrParams)
}
}
/**