refactor: migrate to v5 patch-1

This commit is contained in:
suyao 2025-07-06 04:25:11 +08:00
parent cd42410d70
commit 3e5969b97c
No known key found for this signature in database
16 changed files with 223 additions and 245 deletions

View File

@ -46,7 +46,8 @@
"@openrouter/ai-sdk-provider": "^0.7.2",
"ai": "5.0.0-beta.7",
"anthropic-vertex-ai": "^1.0.2",
"ollama-ai-provider": "^1.2.0"
"ollama-ai-provider": "^1.2.0",
"zod": "^3.25.0"
},
"peerDependenciesMeta": {
"@ai-sdk/amazon-bedrock": {

View File

@ -19,5 +19,6 @@ export {
} from './models'
// 执行管理
export type { MCPRequestContext } from './plugins/built-in/mcpPromptPlugin'
export type { ExecutionOptions, ExecutorConfig } from './runtime'
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'

View File

@ -1,11 +1,11 @@
import { type AnthropicProviderOptions } from '@ai-sdk/anthropic'
import { type GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import { type LanguageModelV1ProviderMetadata } from '@ai-sdk/provider'
import { type SharedV2ProviderMetadata } from '@ai-sdk/provider'
import { type OpenRouterProviderOptions } from './openrouter'
export type ProviderOptions<T extends keyof LanguageModelV1ProviderMetadata> = LanguageModelV1ProviderMetadata[T]
export type ProviderOptions<T extends keyof SharedV2ProviderMetadata> = SharedV2ProviderMetadata[T]
/**
* map中没有
@ -28,4 +28,4 @@ export type TypedProviderOptions = {
[K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K]
} & {
[K in string]?: Record<string, any>
} & LanguageModelV1ProviderMetadata
} & SharedV2ProviderMetadata

View File

@ -3,8 +3,7 @@
* Function Call prompt
*
*/
import type { ToolSet } from 'ai'
import { ToolExecutionError } from 'ai'
import type { ModelMessage, TextStreamPart, ToolErrorUnion, ToolSet } from 'ai'
import { definePlugin } from '../index'
import type { AiRequestContext } from '../types'
@ -44,17 +43,17 @@ export interface MCPPromptConfig {
// 是否启用(用于运行时开关)
enabled?: boolean
// 自定义系统提示符构建函数(可选,有默认实现)
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => Promise<string>
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
// 自定义工具解析函数(可选,有默认实现)
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
createSystemMessage?: (systemPrompt: string, originalParams: any, context: MCPRequestContext) => string | null
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
}
/**
* AI MCP
*/
interface MCPRequestContext extends AiRequestContext {
mcpTools?: ToolSet
export interface MCPRequestContext extends AiRequestContext {
mcpTools: ToolSet
}
/**
@ -201,7 +200,7 @@ function buildAvailableTools(tools: ToolSet): string {
<name>${toolName}</name>
<description>${tool.description || ''}</description>
<arguments>
${tool.parameters ? JSON.stringify(tool.parameters) : ''}
${tool.inputSchema ? JSON.stringify(tool.inputSchema) : ''}
</arguments>
</tool>
`
@ -215,7 +214,7 @@ ${availableTools}
/**
* Cherry Studio
*/
async function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): Promise<string> {
function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): string {
const availableTools = buildAvailableTools(tools)
const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES)
@ -291,8 +290,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
return definePlugin({
name: 'built-in:mcp-prompt',
transformParams: async (params: any, context: MCPRequestContext) => {
transformParams: (params: any, context: AiRequestContext) => {
if (!enabled || !params.tools || typeof params.tools !== 'object') {
return params
}
@ -303,7 +301,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
// 构建系统提示符
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
const systemPrompt = await buildSystemPrompt(userSystemPrompt, params.tools)
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
let systemMessage: string | null = systemPrompt
console.log('config.context', context)
if (config.createSystemMessage) {
@ -320,25 +318,30 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
console.log('transformedParams', transformedParams)
return transformedParams
},
// 流式处理:监听 step-finish 事件并处理工具调用
transformStream: (_, context: MCPRequestContext) => () => {
transformStream: (_: any, context: AiRequestContext) => () => {
let textBuffer = ''
let stepId = ''
let executedResults: { toolCallId: string; toolName: string; result: any; isError?: boolean }[] = []
return new TransformStream<any>({
async transform(chunk, controller) {
if (!context.mcpTools) {
throw new Error('No tools available')
}
type TOOLS = NonNullable<typeof context.mcpTools>
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
async transform(
chunk: TextStreamPart<TOOLS>,
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
) {
// console.log('chunk', chunk)
// 收集文本内容
if (chunk.type === 'text-delta') {
textBuffer += chunk.textDelta || ''
if (chunk.type === 'text') {
textBuffer += chunk.text || ''
stepId = chunk.id || ''
// console.log('textBuffer', textBuffer)
controller.enqueue(chunk)
return
}
// 监听 step-finish 事件
if (chunk.type === 'step-finish' || chunk.type === 'finish') {
if (chunk.type === 'finish-step') {
// console.log('[MCP Prompt Stream] Received step-finish, checking for tool use...')
// 从 context 获取工具信息
@ -364,17 +367,11 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
// console.log('[MCP Prompt Stream] Found valid tool uses:', validToolUses.length)
// 修改 step-finish 事件,标记为工具调用
if (chunk.type !== 'finish') {
controller.enqueue({
...chunk,
finishReason: 'tool-call'
})
}
// 发送 step-start 事件(工具调用步骤开始)
controller.enqueue({
type: 'step-start'
type: 'start-step',
request: {},
warnings: []
})
// 执行工具调用
@ -392,7 +389,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
type: 'tool-call',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
args: toolUse.arguments
input: tool.inputSchema
})
const result = await tool.execute(toolUse.arguments, {
@ -406,8 +403,8 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
type: 'tool-result',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
args: toolUse.arguments,
result
input: toolUse.arguments,
output: result
})
executedResults.push({
@ -420,39 +417,36 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
// 使用 AI SDK 标准错误格式
const toolError = new ToolExecutionError({
toolName: toolUse.toolName,
toolArgs: toolUse.arguments,
const toolError: ToolErrorUnion<typeof context.mcpTools> = {
type: 'tool-error',
toolCallId: toolUse.id,
message: `Tool execution failed: ${error instanceof Error ? error.message : String(error)}`,
cause: error instanceof Error ? error : undefined
})
toolName: toolUse.toolName,
input: toolUse.arguments,
error: error instanceof Error ? error.message : String(error)
}
controller.enqueue(toolError)
// 发送标准错误事件
controller.enqueue({
type: 'error',
error: {
message: toolError.message,
name: toolError.name,
toolName: toolError.toolName,
toolCallId: toolError.toolCallId
}
error: toolError.error
})
// 发送 tool-result 错误事件
controller.enqueue({
type: 'tool-result',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
args: toolUse.arguments,
isError: true,
result: toolError.message
})
// // 发送 tool-result 错误事件
// controller.enqueue({
// type: 'tool-result',
// toolCallId: toolUse.id,
// toolName: toolUse.toolName,
// args: toolUse.arguments,
// isError: true,
// result: toolError.message
// })
executedResults.push({
toolCallId: toolUse.id,
toolName: toolUse.toolName,
result: toolError.message,
result: toolError.error,
isError: true
})
}
@ -460,8 +454,11 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
// 发送最终的 step-finish 事件
controller.enqueue({
type: 'step-finish',
finishReason: 'tool-call'
type: 'finish-step',
finishReason: 'tool-calls',
response: chunk.response,
usage: chunk.usage,
providerMetadata: chunk.providerMetadata
})
// 递归调用逻辑
@ -481,7 +478,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
.join('\n\n')
// console.log('context.originalParams.messages', context.originalParams.messages)
// 构建新的对话消息
const newMessages = [
const newMessages: ModelMessage[] = [
...(context.originalParams.messages || []),
{
role: 'assistant',
@ -540,8 +537,9 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
// 继续发送文本增量,保持流的连续性
controller.enqueue({
type: 'text-delta',
textDelta: '\n\n[工具执行后递归调用失败,继续对话...]'
type: 'text',
id: stepId,
text: '\n\n[工具执行后递归调用失败,继续对话...]'
})
}
}

View File

@ -1,12 +1,13 @@
// 核心类型和接口
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './types'
import { ProviderId } from '../providers/registry'
import type { AiPlugin, AiRequestContext } from './types'
// 插件管理器
export { PluginManager } from './manager'
// 工具函数
export function createContext(providerId: string, modelId: string, originalParams: any): AiRequestContext {
export function createContext(providerId: ProviderId, modelId: string, originalParams: any): AiRequestContext {
return {
providerId,
modelId,

View File

@ -20,6 +20,7 @@ export interface AiRequestContext {
requestId: string
recursiveCall: RecursiveCallFn
isRecursiveCall?: boolean
mcpTools?: ToolSet
[key: string]: any
}
@ -47,7 +48,7 @@ export interface AiPlugin {
transformStream?: (
params: any,
context: AiRequestContext
) => <TOOLS extends ToolSet>(options: {
) => <TOOLS extends ToolSet>(options?: {
tools: TOOLS
stopStream: () => void
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>

View File

@ -16,7 +16,7 @@ export type {
// === 便捷工厂函数 ===
import { LanguageModelV1Middleware } from 'ai'
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { type AiPlugin } from '../plugins'
@ -54,7 +54,7 @@ export async function streamText<T extends ProviderId>(
modelId: string,
params: Parameters<RuntimeExecutor<T>['streamText']>[1],
plugins?: AiPlugin[],
middlewares?: LanguageModelV1Middleware[]
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.streamText(modelId, params, { middlewares })
@ -69,7 +69,7 @@ export async function generateText<T extends ProviderId>(
modelId: string,
params: Parameters<RuntimeExecutor<T>['generateText']>[1],
plugins?: AiPlugin[],
middlewares?: LanguageModelV1Middleware[]
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateText(modelId, params, { middlewares })
@ -84,7 +84,7 @@ export async function generateObject<T extends ProviderId>(
modelId: string,
params: Parameters<RuntimeExecutor<T>['generateObject']>[1],
plugins?: AiPlugin[],
middlewares?: LanguageModelV1Middleware[]
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateObject(modelId, params, { middlewares })
@ -99,7 +99,7 @@ export async function streamObject<T extends ProviderId>(
modelId: string,
params: Parameters<RuntimeExecutor<T>['streamObject']>[1],
plugins?: AiPlugin[],
middlewares?: LanguageModelV1Middleware[]
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.streamObject(modelId, params, { middlewares })

View File

@ -68,8 +68,10 @@ export type {
TextStreamPart,
// 工具相关类型
Tool,
ToolCallUnion,
ToolModelMessage,
ToolResultPart,
ToolSet,
UserModelMessage
} from 'ai'
export { defaultSettingsMiddleware, extractReasoningMiddleware, simulateStreamingMiddleware, smoothStream } from 'ai'

View File

@ -3,7 +3,8 @@
* AI SDK fullStream Cherry Studio chunk
*/
import { TextStreamPart } from '@cherrystudio/ai-core'
import { TextStreamPart, ToolSet } from '@cherrystudio/ai-core'
import { MCPTool, WebSearchSource } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { ToolCallChunkHandler } from './chunk/handleTooCallChunk'
@ -24,8 +25,11 @@ export interface CherryStudioChunk {
*/
export class AiSdkToChunkAdapter {
toolCallHandler: ToolCallChunkHandler
constructor(private onChunk: (chunk: Chunk) => void) {
this.toolCallHandler = new ToolCallChunkHandler(onChunk)
constructor(
private onChunk: (chunk: Chunk) => void,
private mcpTools: MCPTool[] = []
) {
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
}
/**
@ -47,7 +51,7 @@ export class AiSdkToChunkAdapter {
* fullStream Cherry Studio chunks
* @param fullStream AI SDK fullStream (ReadableStream)
*/
private async readFullStream(fullStream: ReadableStream<TextStreamPart<any>>) {
private async readFullStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>) {
const reader = fullStream.getReader()
const final = {
text: '',
@ -73,84 +77,39 @@ export class AiSdkToChunkAdapter {
* AI SDK chunk Cherry Studio chunk
* @param chunk AI SDK chunk
*/
private convertAndEmitChunk(chunk: any, final: { text: string; reasoning_content: string }) {
private convertAndEmitChunk(chunk: TextStreamPart<any>, final: { text: string; reasoning_content: string }) {
console.log('AI SDK chunk type:', chunk.type, chunk)
switch (chunk.type) {
// === 文本相关事件 ===
case 'text-delta':
final.text += chunk.textDelta || ''
case 'text':
final.text += chunk.text || ''
this.onChunk({
type: ChunkType.TEXT_DELTA,
text: chunk.textDelta || ''
text: chunk.text || ''
})
break
case 'text-end':
this.onChunk({
type: ChunkType.TEXT_COMPLETE,
text: final.text || ''
})
break
case 'reasoning':
this.onChunk({
type: ChunkType.THINKING_DELTA,
text: chunk.textDelta || '',
// 自定义字段
thinking_millsec: chunk.thinking_millsec || 0
text: chunk.text || '',
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
})
break
case 'redacted-reasoning':
this.onChunk({
type: ChunkType.THINKING_DELTA,
text: chunk.data || ''
})
break
case 'reasoning-signature':
case 'reasoning-end':
this.onChunk({
type: ChunkType.THINKING_COMPLETE,
text: chunk.text || '',
thinking_millsec: chunk.thinking_millsec || 0
text: (chunk.providerMetadata?.metadata?.thinking_content as string) || '',
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
})
break
// === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) ===
case 'tool-call-streaming-start':
// 开始流式工具调用
this.onChunk({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: [
{
id: chunk.toolCallId,
name: chunk.toolName,
args: {}
}
]
})
break
case 'tool-call-delta':
// 工具调用参数的增量更新
this.onChunk({
type: ChunkType.MCP_TOOL_IN_PROGRESS,
responses: [
{
id: chunk.toolCallId,
tool: {
id: chunk.toolName,
// TODO: serverId,serverName
serverId: 'ai-sdk',
serverName: 'AI SDK',
name: chunk.toolName,
description: '',
inputSchema: {
type: 'object',
title: chunk.toolName,
properties: {}
}
},
arguments: {},
status: 'invoking',
response: chunk.argsTextDelta,
toolCallId: chunk.toolCallId
}
]
})
break
case 'tool-call':
// 原始的工具调用(未被中间件处理)
this.toolCallHandler.handleToolCall(chunk)
@ -160,6 +119,11 @@ export class AiSdkToChunkAdapter {
// 原始的工具调用结果(未被中间件处理)
this.toolCallHandler.handleToolResult(chunk)
break
// case 'start':
// this.onChunk({
// type: ChunkType.LLM_RESPONSE_CREATED
// })
// break
// === 步骤相关事件 ===
// TODO: 需要区分接口开始和步骤开始
@ -168,13 +132,17 @@ export class AiSdkToChunkAdapter {
// type: ChunkType.LLM_RESPONSE_CREATED
// })
// break
case 'step-finish':
this.onChunk({
type: ChunkType.TEXT_COMPLETE,
text: final.text || '' // TEXT_COMPLETE 需要 text 字段
})
final.text = ''
break
// case 'step-finish':
// this.onChunk({
// type: ChunkType.TEXT_COMPLETE,
// text: final.text || '' // TEXT_COMPLETE 需要 text 字段
// })
// final.text = ''
// break
// case 'finish-step': {
// const { totalUsage, finishReason, providerMetadata } = chunk
// }
case 'finish':
this.onChunk({
@ -183,13 +151,13 @@ export class AiSdkToChunkAdapter {
text: final.text || '',
reasoning_content: final.reasoning_content || '',
usage: {
completion_tokens: chunk.usage.completionTokens || 0,
prompt_tokens: chunk.usage.promptTokens || 0,
total_tokens: chunk.usage.totalTokens || 0
completion_tokens: chunk.totalUsage.outputTokens || 0,
prompt_tokens: chunk.totalUsage.inputTokens || 0,
total_tokens: chunk.totalUsage.totalTokens || 0
},
metrics: chunk.usage
metrics: chunk.totalUsage
? {
completion_tokens: chunk.usage.completionTokens || 0,
completion_tokens: chunk.totalUsage.outputTokens || 0,
time_completion_millsec: 0
}
: undefined
@ -201,13 +169,13 @@ export class AiSdkToChunkAdapter {
text: final.text || '',
reasoning_content: final.reasoning_content || '',
usage: {
completion_tokens: chunk.usage.completionTokens || 0,
prompt_tokens: chunk.usage.promptTokens || 0,
total_tokens: chunk.usage.totalTokens || 0
completion_tokens: chunk.totalUsage.outputTokens || 0,
prompt_tokens: chunk.totalUsage.inputTokens || 0,
total_tokens: chunk.totalUsage.totalTokens || 0
},
metrics: chunk.usage
metrics: chunk.totalUsage
? {
completion_tokens: chunk.usage.completionTokens || 0,
completion_tokens: chunk.totalUsage.outputTokens || 0,
time_completion_millsec: 0
}
: undefined
@ -217,30 +185,24 @@ export class AiSdkToChunkAdapter {
// === 源和文件相关事件 ===
case 'source':
// 源信息,可以映射到知识搜索完成
this.onChunk({
type: ChunkType.KNOWLEDGE_SEARCH_COMPLETE,
knowledge: [
{
id: Number(chunk.source.id) || Date.now(),
content: chunk.source.title || '',
sourceUrl: chunk.source.url || '',
type: 'url'
}
]
})
break
case 'file':
// 文件相关事件,可能是图片生成
this.onChunk({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: [chunk.base64]
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
source: WebSearchSource.AISDK,
results: [{}]
}
})
break
// case 'file':
// // 文件相关事件,可能是图片生成
// this.onChunk({
// type: ChunkType.IMAGE_COMPLETE,
// image: {
// type: 'base64',
// images: [chunk.base64]
// }
// })
// break
case 'error':
this.onChunk({
type: ChunkType.ERROR,

View File

@ -4,8 +4,9 @@
* API使
*/
import { ToolCallUnion, ToolSet } from '@cherrystudio/ai-core/index'
import Logger from '@renderer/config/logger'
import { MCPToolResponse } from '@renderer/types'
import { MCPTool, MCPToolResponse } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
/**
@ -19,10 +20,13 @@ export class ToolCallChunkHandler {
toolCallId: string
toolName: string
args: any
// mcpTool: MCPTool
mcpTool: MCPTool
}
>()
constructor(private onChunk: (chunk: Chunk) => void) {}
constructor(
private onChunk: (chunk: Chunk) => void,
private mcpTools: MCPTool[]
) {}
// /**
// * 设置 onChunk 回调
@ -34,10 +38,14 @@ export class ToolCallChunkHandler {
/**
*
*/
public handleToolCall(chunk: any): void {
public handleToolCall(
chunk: {
type: 'tool-call'
} & ToolCallUnion<ToolSet>
): void {
const toolCallId = chunk.toolCallId
const toolName = chunk.toolName
const args = chunk.args || {}
const args = chunk.input || {}
if (!toolCallId || !toolName) {
Logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
@ -51,17 +59,14 @@ export class ToolCallChunkHandler {
this.activeToolCalls.set(toolCallId, {
toolCallId,
toolName,
args
// mcpTool
args,
mcpTool: this.mcpTools.find((tool) => tool.name === toolName)!
})
// 创建 MCPToolResponse 格式
const toolResponse: MCPToolResponse = {
id: toolCallId,
tool: {
id: toolCallId,
name: toolName
},
tool: this.activeToolCalls.get(toolCallId)!.mcpTool,
arguments: args,
status: 'invoking',
toolCallId: toolCallId
@ -98,10 +103,7 @@ export class ToolCallChunkHandler {
// 创建工具调用结果的 MCPToolResponse 格式
const toolResponse: MCPToolResponse = {
id: toolCallId,
tool: {
id: toolCallId,
name: toolCallInfo.toolName
},
tool: toolCallInfo.mcpTool,
arguments: toolCallInfo.args,
status: 'done',
response: {

View File

@ -195,7 +195,7 @@ export default class ModernAiProvider {
// 创建带有中间件的执行器
if (middlewareConfig.onChunk) {
// 流式处理 - 使用适配器
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk, middlewareConfig.mcpTools)
console.log('最终params', params)
const streamResult = await executor.streamText(
modelId,

View File

@ -1,8 +1,4 @@
import {
extractReasoningMiddleware,
LanguageModelV1Middleware,
simulateStreamingMiddleware
} from '@cherrystudio/ai-core'
import { LanguageModelV2Middleware, simulateStreamingMiddleware } from '@cherrystudio/ai-core'
import type { MCPTool, Model, Provider } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
@ -26,7 +22,7 @@ export interface AiSdkMiddlewareConfig {
*/
export interface NamedAiSdkMiddleware {
name: string
middleware: LanguageModelV1Middleware
middleware: LanguageModelV2Middleware
}
/**
@ -75,7 +71,7 @@ export class AiSdkMiddlewareBuilder {
/**
*
*/
public build(): LanguageModelV1Middleware[] {
public build(): LanguageModelV2Middleware[] {
return this.middlewares.map((m) => m.middleware)
}
@ -106,7 +102,7 @@ export class AiSdkMiddlewareBuilder {
* AI SDK中间件的工厂函数
*
*/
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV1Middleware[] {
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV2Middleware[] {
const builder = new AiSdkMiddlewareBuilder()
// 1. 根据provider添加特定中间件
@ -143,10 +139,10 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
// Anthropic特定中间件
break
case 'openai':
builder.add({
name: 'thinking-tag-extraction',
middleware: extractReasoningMiddleware({ tagName: 'think' })
})
// builder.add({
// name: 'thinking-tag-extraction',
// middleware: extractReasoningMiddleware({ tagName: 'think' })
// })
break
case 'gemini':
// Gemini特定中间件

View File

@ -1,4 +1,4 @@
import { definePlugin } from '@cherrystudio/ai-core'
import { definePlugin, TextStreamPart, ToolSet } from '@cherrystudio/ai-core'
export default definePlugin({
name: 'reasoningTimePlugin',
@ -8,57 +8,62 @@ export default definePlugin({
let thinkingStartTime = 0
let hasStartedThinking = false
let accumulatedThinkingContent = ''
let reasoningBlockId = ''
return new TransformStream({
transform(chunk, controller) {
if (chunk.type !== 'reasoning') {
// === 处理 reasoning 结束 ===
if (hasStartedThinking) {
console.log(`[ReasoningPlugin] Ending reasoning.`)
// 生成 reasoning-signature
controller.enqueue({
type: 'reasoning-signature',
text: accumulatedThinkingContent,
thinking_millsec: performance.now() - thinkingStartTime
})
// 重置状态
accumulatedThinkingContent = ''
hasStartedThinking = false
thinkingStartTime = 0
}
controller.enqueue(chunk)
return
}
return new TransformStream<TextStreamPart<ToolSet>, TextStreamPart<ToolSet>>({
transform(chunk: TextStreamPart<ToolSet>, controller: TransformStreamDefaultController<TextStreamPart<ToolSet>>) {
// === 处理 reasoning 类型 ===
if (chunk.type === 'reasoning') {
if (!hasStartedThinking) {
hasStartedThinking = true
thinkingStartTime = performance.now()
reasoningBlockId = chunk.id
}
accumulatedThinkingContent += chunk.text
// 1. 时间跟踪逻辑
if (!hasStartedThinking) {
hasStartedThinking = true
thinkingStartTime = performance.now()
console.log(`[ReasoningPlugin] Starting reasoning session`)
controller.enqueue({
...chunk,
providerMetadata: {
...chunk.providerMetadata,
metadata: {
...chunk.providerMetadata?.metadata,
thinking_millsec: performance.now() - thinkingStartTime,
thinking_content: accumulatedThinkingContent
}
}
})
} else if (hasStartedThinking) {
controller.enqueue({
type: 'reasoning-end',
id: reasoningBlockId,
providerMetadata: {
metadata: {
thinking_millsec: performance.now() - thinkingStartTime,
thinking_content: accumulatedThinkingContent
}
}
})
accumulatedThinkingContent = ''
hasStartedThinking = false
thinkingStartTime = 0
reasoningBlockId = ''
controller.enqueue(chunk)
} else {
controller.enqueue(chunk)
}
accumulatedThinkingContent += chunk.textDelta
// 2. 直接透传 chunk并附加上时间
console.log(`[ReasoningPlugin] Forwarding reasoning chunk: "${chunk.textDelta}"`)
controller.enqueue({
...chunk,
thinking_millsec: performance.now() - thinkingStartTime
})
},
// === flush 处理流结束时仍在reasoning状态的场景 ===
flush(controller) {
if (hasStartedThinking) {
console.log(`[ReasoningPlugin] Final flush for reasoning-signature.`)
controller.enqueue({
type: 'reasoning-signature',
text: accumulatedThinkingContent,
thinking_millsec: performance.now() - thinkingStartTime
type: 'reasoning-end',
id: reasoningBlockId,
providerMetadata: {
metadata: {
thinking_millsec: performance.now() - thinkingStartTime,
thinking_content: accumulatedThinkingContent
}
}
})
}
}

View File

@ -48,7 +48,7 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string,
console.log('mcpTool', mcpTool.inputSchema)
tools[mcpTool.name] = aiSdk.tool<any, ToolCallResult>({
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
parameters: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
execute: async (params): Promise<ToolCallResult> => {
console.log('execute_params', params)
// 创建适配的 MCPToolResponse 对象

View File

@ -545,7 +545,8 @@ export enum WebSearchSource {
QWEN = 'qwen',
HUNYUAN = 'hunyuan',
ZHIPU = 'zhipu',
GROK = 'grok'
GROK = 'grok',
AISDK = 'ai-sdk'
}
export type WebSearchResponse = {

View File

@ -956,6 +956,7 @@ __metadata:
ollama-ai-provider: "npm:^1.2.0"
tsdown: "npm:^0.12.9"
typescript: "npm:^5.0.0"
zod: "npm:^3.25.0"
peerDependenciesMeta:
"@ai-sdk/amazon-bedrock":
optional: true
@ -20140,6 +20141,13 @@ __metadata:
languageName: node
linkType: hard
"zod@npm:^3.25.0":
version: 3.25.74
resolution: "zod@npm:3.25.74"
checksum: 10c0/59e38b046ac333b5bd1ba325a83b6798721227cbfb1e69dfc7159bd7824b904241ab923026edb714fafefec3624265ae374a70aee9a5a45b365bd31781ffa105
languageName: node
linkType: hard
"zustand@npm:^4.4.0":
version: 4.5.6
resolution: "zustand@npm:4.5.6"