feat: enhance AI Core runtime with advanced model handling and middleware support

- Introduced new high-level APIs for model creation and configuration, improving usability for advanced users.
- Enhanced the RuntimeExecutor to support both direct model usage and model ID resolution, allowing for more flexible execution options.
- Updated existing methods to accept middleware configurations, streamlining the integration of custom processing logic.
- Refactored the plugin system to better accommodate middleware, enhancing the overall extensibility of the AI Core.
- Improved documentation to reflect the new capabilities and usage patterns for the runtime APIs.
This commit is contained in:
MyPrototypeWhat 2025-06-25 17:25:45 +08:00
parent 7d8ed3a737
commit e4c0ea035f
6 changed files with 203 additions and 109 deletions

View File

@ -2,7 +2,7 @@
* *
* AI调用处理 * AI调用处理
*/ */
import { generateObject, generateText, LanguageModelV1, streamObject, streamText } from 'ai' import { generateObject, generateText, LanguageModel, LanguageModelV1Middleware, streamObject, streamText } from 'ai'
import { type ProviderId, type ProviderSettingsMap } from '../../types' import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { createModel, getProviderInfo } from '../models' import { createModel, getProviderInfo } from '../models'
@ -28,24 +28,43 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
this.pluginClient = new PluginEngine(config.providerId, config.plugins || []) this.pluginClient = new PluginEngine(config.providerId, config.plugins || [])
} }
// === 高阶重载:直接使用模型 ===
/** /**
* - 使modelId自动创建模型 * - 使
*/
async streamText(
model: LanguageModel,
params: Omit<Parameters<typeof streamText>[0], 'model'>
): Promise<ReturnType<typeof streamText>>
/**
* - 使modelId + middleware
*/ */
async streamText( async streamText(
modelId: string, modelId: string,
params: Omit<Parameters<typeof streamText>[0], 'model'> params: Omit<Parameters<typeof streamText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof streamText>>
/**
* -
*/
async streamText(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof streamText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof streamText>> { ): Promise<ReturnType<typeof streamText>> {
// 1. 使用 createModel 创建模型 const model = await this.resolveModel(modelOrId, options?.middlewares)
const model = await createModel({
providerId: this.config.providerId,
modelId,
options: this.config.options
})
// 2. 执行插件处理 // 2. 执行插件处理
return this.pluginClient.executeStreamWithPlugins( return this.pluginClient.executeStreamWithPlugins(
'streamText', 'streamText',
modelId, typeof modelOrId === 'string' ? modelOrId : model.modelId,
params, params,
async (finalModelId, transformedParams, streamTransforms) => { async (finalModelId, transformedParams, streamTransforms) => {
const experimental_transform = const experimental_transform =
@ -60,46 +79,42 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
) )
} }
// === 其他方法的重载 ===
/** /**
* - 使 * - 使
*/ */
async streamTextWithModel( async generateText(
model: LanguageModelV1, model: LanguageModel,
params: Omit<Parameters<typeof streamText>[0], 'model'> params: Omit<Parameters<typeof generateText>[0], 'model'>
): Promise<ReturnType<typeof streamText>> { ): Promise<ReturnType<typeof generateText>>
return this.pluginClient.executeStreamWithPlugins(
'streamText',
model.modelId,
params,
async (finalModelId, transformedParams, streamTransforms) => {
const experimental_transform =
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
return await streamText({
model,
...transformedParams,
experimental_transform
})
}
)
}
/** /**
* * - 使modelId + middleware
*/ */
async generateText( async generateText(
modelId: string, modelId: string,
params: Omit<Parameters<typeof generateText>[0], 'model'> params: Omit<Parameters<typeof generateText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof generateText>>
/**
* -
*/
async generateText(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof generateText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof generateText>> { ): Promise<ReturnType<typeof generateText>> {
const model = await createModel({ const model = await this.resolveModel(modelOrId, options?.middlewares)
providerId: this.config.providerId,
modelId,
options: this.config.options
})
return this.pluginClient.executeWithPlugins( return this.pluginClient.executeWithPlugins(
'generateText', 'generateText',
modelId, typeof modelOrId === 'string' ? modelOrId : model.modelId,
params, params,
async (finalModelId, transformedParams) => { async (finalModelId, transformedParams) => {
return await generateText({ model, ...transformedParams }) return await generateText({ model, ...transformedParams })
@ -108,21 +123,39 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
} }
/** /**
* * - 使
*/ */
async generateObject( async generateObject(
modelId: string, model: LanguageModel,
params: Omit<Parameters<typeof generateObject>[0], 'model'> params: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>>
/**
* - 使modelId + middleware
*/
async generateObject(
modelOrId: string,
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof generateObject>>
/**
* -
*/
async generateObject(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof generateObject>> { ): Promise<ReturnType<typeof generateObject>> {
const model = await createModel({ const model = await this.resolveModel(modelOrId, options?.middlewares)
providerId: this.config.providerId,
modelId,
options: this.config.options
})
return this.pluginClient.executeWithPlugins( return this.pluginClient.executeWithPlugins(
'generateObject', 'generateObject',
modelId, typeof modelOrId === 'string' ? modelOrId : model.modelId,
params, params,
async (finalModelId, transformedParams) => { async (finalModelId, transformedParams) => {
return await generateObject({ model, ...transformedParams }) return await generateObject({ model, ...transformedParams })
@ -131,27 +164,69 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
} }
/** /**
* * - 使
*/
async streamObject(
model: LanguageModel,
params: Omit<Parameters<typeof streamObject>[0], 'model'>
): Promise<ReturnType<typeof streamObject>>
/**
* - 使modelId + middleware
*/ */
async streamObject( async streamObject(
modelId: string, modelId: string,
params: Omit<Parameters<typeof streamObject>[0], 'model'> params: Omit<Parameters<typeof streamObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof streamObject>>
/**
* -
*/
async streamObject(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof streamObject>> { ): Promise<ReturnType<typeof streamObject>> {
const model = await createModel({ const model = await this.resolveModel(modelOrId, options?.middlewares)
providerId: this.config.providerId,
modelId,
options: this.config.options
})
return this.pluginClient.executeWithPlugins( return this.pluginClient.executeWithPlugins(
'streamObject', 'streamObject',
modelId, typeof modelOrId === 'string' ? modelOrId : model.modelId,
params, params,
async (finalModelId, transformedParams) => { async (finalModelId, transformedParams) => {
return await streamObject({ model, ...transformedParams }) return await streamObject({ model, ...transformedParams })
} }
) )
} }
// === 辅助方法 ===
/**
*
*/
private async resolveModel(
modelOrId: LanguageModel | string,
middlewares?: LanguageModelV1Middleware[]
): Promise<LanguageModel> {
if (typeof modelOrId === 'string') {
// 字符串modelId需要创建模型
return await createModel({
providerId: this.config.providerId,
modelId: modelOrId,
options: this.config.options,
middlewares
})
} else {
// 已经是模型,直接返回
return modelOrId
}
}
/** /**
* *
*/ */

View File

@ -16,6 +16,8 @@ export type {
// === 便捷工厂函数 === // === 便捷工厂函数 ===
import { LanguageModelV1Middleware } from 'ai'
import { type ProviderId, type ProviderSettingsMap } from '../../types' import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { type AiPlugin } from '../plugins' import { type AiPlugin } from '../plugins'
import { RuntimeExecutor } from './executor' import { RuntimeExecutor } from './executor'
@ -44,59 +46,63 @@ export function createOpenAICompatibleExecutor(
// === 直接调用API无需创建executor实例=== // === 直接调用API无需创建executor实例===
/** /**
* * - middlewares
*/ */
export async function streamText<T extends ProviderId>( export async function streamText<T extends ProviderId>(
providerId: T, providerId: T,
options: ProviderSettingsMap[T], options: ProviderSettingsMap[T],
modelId: string, modelId: string,
params: Parameters<RuntimeExecutor<T>['streamText']>[1], params: Parameters<RuntimeExecutor<T>['streamText']>[1],
plugins?: AiPlugin[] plugins?: AiPlugin[],
middlewares?: LanguageModelV1Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> { ): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
const executor = createExecutor(providerId, options, plugins) const executor = createExecutor(providerId, options, plugins)
return executor.streamText(modelId, params) return executor.streamText(modelId, params, { middlewares })
} }
/** /**
* * - middlewares
*/ */
export async function generateText<T extends ProviderId>( export async function generateText<T extends ProviderId>(
providerId: T, providerId: T,
options: ProviderSettingsMap[T], options: ProviderSettingsMap[T],
modelId: string, modelId: string,
params: Parameters<RuntimeExecutor<T>['generateText']>[1], params: Parameters<RuntimeExecutor<T>['generateText']>[1],
plugins?: AiPlugin[] plugins?: AiPlugin[],
middlewares?: LanguageModelV1Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> { ): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
const executor = createExecutor(providerId, options, plugins) const executor = createExecutor(providerId, options, plugins)
return executor.generateText(modelId, params) return executor.generateText(modelId, params, { middlewares })
} }
/** /**
* * - middlewares
*/ */
export async function generateObject<T extends ProviderId>( export async function generateObject<T extends ProviderId>(
providerId: T, providerId: T,
options: ProviderSettingsMap[T], options: ProviderSettingsMap[T],
modelId: string, modelId: string,
params: Parameters<RuntimeExecutor<T>['generateObject']>[1], params: Parameters<RuntimeExecutor<T>['generateObject']>[1],
plugins?: AiPlugin[] plugins?: AiPlugin[],
middlewares?: LanguageModelV1Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> { ): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
const executor = createExecutor(providerId, options, plugins) const executor = createExecutor(providerId, options, plugins)
return executor.generateObject(modelId, params) return executor.generateObject(modelId, params, { middlewares })
} }
/** /**
* * - middlewares
*/ */
export async function streamObject<T extends ProviderId>( export async function streamObject<T extends ProviderId>(
providerId: T, providerId: T,
options: ProviderSettingsMap[T], options: ProviderSettingsMap[T],
modelId: string, modelId: string,
params: Parameters<RuntimeExecutor<T>['streamObject']>[1], params: Parameters<RuntimeExecutor<T>['streamObject']>[1],
plugins?: AiPlugin[] plugins?: AiPlugin[],
middlewares?: LanguageModelV1Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> { ): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
const executor = createExecutor(providerId, options, plugins) const executor = createExecutor(providerId, options, plugins)
return executor.streamObject(modelId, params) return executor.streamObject(modelId, params, { middlewares })
} }
// === Agent 功能预留 === // === Agent 功能预留 ===

View File

@ -55,17 +55,6 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
return this.pluginManager.getPlugins() return this.pluginManager.getPlugins()
} }
// /**
// * 使用core模块创建模型包含中间件
// */
// async createModelWithMiddlewares(modelId: string): Promise<any> {
// // 使用core模块的resolveConfig解析配置
// const config = resolveConfig(this.providerId, modelId, this.options, this.pluginManager.getPlugins())
// // 使用core模块创建包装好的模型
// return createModelFromConfig(config)
// }
/** /**
* *
* AiExecutor使用 * AiExecutor使用

View File

@ -15,6 +15,9 @@ import { ProviderId, type ProviderSettingsMap } from './types'
// ==================== 主要用户接口 ==================== // ==================== 主要用户接口 ====================
export { createExecutor, createOpenAICompatibleExecutor } from './core/runtime' export { createExecutor, createOpenAICompatibleExecutor } from './core/runtime'
// ==================== 高级API ====================
export { createModel, type ModelConfig } from './core/models'
// ==================== 插件系统 ==================== // ==================== 插件系统 ====================
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './core/plugins' export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './core/plugins'
export { createContext, definePlugin, PluginManager } from './core/plugins' export { createContext, definePlugin, PluginManager } from './core/plugins'

View File

@ -9,8 +9,7 @@
*/ */
import { import {
AiClient, createExecutor,
createClient,
ProviderConfigFactory, ProviderConfigFactory,
type ProviderId, type ProviderId,
type ProviderSettingsMap, type ProviderSettingsMap,
@ -101,7 +100,7 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean {
} }
export default class ModernAiProvider { export default class ModernAiProvider {
private modernClient?: AiClient private modernExecutor?: ReturnType<typeof createExecutor>
private legacyProvider: LegacyAiProvider private legacyProvider: LegacyAiProvider
private provider: Provider private provider: Provider
@ -109,9 +108,10 @@ export default class ModernAiProvider {
this.provider = provider this.provider = provider
this.legacyProvider = new LegacyAiProvider(provider) this.legacyProvider = new LegacyAiProvider(provider)
// TODO:如果后续在调用completions时需要切换provider的话,
// 初始化时不构建中间件,等到需要时再构建 // 初始化时不构建中间件,等到需要时再构建
const config = providerToAiSdkConfig(provider) const config = providerToAiSdkConfig(provider)
this.modernClient = createClient(config.providerId, config.options) this.modernExecutor = createExecutor(config.providerId, config.options)
} }
public async completions( public async completions(
@ -145,7 +145,7 @@ export default class ModernAiProvider {
params: StreamTextParams, params: StreamTextParams,
middlewareConfig: AiSdkMiddlewareConfig middlewareConfig: AiSdkMiddlewareConfig
): Promise<CompletionsResult> { ): Promise<CompletionsResult> {
if (!this.modernClient) { if (!this.modernExecutor) {
throw new Error('Modern AI SDK client not initialized') throw new Error('Modern AI SDK client not initialized')
} }
@ -160,26 +160,25 @@ export default class ModernAiProvider {
// 动态构建中间件数组 // 动态构建中间件数组
const middlewares = buildAiSdkMiddlewares(finalConfig) const middlewares = buildAiSdkMiddlewares(finalConfig)
console.log( console.log('构建的中间件:', middlewares.length)
'构建的中间件:',
middlewares.map((m) => m.name)
)
// 创建带有中间件的客户端
const config = providerToAiSdkConfig(this.provider)
const clientWithMiddlewares = createClient(config.providerId, config.options, middlewares)
// 创建带有中间件的执行器
if (middlewareConfig.onChunk) { if (middlewareConfig.onChunk) {
// 流式处理 - 使用适配器 // 流式处理 - 使用适配器
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk) const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
const streamResult = await clientWithMiddlewares.streamText(modelId, { const streamResult = await this.modernExecutor.streamText(
...params, modelId,
experimental_transform: smoothStream({ {
delayInMs: 80, ...params,
// 中文3个字符一个chunk,英文一个单词一个chunk experimental_transform: smoothStream({
chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/ delayInMs: 80,
}) // 中文3个字符一个chunk,英文一个单词一个chunk
}) chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/
})
},
middlewares.length > 0 ? { middlewares } : undefined
)
const finalText = await adapter.processStream(streamResult) const finalText = await adapter.processStream(streamResult)
return { return {
@ -187,7 +186,11 @@ export default class ModernAiProvider {
} }
} else { } else {
// 流式处理但没有 onChunk 回调 // 流式处理但没有 onChunk 回调
const streamResult = await clientWithMiddlewares.streamText(modelId, params) const streamResult = await this.modernExecutor.streamText(
modelId,
params,
middlewares.length > 0 ? { middlewares } : undefined
)
const finalText = await streamResult.text const finalText = await streamResult.text
return { return {

View File

@ -1,4 +1,8 @@
import { AiPlugin, extractReasoningMiddleware, simulateStreamingMiddleware } from '@cherrystudio/ai-core' import {
extractReasoningMiddleware,
LanguageModelV1Middleware,
simulateStreamingMiddleware
} from '@cherrystudio/ai-core'
import { isReasoningModel } from '@renderer/config/models' import { isReasoningModel } from '@renderer/config/models'
import type { Model, Provider } from '@renderer/types' import type { Model, Provider } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk' import type { Chunk } from '@renderer/types/chunk'
@ -21,7 +25,10 @@ export interface AiSdkMiddlewareConfig {
/** /**
* AI SDK * AI SDK
*/ */
export type NamedAiSdkMiddleware = AiPlugin export interface NamedAiSdkMiddleware {
name: string
middleware: LanguageModelV1Middleware
}
/** /**
* AI SDK * AI SDK
@ -69,7 +76,14 @@ export class AiSdkMiddlewareBuilder {
/** /**
* *
*/ */
public build(): NamedAiSdkMiddleware[] { public build(): LanguageModelV1Middleware[] {
return this.middlewares.map((m) => m.middleware)
}
/**
*
*/
public buildNamed(): NamedAiSdkMiddleware[] {
return [...this.middlewares] return [...this.middlewares]
} }
@ -93,14 +107,14 @@ export class AiSdkMiddlewareBuilder {
* AI SDK中间件的工厂函数 * AI SDK中间件的工厂函数
* *
*/ */
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): NamedAiSdkMiddleware[] { export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV1Middleware[] {
const builder = new AiSdkMiddlewareBuilder() const builder = new AiSdkMiddlewareBuilder()
// 1. 思考模型且有onChunk回调时添加思考时间中间件 // 1. 思考模型且有onChunk回调时添加思考时间中间件
if (config.onChunk && config.model && isReasoningModel(config.model)) { if (config.onChunk && config.model && isReasoningModel(config.model)) {
builder.add({ builder.add({
name: 'thinking-time', name: 'thinking-time',
aiSdkMiddlewares: [thinkingTimeMiddleware()] middleware: thinkingTimeMiddleware()
}) })
} }
@ -121,7 +135,7 @@ export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): NamedAiSdk
if (config.streamOutput === false) { if (config.streamOutput === false) {
builder.add({ builder.add({
name: 'simulate-streaming', name: 'simulate-streaming',
aiSdkMiddlewares: [simulateStreamingMiddleware()] middleware: simulateStreamingMiddleware()
}) })
} }
@ -142,7 +156,7 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
case 'openai': case 'openai':
builder.add({ builder.add({
name: 'thinking-tag-extraction', name: 'thinking-tag-extraction',
aiSdkMiddlewares: [extractReasoningMiddleware({ tagName: 'think' })] middleware: extractReasoningMiddleware({ tagName: 'think' })
}) })
break break
case 'gemini': case 'gemini':
@ -183,8 +197,12 @@ export function createDefaultAiSdkMiddlewareBuilder(config: AiSdkMiddlewareConfi
const builder = new AiSdkMiddlewareBuilder() const builder = new AiSdkMiddlewareBuilder()
const defaultMiddlewares = buildAiSdkMiddlewares(config) const defaultMiddlewares = buildAiSdkMiddlewares(config)
defaultMiddlewares.forEach((middleware) => { // 将普通中间件数组转换为具名中间件并添加
builder.add(middleware) defaultMiddlewares.forEach((middleware, index) => {
builder.add({
name: `default-middleware-${index}`,
middleware
})
}) })
return builder return builder