feat: enhance AI Core runtime with advanced model handling and middleware support

- Introduced new high-level APIs for model creation and configuration, improving usability for advanced users.
- Enhanced the RuntimeExecutor to support both direct model usage and model ID resolution, allowing for more flexible execution options.
- Updated existing methods to accept middleware configurations, streamlining the integration of custom processing logic.
- Refactored the plugin system to better accommodate middleware, enhancing the overall extensibility of the AI Core.
- Improved documentation to reflect the new capabilities and usage patterns for the runtime APIs.
This commit is contained in:
MyPrototypeWhat 2025-06-25 17:25:45 +08:00
parent 7d8ed3a737
commit e4c0ea035f
6 changed files with 203 additions and 109 deletions

View File

@ -2,7 +2,7 @@
*
* AI调用处理
*/
import { generateObject, generateText, LanguageModelV1, streamObject, streamText } from 'ai'
import { generateObject, generateText, LanguageModel, LanguageModelV1Middleware, streamObject, streamText } from 'ai'
import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { createModel, getProviderInfo } from '../models'
@ -28,24 +28,43 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
this.pluginClient = new PluginEngine(config.providerId, config.plugins || [])
}
// === 高阶重载:直接使用模型 ===
/**
* - 使modelId自动创建模型
* - 使
*/
async streamText(
model: LanguageModel,
params: Omit<Parameters<typeof streamText>[0], 'model'>
): Promise<ReturnType<typeof streamText>>
/**
* - 使modelId + middleware
*/
async streamText(
modelId: string,
params: Omit<Parameters<typeof streamText>[0], 'model'>
params: Omit<Parameters<typeof streamText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof streamText>>
/**
* -
*/
async streamText(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof streamText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof streamText>> {
// 1. 使用 createModel 创建模型
const model = await createModel({
providerId: this.config.providerId,
modelId,
options: this.config.options
})
const model = await this.resolveModel(modelOrId, options?.middlewares)
// 2. 执行插件处理
return this.pluginClient.executeStreamWithPlugins(
'streamText',
modelId,
typeof modelOrId === 'string' ? modelOrId : model.modelId,
params,
async (finalModelId, transformedParams, streamTransforms) => {
const experimental_transform =
@ -60,46 +79,42 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
)
}
// === 其他方法的重载 ===
/**
* - 使
* - 使
*/
async streamTextWithModel(
model: LanguageModelV1,
params: Omit<Parameters<typeof streamText>[0], 'model'>
): Promise<ReturnType<typeof streamText>> {
return this.pluginClient.executeStreamWithPlugins(
'streamText',
model.modelId,
params,
async (finalModelId, transformedParams, streamTransforms) => {
const experimental_transform =
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
return await streamText({
model,
...transformedParams,
experimental_transform
})
}
)
}
async generateText(
model: LanguageModel,
params: Omit<Parameters<typeof generateText>[0], 'model'>
): Promise<ReturnType<typeof generateText>>
/**
*
* - 使modelId + middleware
*/
async generateText(
modelId: string,
params: Omit<Parameters<typeof generateText>[0], 'model'>
params: Omit<Parameters<typeof generateText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof generateText>>
/**
* -
*/
async generateText(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof generateText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof generateText>> {
const model = await createModel({
providerId: this.config.providerId,
modelId,
options: this.config.options
})
const model = await this.resolveModel(modelOrId, options?.middlewares)
return this.pluginClient.executeWithPlugins(
'generateText',
modelId,
typeof modelOrId === 'string' ? modelOrId : model.modelId,
params,
async (finalModelId, transformedParams) => {
return await generateText({ model, ...transformedParams })
@ -108,21 +123,39 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
}
/**
*
* - 使
*/
async generateObject(
modelId: string,
model: LanguageModel,
params: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>>
/**
* - 使modelId + middleware
*/
async generateObject(
modelOrId: string,
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof generateObject>>
/**
* -
*/
async generateObject(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof generateObject>> {
const model = await createModel({
providerId: this.config.providerId,
modelId,
options: this.config.options
})
const model = await this.resolveModel(modelOrId, options?.middlewares)
return this.pluginClient.executeWithPlugins(
'generateObject',
modelId,
typeof modelOrId === 'string' ? modelOrId : model.modelId,
params,
async (finalModelId, transformedParams) => {
return await generateObject({ model, ...transformedParams })
@ -131,27 +164,69 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
}
/**
*
* - 使
*/
async streamObject(
model: LanguageModel,
params: Omit<Parameters<typeof streamObject>[0], 'model'>
): Promise<ReturnType<typeof streamObject>>
/**
* - 使modelId + middleware
*/
async streamObject(
modelId: string,
params: Omit<Parameters<typeof streamObject>[0], 'model'>
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof streamObject>>
/**
* -
*/
async streamObject(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV1Middleware[]
}
): Promise<ReturnType<typeof streamObject>> {
const model = await createModel({
providerId: this.config.providerId,
modelId,
options: this.config.options
})
const model = await this.resolveModel(modelOrId, options?.middlewares)
return this.pluginClient.executeWithPlugins(
'streamObject',
modelId,
typeof modelOrId === 'string' ? modelOrId : model.modelId,
params,
async (finalModelId, transformedParams) => {
return await streamObject({ model, ...transformedParams })
}
)
}
// === 辅助方法 ===
/**
*
*/
private async resolveModel(
modelOrId: LanguageModel | string,
middlewares?: LanguageModelV1Middleware[]
): Promise<LanguageModel> {
if (typeof modelOrId === 'string') {
// 字符串modelId需要创建模型
return await createModel({
providerId: this.config.providerId,
modelId: modelOrId,
options: this.config.options,
middlewares
})
} else {
// 已经是模型,直接返回
return modelOrId
}
}
/**
*
*/

View File

@ -16,6 +16,8 @@ export type {
// === 便捷工厂函数 ===
import { LanguageModelV1Middleware } from 'ai'
import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { type AiPlugin } from '../plugins'
import { RuntimeExecutor } from './executor'
@ -44,59 +46,63 @@ export function createOpenAICompatibleExecutor(
// === 直接调用API无需创建executor实例===
/**
*
* - middlewares
*/
export async function streamText<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
modelId: string,
params: Parameters<RuntimeExecutor<T>['streamText']>[1],
plugins?: AiPlugin[]
plugins?: AiPlugin[],
middlewares?: LanguageModelV1Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.streamText(modelId, params)
return executor.streamText(modelId, params, { middlewares })
}
/**
*
* - middlewares
*/
export async function generateText<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
modelId: string,
params: Parameters<RuntimeExecutor<T>['generateText']>[1],
plugins?: AiPlugin[]
plugins?: AiPlugin[],
middlewares?: LanguageModelV1Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateText(modelId, params)
return executor.generateText(modelId, params, { middlewares })
}
/**
*
* - middlewares
*/
export async function generateObject<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
modelId: string,
params: Parameters<RuntimeExecutor<T>['generateObject']>[1],
plugins?: AiPlugin[]
plugins?: AiPlugin[],
middlewares?: LanguageModelV1Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateObject(modelId, params)
return executor.generateObject(modelId, params, { middlewares })
}
/**
*
* - middlewares
*/
export async function streamObject<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
modelId: string,
params: Parameters<RuntimeExecutor<T>['streamObject']>[1],
plugins?: AiPlugin[]
plugins?: AiPlugin[],
middlewares?: LanguageModelV1Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.streamObject(modelId, params)
return executor.streamObject(modelId, params, { middlewares })
}
// === Agent 功能预留 ===

View File

@ -55,17 +55,6 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
return this.pluginManager.getPlugins()
}
// /**
// * 使用core模块创建模型包含中间件
// */
// async createModelWithMiddlewares(modelId: string): Promise<any> {
// // 使用core模块的resolveConfig解析配置
// const config = resolveConfig(this.providerId, modelId, this.options, this.pluginManager.getPlugins())
// // 使用core模块创建包装好的模型
// return createModelFromConfig(config)
// }
/**
*
* AiExecutor使用

View File

@ -15,6 +15,9 @@ import { ProviderId, type ProviderSettingsMap } from './types'
// ==================== 主要用户接口 ====================
export { createExecutor, createOpenAICompatibleExecutor } from './core/runtime'
// ==================== 高级API ====================
export { createModel, type ModelConfig } from './core/models'
// ==================== 插件系统 ====================
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './core/plugins'
export { createContext, definePlugin, PluginManager } from './core/plugins'

View File

@ -9,8 +9,7 @@
*/
import {
AiClient,
createClient,
createExecutor,
ProviderConfigFactory,
type ProviderId,
type ProviderSettingsMap,
@ -101,7 +100,7 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean {
}
export default class ModernAiProvider {
private modernClient?: AiClient
private modernExecutor?: ReturnType<typeof createExecutor>
private legacyProvider: LegacyAiProvider
private provider: Provider
@ -109,9 +108,10 @@ export default class ModernAiProvider {
this.provider = provider
this.legacyProvider = new LegacyAiProvider(provider)
// TODO:如果后续在调用completions时需要切换provider的话,
// 初始化时不构建中间件,等到需要时再构建
const config = providerToAiSdkConfig(provider)
this.modernClient = createClient(config.providerId, config.options)
this.modernExecutor = createExecutor(config.providerId, config.options)
}
public async completions(
@ -145,7 +145,7 @@ export default class ModernAiProvider {
params: StreamTextParams,
middlewareConfig: AiSdkMiddlewareConfig
): Promise<CompletionsResult> {
if (!this.modernClient) {
if (!this.modernExecutor) {
throw new Error('Modern AI SDK client not initialized')
}
@ -160,26 +160,25 @@ export default class ModernAiProvider {
// 动态构建中间件数组
const middlewares = buildAiSdkMiddlewares(finalConfig)
console.log(
'构建的中间件:',
middlewares.map((m) => m.name)
)
// 创建带有中间件的客户端
const config = providerToAiSdkConfig(this.provider)
const clientWithMiddlewares = createClient(config.providerId, config.options, middlewares)
console.log('构建的中间件:', middlewares.length)
// 创建带有中间件的执行器
if (middlewareConfig.onChunk) {
// 流式处理 - 使用适配器
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
const streamResult = await clientWithMiddlewares.streamText(modelId, {
...params,
experimental_transform: smoothStream({
delayInMs: 80,
// 中文3个字符一个chunk,英文一个单词一个chunk
chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/
})
})
const streamResult = await this.modernExecutor.streamText(
modelId,
{
...params,
experimental_transform: smoothStream({
delayInMs: 80,
// 中文3个字符一个chunk,英文一个单词一个chunk
chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/
})
},
middlewares.length > 0 ? { middlewares } : undefined
)
const finalText = await adapter.processStream(streamResult)
return {
@ -187,7 +186,11 @@ export default class ModernAiProvider {
}
} else {
// 流式处理但没有 onChunk 回调
const streamResult = await clientWithMiddlewares.streamText(modelId, params)
const streamResult = await this.modernExecutor.streamText(
modelId,
params,
middlewares.length > 0 ? { middlewares } : undefined
)
const finalText = await streamResult.text
return {

View File

@ -1,4 +1,8 @@
import { AiPlugin, extractReasoningMiddleware, simulateStreamingMiddleware } from '@cherrystudio/ai-core'
import {
extractReasoningMiddleware,
LanguageModelV1Middleware,
simulateStreamingMiddleware
} from '@cherrystudio/ai-core'
import { isReasoningModel } from '@renderer/config/models'
import type { Model, Provider } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
@ -21,7 +25,10 @@ export interface AiSdkMiddlewareConfig {
/**
* AI SDK
*/
export type NamedAiSdkMiddleware = AiPlugin
export interface NamedAiSdkMiddleware {
name: string
middleware: LanguageModelV1Middleware
}
/**
* AI SDK
@ -69,7 +76,14 @@ export class AiSdkMiddlewareBuilder {
/**
*
*/
public build(): NamedAiSdkMiddleware[] {
public build(): LanguageModelV1Middleware[] {
return this.middlewares.map((m) => m.middleware)
}
/**
*
*/
public buildNamed(): NamedAiSdkMiddleware[] {
return [...this.middlewares]
}
@ -93,14 +107,14 @@ export class AiSdkMiddlewareBuilder {
* AI SDK中间件的工厂函数
*
*/
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): NamedAiSdkMiddleware[] {
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV1Middleware[] {
const builder = new AiSdkMiddlewareBuilder()
// 1. 思考模型且有onChunk回调时添加思考时间中间件
if (config.onChunk && config.model && isReasoningModel(config.model)) {
builder.add({
name: 'thinking-time',
aiSdkMiddlewares: [thinkingTimeMiddleware()]
middleware: thinkingTimeMiddleware()
})
}
@ -121,7 +135,7 @@ export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): NamedAiSdk
if (config.streamOutput === false) {
builder.add({
name: 'simulate-streaming',
aiSdkMiddlewares: [simulateStreamingMiddleware()]
middleware: simulateStreamingMiddleware()
})
}
@ -142,7 +156,7 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
case 'openai':
builder.add({
name: 'thinking-tag-extraction',
aiSdkMiddlewares: [extractReasoningMiddleware({ tagName: 'think' })]
middleware: extractReasoningMiddleware({ tagName: 'think' })
})
break
case 'gemini':
@ -183,8 +197,12 @@ export function createDefaultAiSdkMiddlewareBuilder(config: AiSdkMiddlewareConfi
const builder = new AiSdkMiddlewareBuilder()
const defaultMiddlewares = buildAiSdkMiddlewares(config)
defaultMiddlewares.forEach((middleware) => {
builder.add(middleware)
// 将普通中间件数组转换为具名中间件并添加
defaultMiddlewares.forEach((middleware, index) => {
builder.add({
name: `default-middleware-${index}`,
middleware
})
})
return builder