From e4c0ea035f159413faef85d4257fd581e0033948 Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat Date: Wed, 25 Jun 2025 17:25:45 +0800 Subject: [PATCH] feat: enhance AI Core runtime with advanced model handling and middleware support - Introduced new high-level APIs for model creation and configuration, improving usability for advanced users. - Enhanced the RuntimeExecutor to support both direct model usage and model ID resolution, allowing for more flexible execution options. - Updated existing methods to accept middleware configurations, streamlining the integration of custom processing logic. - Refactored the plugin system to better accommodate middleware, enhancing the overall extensibility of the AI Core. - Improved documentation to reflect the new capabilities and usage patterns for the runtime APIs. --- packages/aiCore/src/core/runtime/executor.ts | 185 ++++++++++++------ packages/aiCore/src/core/runtime/index.ts | 30 +-- .../aiCore/src/core/runtime/pluginEngine.ts | 11 -- packages/aiCore/src/index.ts | 3 + src/renderer/src/aiCore/index_new.ts | 47 ++--- .../aisdk/AiSdkMiddlewareBuilder.ts | 36 +++- 6 files changed, 203 insertions(+), 109 deletions(-) diff --git a/packages/aiCore/src/core/runtime/executor.ts b/packages/aiCore/src/core/runtime/executor.ts index d6f1bd03e4..1df259a9ba 100644 --- a/packages/aiCore/src/core/runtime/executor.ts +++ b/packages/aiCore/src/core/runtime/executor.ts @@ -2,7 +2,7 @@ * 运行时执行器 * 专注于插件化的AI调用处理 */ -import { generateObject, generateText, LanguageModelV1, streamObject, streamText } from 'ai' +import { generateObject, generateText, LanguageModel, LanguageModelV1Middleware, streamObject, streamText } from 'ai' import { type ProviderId, type ProviderSettingsMap } from '../../types' import { createModel, getProviderInfo } from '../models' @@ -28,24 +28,43 @@ export class RuntimeExecutor { this.pluginClient = new PluginEngine(config.providerId, config.plugins || []) } + // === 高阶重载:直接使用模型 === + /** - * 流式文本生成 - 使用modelId自动创建模型 + * 流式文本生成 - 使用已创建的模型(高级用法) + */ + async streamText( + model: LanguageModel, + params: Omit[0], 'model'> + ): Promise> + + /** + * 流式文本生成 - 使用modelId + 可选middleware(灵活用法) */ async streamText( modelId: string, - params: Omit[0], 'model'> + params: Omit[0], 'model'>, + options?: { + middlewares?: LanguageModelV1Middleware[] + } + ): Promise> + + /** + * 流式文本生成 - 内部实现(统一处理重载) + */ + async streamText( + modelOrId: LanguageModel | string, + params: Omit[0], 'model'>, + options?: { + middlewares?: LanguageModelV1Middleware[] + } ): Promise> { - // 1. 使用 createModel 创建模型 - const model = await createModel({ - providerId: this.config.providerId, - modelId, - options: this.config.options - }) + const model = await this.resolveModel(modelOrId, options?.middlewares) // 2. 执行插件处理 return this.pluginClient.executeStreamWithPlugins( 'streamText', - modelId, + typeof modelOrId === 'string' ? modelOrId : model.modelId, params, async (finalModelId, transformedParams, streamTransforms) => { const experimental_transform = @@ -60,46 +79,42 @@ export class RuntimeExecutor { ) } + // === 其他方法的重载 === + /** - * 流式文本生成 - 直接使用已创建的模型 + * 生成文本 - 使用已创建的模型 */ - async streamTextWithModel( - model: LanguageModelV1, - params: Omit[0], 'model'> - ): Promise> { - return this.pluginClient.executeStreamWithPlugins( - 'streamText', - model.modelId, - params, - async (finalModelId, transformedParams, streamTransforms) => { - const experimental_transform = - params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined) - - return await streamText({ - model, - ...transformedParams, - experimental_transform - }) - } - ) - } + async generateText( + model: LanguageModel, + params: Omit[0], 'model'> + ): Promise> /** - * 生成文本 + * 生成文本 - 使用modelId + 可选middleware */ async generateText( modelId: string, - params: Omit[0], 'model'> + params: Omit[0], 'model'>, + options?: { + middlewares?: LanguageModelV1Middleware[] + } + ): Promise> + + /** + * 生成文本 - 内部实现 + */ + async generateText( + modelOrId: LanguageModel | string, + params: Omit[0], 'model'>, + options?: { + middlewares?: LanguageModelV1Middleware[] + } ): Promise> { - const model = await createModel({ - providerId: this.config.providerId, - modelId, - options: this.config.options - }) + const model = await this.resolveModel(modelOrId, options?.middlewares) return this.pluginClient.executeWithPlugins( 'generateText', - modelId, + typeof modelOrId === 'string' ? modelOrId : model.modelId, params, async (finalModelId, transformedParams) => { return await generateText({ model, ...transformedParams }) @@ -108,21 +123,39 @@ export class RuntimeExecutor { } /** - * 生成结构化对象 + * 生成结构化对象 - 使用已创建的模型 */ async generateObject( - modelId: string, + model: LanguageModel, params: Omit[0], 'model'> + ): Promise> + + /** + * 生成结构化对象 - 使用modelId + 可选middleware + */ + async generateObject( + modelOrId: string, + params: Omit[0], 'model'>, + options?: { + middlewares?: LanguageModelV1Middleware[] + } + ): Promise> + + /** + * 生成结构化对象 - 内部实现 + */ + async generateObject( + modelOrId: LanguageModel | string, + params: Omit[0], 'model'>, + options?: { + middlewares?: LanguageModelV1Middleware[] + } ): Promise> { - const model = await createModel({ - providerId: this.config.providerId, - modelId, - options: this.config.options - }) + const model = await this.resolveModel(modelOrId, options?.middlewares) return this.pluginClient.executeWithPlugins( 'generateObject', - modelId, + typeof modelOrId === 'string' ? modelOrId : model.modelId, params, async (finalModelId, transformedParams) => { return await generateObject({ model, ...transformedParams }) @@ -131,27 +164,69 @@ export class RuntimeExecutor { } /** - * 流式生成结构化对象 + * 流式生成结构化对象 - 使用已创建的模型 + */ + async streamObject( + model: LanguageModel, + params: Omit[0], 'model'> + ): Promise> + + /** + * 流式生成结构化对象 - 使用modelId + 可选middleware */ async streamObject( modelId: string, - params: Omit[0], 'model'> + params: Omit[0], 'model'>, + options?: { + middlewares?: LanguageModelV1Middleware[] + } + ): Promise> + + /** + * 流式生成结构化对象 - 内部实现 + */ + async streamObject( + modelOrId: LanguageModel | string, + params: Omit[0], 'model'>, + options?: { + middlewares?: LanguageModelV1Middleware[] + } ): Promise> { - const model = await createModel({ - providerId: this.config.providerId, - modelId, - options: this.config.options - }) + const model = await this.resolveModel(modelOrId, options?.middlewares) return this.pluginClient.executeWithPlugins( 'streamObject', - modelId, + typeof modelOrId === 'string' ? modelOrId : model.modelId, params, async (finalModelId, transformedParams) => { return await streamObject({ model, ...transformedParams }) } ) } + + // === 辅助方法 === + + /** + * 解析模型:如果是字符串则创建模型,如果是模型则直接返回 + */ + private async resolveModel( + modelOrId: LanguageModel | string, + middlewares?: LanguageModelV1Middleware[] + ): Promise { + if (typeof modelOrId === 'string') { + // 字符串modelId,需要创建模型 + return await createModel({ + providerId: this.config.providerId, + modelId: modelOrId, + options: this.config.options, + middlewares + }) + } else { + // 已经是模型,直接返回 + return modelOrId + } + } + /** * 获取客户端信息 */ diff --git a/packages/aiCore/src/core/runtime/index.ts b/packages/aiCore/src/core/runtime/index.ts index 6a8c9a06ee..8cb9285c3c 100644 --- a/packages/aiCore/src/core/runtime/index.ts +++ b/packages/aiCore/src/core/runtime/index.ts @@ -16,6 +16,8 @@ export type { // === 便捷工厂函数 === +import { LanguageModelV1Middleware } from 'ai' + import { type ProviderId, type ProviderSettingsMap } from '../../types' import { type AiPlugin } from '../plugins' import { RuntimeExecutor } from './executor' @@ -44,59 +46,63 @@ export function createOpenAICompatibleExecutor( // === 直接调用API(无需创建executor实例)=== /** - * 直接流式文本生成 + * 直接流式文本生成 - 支持middlewares */ export async function streamText( providerId: T, options: ProviderSettingsMap[T], modelId: string, params: Parameters['streamText']>[1], - plugins?: AiPlugin[] + plugins?: AiPlugin[], + middlewares?: LanguageModelV1Middleware[] ): Promise['streamText']>> { const executor = createExecutor(providerId, options, plugins) - return executor.streamText(modelId, params) + return executor.streamText(modelId, params, { middlewares }) } /** - * 直接生成文本 + * 直接生成文本 - 支持middlewares */ export async function generateText( providerId: T, options: ProviderSettingsMap[T], modelId: string, params: Parameters['generateText']>[1], - plugins?: AiPlugin[] + plugins?: AiPlugin[], + middlewares?: LanguageModelV1Middleware[] ): Promise['generateText']>> { const executor = createExecutor(providerId, options, plugins) - return executor.generateText(modelId, params) + return executor.generateText(modelId, params, { middlewares }) } /** - * 直接生成结构化对象 + * 直接生成结构化对象 - 支持middlewares */ export async function generateObject( providerId: T, options: ProviderSettingsMap[T], modelId: string, params: Parameters['generateObject']>[1], - plugins?: AiPlugin[] + plugins?: AiPlugin[], + middlewares?: LanguageModelV1Middleware[] ): Promise['generateObject']>> { const executor = createExecutor(providerId, options, plugins) - return executor.generateObject(modelId, params) + return executor.generateObject(modelId, params, { middlewares }) } /** - * 直接流式生成结构化对象 + * 直接流式生成结构化对象 - 支持middlewares */ export async function streamObject( providerId: T, options: ProviderSettingsMap[T], modelId: string, params: Parameters['streamObject']>[1], - plugins?: AiPlugin[] + plugins?: AiPlugin[], + middlewares?: LanguageModelV1Middleware[] ): Promise['streamObject']>> { const executor = createExecutor(providerId, options, plugins) - return executor.streamObject(modelId, params) + return executor.streamObject(modelId, params, { middlewares }) } // === Agent 功能预留 === diff --git a/packages/aiCore/src/core/runtime/pluginEngine.ts b/packages/aiCore/src/core/runtime/pluginEngine.ts index f944522b72..f63bb32eac 100644 --- a/packages/aiCore/src/core/runtime/pluginEngine.ts +++ b/packages/aiCore/src/core/runtime/pluginEngine.ts @@ -55,17 +55,6 @@ export class PluginEngine { return this.pluginManager.getPlugins() } - // /** - // * 使用core模块创建模型(包含中间件) - // */ - // async createModelWithMiddlewares(modelId: string): Promise { - // // 使用core模块的resolveConfig解析配置 - // const config = resolveConfig(this.providerId, modelId, this.options, this.pluginManager.getPlugins()) - - // // 使用core模块创建包装好的模型 - // return createModelFromConfig(config) - // } - /** * 执行带插件的操作(非流式) * 提供给AiExecutor使用 diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index 2c37c8b2e1..b6b6d3e013 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -15,6 +15,9 @@ import { ProviderId, type ProviderSettingsMap } from './types' // ==================== 主要用户接口 ==================== export { createExecutor, createOpenAICompatibleExecutor } from './core/runtime' +// ==================== 高级API ==================== +export { createModel, type ModelConfig } from './core/models' + // ==================== 插件系统 ==================== export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './core/plugins' export { createContext, definePlugin, PluginManager } from './core/plugins' diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 70d063fa1e..7b2b872b6f 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -9,8 +9,7 @@ */ import { - AiClient, - createClient, + createExecutor, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap, @@ -101,7 +100,7 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean { } export default class ModernAiProvider { - private modernClient?: AiClient + private modernExecutor?: ReturnType private legacyProvider: LegacyAiProvider private provider: Provider @@ -109,9 +108,10 @@ export default class ModernAiProvider { this.provider = provider this.legacyProvider = new LegacyAiProvider(provider) + // TODO:如果后续在调用completions时需要切换provider的话, // 初始化时不构建中间件,等到需要时再构建 const config = providerToAiSdkConfig(provider) - this.modernClient = createClient(config.providerId, config.options) + this.modernExecutor = createExecutor(config.providerId, config.options) } public async completions( @@ -145,7 +145,7 @@ export default class ModernAiProvider { params: StreamTextParams, middlewareConfig: AiSdkMiddlewareConfig ): Promise { - if (!this.modernClient) { + if (!this.modernExecutor) { throw new Error('Modern AI SDK client not initialized') } @@ -160,26 +160,25 @@ export default class ModernAiProvider { // 动态构建中间件数组 const middlewares = buildAiSdkMiddlewares(finalConfig) - console.log( - '构建的中间件:', - middlewares.map((m) => m.name) - ) - - // 创建带有中间件的客户端 - const config = providerToAiSdkConfig(this.provider) - const clientWithMiddlewares = createClient(config.providerId, config.options, middlewares) + console.log('构建的中间件:', middlewares.length) + // 创建带有中间件的执行器 if (middlewareConfig.onChunk) { // 流式处理 - 使用适配器 const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk) - const streamResult = await clientWithMiddlewares.streamText(modelId, { - ...params, - experimental_transform: smoothStream({ - delayInMs: 80, - // 中文3个字符一个chunk,英文一个单词一个chunk - chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/ - }) - }) + const streamResult = await this.modernExecutor.streamText( + modelId, + { + ...params, + experimental_transform: smoothStream({ + delayInMs: 80, + // 中文3个字符一个chunk,英文一个单词一个chunk + chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/ + }) + }, + middlewares.length > 0 ? { middlewares } : undefined + ) + const finalText = await adapter.processStream(streamResult) return { @@ -187,7 +186,11 @@ export default class ModernAiProvider { } } else { // 流式处理但没有 onChunk 回调 - const streamResult = await clientWithMiddlewares.streamText(modelId, params) + const streamResult = await this.modernExecutor.streamText( + modelId, + params, + middlewares.length > 0 ? { middlewares } : undefined + ) const finalText = await streamResult.text return { diff --git a/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts index 96c3626bef..f6756ef31a 100644 --- a/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts @@ -1,4 +1,8 @@ -import { AiPlugin, extractReasoningMiddleware, simulateStreamingMiddleware } from '@cherrystudio/ai-core' +import { + extractReasoningMiddleware, + LanguageModelV1Middleware, + simulateStreamingMiddleware +} from '@cherrystudio/ai-core' import { isReasoningModel } from '@renderer/config/models' import type { Model, Provider } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' @@ -21,7 +25,10 @@ export interface AiSdkMiddlewareConfig { /** * 具名的 AI SDK 中间件 */ -export type NamedAiSdkMiddleware = AiPlugin +export interface NamedAiSdkMiddleware { + name: string + middleware: LanguageModelV1Middleware +} /** * AI SDK 中间件建造者 @@ -69,7 +76,14 @@ export class AiSdkMiddlewareBuilder { /** * 构建最终的中间件数组 */ - public build(): NamedAiSdkMiddleware[] { + public build(): LanguageModelV1Middleware[] { + return this.middlewares.map((m) => m.middleware) + } + + /** + * 获取具名中间件数组(用于调试) + */ + public buildNamed(): NamedAiSdkMiddleware[] { return [...this.middlewares] } @@ -93,14 +107,14 @@ export class AiSdkMiddlewareBuilder { * 根据配置构建AI SDK中间件的工厂函数 * 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果 */ -export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): NamedAiSdkMiddleware[] { +export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV1Middleware[] { const builder = new AiSdkMiddlewareBuilder() // 1. 思考模型且有onChunk回调时添加思考时间中间件 if (config.onChunk && config.model && isReasoningModel(config.model)) { builder.add({ name: 'thinking-time', - aiSdkMiddlewares: [thinkingTimeMiddleware()] + middleware: thinkingTimeMiddleware() }) } @@ -121,7 +135,7 @@ export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): NamedAiSdk if (config.streamOutput === false) { builder.add({ name: 'simulate-streaming', - aiSdkMiddlewares: [simulateStreamingMiddleware()] + middleware: simulateStreamingMiddleware() }) } @@ -142,7 +156,7 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: case 'openai': builder.add({ name: 'thinking-tag-extraction', - aiSdkMiddlewares: [extractReasoningMiddleware({ tagName: 'think' })] + middleware: extractReasoningMiddleware({ tagName: 'think' }) }) break case 'gemini': @@ -183,8 +197,12 @@ export function createDefaultAiSdkMiddlewareBuilder(config: AiSdkMiddlewareConfi const builder = new AiSdkMiddlewareBuilder() const defaultMiddlewares = buildAiSdkMiddlewares(config) - defaultMiddlewares.forEach((middleware) => { - builder.add(middleware) + // 将普通中间件数组转换为具名中间件并添加 + defaultMiddlewares.forEach((middleware, index) => { + builder.add({ + name: `default-middleware-${index}`, + middleware + }) }) return builder