mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
refactor: migrate to v5 patch-1
This commit is contained in:
parent
cd42410d70
commit
3e5969b97c
@ -46,7 +46,8 @@
|
||||
"@openrouter/ai-sdk-provider": "^0.7.2",
|
||||
"ai": "5.0.0-beta.7",
|
||||
"anthropic-vertex-ai": "^1.0.2",
|
||||
"ollama-ai-provider": "^1.2.0"
|
||||
"ollama-ai-provider": "^1.2.0",
|
||||
"zod": "^3.25.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@ai-sdk/amazon-bedrock": {
|
||||
|
||||
@ -19,5 +19,6 @@ export {
|
||||
} from './models'
|
||||
|
||||
// 执行管理
|
||||
export type { MCPRequestContext } from './plugins/built-in/mcpPromptPlugin'
|
||||
export type { ExecutionOptions, ExecutorConfig } from './runtime'
|
||||
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import { type AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||
import { type GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||
import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||
import { type LanguageModelV1ProviderMetadata } from '@ai-sdk/provider'
|
||||
import { type SharedV2ProviderMetadata } from '@ai-sdk/provider'
|
||||
|
||||
import { type OpenRouterProviderOptions } from './openrouter'
|
||||
|
||||
export type ProviderOptions<T extends keyof LanguageModelV1ProviderMetadata> = LanguageModelV1ProviderMetadata[T]
|
||||
export type ProviderOptions<T extends keyof SharedV2ProviderMetadata> = SharedV2ProviderMetadata[T]
|
||||
|
||||
/**
|
||||
* 供应商选项类型,如果map中没有,说明没有约束
|
||||
@ -28,4 +28,4 @@ export type TypedProviderOptions = {
|
||||
[K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K]
|
||||
} & {
|
||||
[K in string]?: Record<string, any>
|
||||
} & LanguageModelV1ProviderMetadata
|
||||
} & SharedV2ProviderMetadata
|
||||
|
||||
@ -3,8 +3,7 @@
|
||||
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
|
||||
* 内置默认逻辑,支持自定义覆盖
|
||||
*/
|
||||
import type { ToolSet } from 'ai'
|
||||
import { ToolExecutionError } from 'ai'
|
||||
import type { ModelMessage, TextStreamPart, ToolErrorUnion, ToolSet } from 'ai'
|
||||
|
||||
import { definePlugin } from '../index'
|
||||
import type { AiRequestContext } from '../types'
|
||||
@ -44,17 +43,17 @@ export interface MCPPromptConfig {
|
||||
// 是否启用(用于运行时开关)
|
||||
enabled?: boolean
|
||||
// 自定义系统提示符构建函数(可选,有默认实现)
|
||||
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => Promise<string>
|
||||
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
|
||||
// 自定义工具解析函数(可选,有默认实现)
|
||||
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
|
||||
createSystemMessage?: (systemPrompt: string, originalParams: any, context: MCPRequestContext) => string | null
|
||||
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
|
||||
}
|
||||
|
||||
/**
|
||||
* 扩展的 AI 请求上下文,支持 MCP 工具存储
|
||||
*/
|
||||
interface MCPRequestContext extends AiRequestContext {
|
||||
mcpTools?: ToolSet
|
||||
export interface MCPRequestContext extends AiRequestContext {
|
||||
mcpTools: ToolSet
|
||||
}
|
||||
|
||||
/**
|
||||
@ -201,7 +200,7 @@ function buildAvailableTools(tools: ToolSet): string {
|
||||
<name>${toolName}</name>
|
||||
<description>${tool.description || ''}</description>
|
||||
<arguments>
|
||||
${tool.parameters ? JSON.stringify(tool.parameters) : ''}
|
||||
${tool.inputSchema ? JSON.stringify(tool.inputSchema) : ''}
|
||||
</arguments>
|
||||
</tool>
|
||||
`
|
||||
@ -215,7 +214,7 @@ ${availableTools}
|
||||
/**
|
||||
* 默认的系统提示符构建函数(提取自 Cherry Studio)
|
||||
*/
|
||||
async function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): Promise<string> {
|
||||
function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): string {
|
||||
const availableTools = buildAvailableTools(tools)
|
||||
|
||||
const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES)
|
||||
@ -291,8 +290,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
|
||||
|
||||
return definePlugin({
|
||||
name: 'built-in:mcp-prompt',
|
||||
|
||||
transformParams: async (params: any, context: MCPRequestContext) => {
|
||||
transformParams: (params: any, context: AiRequestContext) => {
|
||||
if (!enabled || !params.tools || typeof params.tools !== 'object') {
|
||||
return params
|
||||
}
|
||||
@ -303,7 +301,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
|
||||
|
||||
// 构建系统提示符
|
||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||
const systemPrompt = await buildSystemPrompt(userSystemPrompt, params.tools)
|
||||
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
|
||||
let systemMessage: string | null = systemPrompt
|
||||
console.log('config.context', context)
|
||||
if (config.createSystemMessage) {
|
||||
@ -320,25 +318,30 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
|
||||
console.log('transformedParams', transformedParams)
|
||||
return transformedParams
|
||||
},
|
||||
|
||||
// 流式处理:监听 step-finish 事件并处理工具调用
|
||||
transformStream: (_, context: MCPRequestContext) => () => {
|
||||
transformStream: (_: any, context: AiRequestContext) => () => {
|
||||
let textBuffer = ''
|
||||
let stepId = ''
|
||||
let executedResults: { toolCallId: string; toolName: string; result: any; isError?: boolean }[] = []
|
||||
|
||||
return new TransformStream<any>({
|
||||
async transform(chunk, controller) {
|
||||
if (!context.mcpTools) {
|
||||
throw new Error('No tools available')
|
||||
}
|
||||
type TOOLS = NonNullable<typeof context.mcpTools>
|
||||
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
|
||||
async transform(
|
||||
chunk: TextStreamPart<TOOLS>,
|
||||
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
|
||||
) {
|
||||
// console.log('chunk', chunk)
|
||||
// 收集文本内容
|
||||
if (chunk.type === 'text-delta') {
|
||||
textBuffer += chunk.textDelta || ''
|
||||
if (chunk.type === 'text') {
|
||||
textBuffer += chunk.text || ''
|
||||
stepId = chunk.id || ''
|
||||
// console.log('textBuffer', textBuffer)
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
// 监听 step-finish 事件
|
||||
if (chunk.type === 'step-finish' || chunk.type === 'finish') {
|
||||
if (chunk.type === 'finish-step') {
|
||||
// console.log('[MCP Prompt Stream] Received step-finish, checking for tool use...')
|
||||
|
||||
// 从 context 获取工具信息
|
||||
@ -364,17 +367,11 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
|
||||
|
||||
// console.log('[MCP Prompt Stream] Found valid tool uses:', validToolUses.length)
|
||||
|
||||
// 修改 step-finish 事件,标记为工具调用
|
||||
if (chunk.type !== 'finish') {
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
finishReason: 'tool-call'
|
||||
})
|
||||
}
|
||||
|
||||
// 发送 step-start 事件(工具调用步骤开始)
|
||||
controller.enqueue({
|
||||
type: 'step-start'
|
||||
type: 'start-step',
|
||||
request: {},
|
||||
warnings: []
|
||||
})
|
||||
|
||||
// 执行工具调用
|
||||
@ -392,7 +389,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
|
||||
type: 'tool-call',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
args: toolUse.arguments
|
||||
input: tool.inputSchema
|
||||
})
|
||||
|
||||
const result = await tool.execute(toolUse.arguments, {
|
||||
@ -406,8 +403,8 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
|
||||
type: 'tool-result',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
args: toolUse.arguments,
|
||||
result
|
||||
input: toolUse.arguments,
|
||||
output: result
|
||||
})
|
||||
|
||||
executedResults.push({
|
||||
@ -420,39 +417,36 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
|
||||
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
|
||||
|
||||
// 使用 AI SDK 标准错误格式
|
||||
const toolError = new ToolExecutionError({
|
||||
toolName: toolUse.toolName,
|
||||
toolArgs: toolUse.arguments,
|
||||
const toolError: ToolErrorUnion<typeof context.mcpTools> = {
|
||||
type: 'tool-error',
|
||||
toolCallId: toolUse.id,
|
||||
message: `Tool execution failed: ${error instanceof Error ? error.message : String(error)}`,
|
||||
cause: error instanceof Error ? error : undefined
|
||||
})
|
||||
toolName: toolUse.toolName,
|
||||
input: toolUse.arguments,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
}
|
||||
|
||||
controller.enqueue(toolError)
|
||||
|
||||
// 发送标准错误事件
|
||||
controller.enqueue({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: toolError.message,
|
||||
name: toolError.name,
|
||||
toolName: toolError.toolName,
|
||||
toolCallId: toolError.toolCallId
|
||||
}
|
||||
error: toolError.error
|
||||
})
|
||||
|
||||
// 发送 tool-result 错误事件
|
||||
controller.enqueue({
|
||||
type: 'tool-result',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
args: toolUse.arguments,
|
||||
isError: true,
|
||||
result: toolError.message
|
||||
})
|
||||
// // 发送 tool-result 错误事件
|
||||
// controller.enqueue({
|
||||
// type: 'tool-result',
|
||||
// toolCallId: toolUse.id,
|
||||
// toolName: toolUse.toolName,
|
||||
// args: toolUse.arguments,
|
||||
// isError: true,
|
||||
// result: toolError.message
|
||||
// })
|
||||
|
||||
executedResults.push({
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
result: toolError.message,
|
||||
result: toolError.error,
|
||||
isError: true
|
||||
})
|
||||
}
|
||||
@ -460,8 +454,11 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
|
||||
|
||||
// 发送最终的 step-finish 事件
|
||||
controller.enqueue({
|
||||
type: 'step-finish',
|
||||
finishReason: 'tool-call'
|
||||
type: 'finish-step',
|
||||
finishReason: 'tool-calls',
|
||||
response: chunk.response,
|
||||
usage: chunk.usage,
|
||||
providerMetadata: chunk.providerMetadata
|
||||
})
|
||||
|
||||
// 递归调用逻辑
|
||||
@ -481,7 +478,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
|
||||
.join('\n\n')
|
||||
// console.log('context.originalParams.messages', context.originalParams.messages)
|
||||
// 构建新的对话消息
|
||||
const newMessages = [
|
||||
const newMessages: ModelMessage[] = [
|
||||
...(context.originalParams.messages || []),
|
||||
{
|
||||
role: 'assistant',
|
||||
@ -540,8 +537,9 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
|
||||
|
||||
// 继续发送文本增量,保持流的连续性
|
||||
controller.enqueue({
|
||||
type: 'text-delta',
|
||||
textDelta: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||
type: 'text',
|
||||
id: stepId,
|
||||
text: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
// 核心类型和接口
|
||||
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './types'
|
||||
import { ProviderId } from '../providers/registry'
|
||||
import type { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
// 插件管理器
|
||||
export { PluginManager } from './manager'
|
||||
|
||||
// 工具函数
|
||||
export function createContext(providerId: string, modelId: string, originalParams: any): AiRequestContext {
|
||||
export function createContext(providerId: ProviderId, modelId: string, originalParams: any): AiRequestContext {
|
||||
return {
|
||||
providerId,
|
||||
modelId,
|
||||
|
||||
@ -20,6 +20,7 @@ export interface AiRequestContext {
|
||||
requestId: string
|
||||
recursiveCall: RecursiveCallFn
|
||||
isRecursiveCall?: boolean
|
||||
mcpTools?: ToolSet
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
@ -47,7 +48,7 @@ export interface AiPlugin {
|
||||
transformStream?: (
|
||||
params: any,
|
||||
context: AiRequestContext
|
||||
) => <TOOLS extends ToolSet>(options: {
|
||||
) => <TOOLS extends ToolSet>(options?: {
|
||||
tools: TOOLS
|
||||
stopStream: () => void
|
||||
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
|
||||
|
||||
@ -16,7 +16,7 @@ export type {
|
||||
|
||||
// === 便捷工厂函数 ===
|
||||
|
||||
import { LanguageModelV1Middleware } from 'ai'
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
||||
import { type AiPlugin } from '../plugins'
|
||||
@ -54,7 +54,7 @@ export async function streamText<T extends ProviderId>(
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['streamText']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.streamText(modelId, params, { middlewares })
|
||||
@ -69,7 +69,7 @@ export async function generateText<T extends ProviderId>(
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['generateText']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateText(modelId, params, { middlewares })
|
||||
@ -84,7 +84,7 @@ export async function generateObject<T extends ProviderId>(
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['generateObject']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateObject(modelId, params, { middlewares })
|
||||
@ -99,7 +99,7 @@ export async function streamObject<T extends ProviderId>(
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['streamObject']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.streamObject(modelId, params, { middlewares })
|
||||
|
||||
@ -68,8 +68,10 @@ export type {
|
||||
TextStreamPart,
|
||||
// 工具相关类型
|
||||
Tool,
|
||||
ToolCallUnion,
|
||||
ToolModelMessage,
|
||||
ToolResultPart,
|
||||
ToolSet,
|
||||
UserModelMessage
|
||||
} from 'ai'
|
||||
export { defaultSettingsMiddleware, extractReasoningMiddleware, simulateStreamingMiddleware, smoothStream } from 'ai'
|
||||
|
||||
@ -3,7 +3,8 @@
|
||||
* 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式
|
||||
*/
|
||||
|
||||
import { TextStreamPart } from '@cherrystudio/ai-core'
|
||||
import { TextStreamPart, ToolSet } from '@cherrystudio/ai-core'
|
||||
import { MCPTool, WebSearchSource } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import { ToolCallChunkHandler } from './chunk/handleTooCallChunk'
|
||||
@ -24,8 +25,11 @@ export interface CherryStudioChunk {
|
||||
*/
|
||||
export class AiSdkToChunkAdapter {
|
||||
toolCallHandler: ToolCallChunkHandler
|
||||
constructor(private onChunk: (chunk: Chunk) => void) {
|
||||
this.toolCallHandler = new ToolCallChunkHandler(onChunk)
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
private mcpTools: MCPTool[] = []
|
||||
) {
|
||||
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -47,7 +51,7 @@ export class AiSdkToChunkAdapter {
|
||||
* 读取 fullStream 并转换为 Cherry Studio chunks
|
||||
* @param fullStream AI SDK 的 fullStream (ReadableStream)
|
||||
*/
|
||||
private async readFullStream(fullStream: ReadableStream<TextStreamPart<any>>) {
|
||||
private async readFullStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>) {
|
||||
const reader = fullStream.getReader()
|
||||
const final = {
|
||||
text: '',
|
||||
@ -73,84 +77,39 @@ export class AiSdkToChunkAdapter {
|
||||
* 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调
|
||||
* @param chunk AI SDK 的 chunk 数据
|
||||
*/
|
||||
private convertAndEmitChunk(chunk: any, final: { text: string; reasoning_content: string }) {
|
||||
private convertAndEmitChunk(chunk: TextStreamPart<any>, final: { text: string; reasoning_content: string }) {
|
||||
console.log('AI SDK chunk type:', chunk.type, chunk)
|
||||
switch (chunk.type) {
|
||||
// === 文本相关事件 ===
|
||||
case 'text-delta':
|
||||
final.text += chunk.textDelta || ''
|
||||
case 'text':
|
||||
final.text += chunk.text || ''
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: chunk.textDelta || ''
|
||||
text: chunk.text || ''
|
||||
})
|
||||
break
|
||||
case 'text-end':
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: final.text || ''
|
||||
})
|
||||
break
|
||||
|
||||
case 'reasoning':
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: chunk.textDelta || '',
|
||||
// 自定义字段
|
||||
thinking_millsec: chunk.thinking_millsec || 0
|
||||
text: chunk.text || '',
|
||||
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
|
||||
})
|
||||
break
|
||||
case 'redacted-reasoning':
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: chunk.data || ''
|
||||
})
|
||||
break
|
||||
case 'reasoning-signature':
|
||||
case 'reasoning-end':
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: chunk.text || '',
|
||||
thinking_millsec: chunk.thinking_millsec || 0
|
||||
text: (chunk.providerMetadata?.metadata?.thinking_content as string) || '',
|
||||
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
|
||||
})
|
||||
break
|
||||
|
||||
// === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) ===
|
||||
case 'tool-call-streaming-start':
|
||||
// 开始流式工具调用
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: [
|
||||
{
|
||||
id: chunk.toolCallId,
|
||||
name: chunk.toolName,
|
||||
args: {}
|
||||
}
|
||||
]
|
||||
})
|
||||
break
|
||||
|
||||
case 'tool-call-delta':
|
||||
// 工具调用参数的增量更新
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_IN_PROGRESS,
|
||||
responses: [
|
||||
{
|
||||
id: chunk.toolCallId,
|
||||
tool: {
|
||||
id: chunk.toolName,
|
||||
// TODO: serverId,serverName
|
||||
serverId: 'ai-sdk',
|
||||
serverName: 'AI SDK',
|
||||
name: chunk.toolName,
|
||||
description: '',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
title: chunk.toolName,
|
||||
properties: {}
|
||||
}
|
||||
},
|
||||
arguments: {},
|
||||
status: 'invoking',
|
||||
response: chunk.argsTextDelta,
|
||||
toolCallId: chunk.toolCallId
|
||||
}
|
||||
]
|
||||
})
|
||||
break
|
||||
|
||||
case 'tool-call':
|
||||
// 原始的工具调用(未被中间件处理)
|
||||
this.toolCallHandler.handleToolCall(chunk)
|
||||
@ -160,6 +119,11 @@ export class AiSdkToChunkAdapter {
|
||||
// 原始的工具调用结果(未被中间件处理)
|
||||
this.toolCallHandler.handleToolResult(chunk)
|
||||
break
|
||||
// case 'start':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||
// })
|
||||
// break
|
||||
|
||||
// === 步骤相关事件 ===
|
||||
// TODO: 需要区分接口开始和步骤开始
|
||||
@ -168,13 +132,17 @@ export class AiSdkToChunkAdapter {
|
||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||
// })
|
||||
// break
|
||||
case 'step-finish':
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: final.text || '' // TEXT_COMPLETE 需要 text 字段
|
||||
})
|
||||
final.text = ''
|
||||
break
|
||||
// case 'step-finish':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.TEXT_COMPLETE,
|
||||
// text: final.text || '' // TEXT_COMPLETE 需要 text 字段
|
||||
// })
|
||||
// final.text = ''
|
||||
// break
|
||||
|
||||
// case 'finish-step': {
|
||||
// const { totalUsage, finishReason, providerMetadata } = chunk
|
||||
// }
|
||||
|
||||
case 'finish':
|
||||
this.onChunk({
|
||||
@ -183,13 +151,13 @@ export class AiSdkToChunkAdapter {
|
||||
text: final.text || '',
|
||||
reasoning_content: final.reasoning_content || '',
|
||||
usage: {
|
||||
completion_tokens: chunk.usage.completionTokens || 0,
|
||||
prompt_tokens: chunk.usage.promptTokens || 0,
|
||||
total_tokens: chunk.usage.totalTokens || 0
|
||||
completion_tokens: chunk.totalUsage.outputTokens || 0,
|
||||
prompt_tokens: chunk.totalUsage.inputTokens || 0,
|
||||
total_tokens: chunk.totalUsage.totalTokens || 0
|
||||
},
|
||||
metrics: chunk.usage
|
||||
metrics: chunk.totalUsage
|
||||
? {
|
||||
completion_tokens: chunk.usage.completionTokens || 0,
|
||||
completion_tokens: chunk.totalUsage.outputTokens || 0,
|
||||
time_completion_millsec: 0
|
||||
}
|
||||
: undefined
|
||||
@ -201,13 +169,13 @@ export class AiSdkToChunkAdapter {
|
||||
text: final.text || '',
|
||||
reasoning_content: final.reasoning_content || '',
|
||||
usage: {
|
||||
completion_tokens: chunk.usage.completionTokens || 0,
|
||||
prompt_tokens: chunk.usage.promptTokens || 0,
|
||||
total_tokens: chunk.usage.totalTokens || 0
|
||||
completion_tokens: chunk.totalUsage.outputTokens || 0,
|
||||
prompt_tokens: chunk.totalUsage.inputTokens || 0,
|
||||
total_tokens: chunk.totalUsage.totalTokens || 0
|
||||
},
|
||||
metrics: chunk.usage
|
||||
metrics: chunk.totalUsage
|
||||
? {
|
||||
completion_tokens: chunk.usage.completionTokens || 0,
|
||||
completion_tokens: chunk.totalUsage.outputTokens || 0,
|
||||
time_completion_millsec: 0
|
||||
}
|
||||
: undefined
|
||||
@ -217,30 +185,24 @@ export class AiSdkToChunkAdapter {
|
||||
|
||||
// === 源和文件相关事件 ===
|
||||
case 'source':
|
||||
// 源信息,可以映射到知识搜索完成
|
||||
this.onChunk({
|
||||
type: ChunkType.KNOWLEDGE_SEARCH_COMPLETE,
|
||||
knowledge: [
|
||||
{
|
||||
id: Number(chunk.source.id) || Date.now(),
|
||||
content: chunk.source.title || '',
|
||||
sourceUrl: chunk.source.url || '',
|
||||
type: 'url'
|
||||
}
|
||||
]
|
||||
})
|
||||
break
|
||||
|
||||
case 'file':
|
||||
// 文件相关事件,可能是图片生成
|
||||
this.onChunk({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [chunk.base64]
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
source: WebSearchSource.AISDK,
|
||||
results: [{}]
|
||||
}
|
||||
})
|
||||
break
|
||||
// case 'file':
|
||||
// // 文件相关事件,可能是图片生成
|
||||
// this.onChunk({
|
||||
// type: ChunkType.IMAGE_COMPLETE,
|
||||
// image: {
|
||||
// type: 'base64',
|
||||
// images: [chunk.base64]
|
||||
// }
|
||||
// })
|
||||
// break
|
||||
case 'error':
|
||||
this.onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
|
||||
@ -4,8 +4,9 @@
|
||||
* 提供工具调用相关的处理API,每个交互使用一个新的实例
|
||||
*/
|
||||
|
||||
import { ToolCallUnion, ToolSet } from '@cherrystudio/ai-core/index'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { MCPToolResponse } from '@renderer/types'
|
||||
import { MCPTool, MCPToolResponse } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
/**
|
||||
@ -19,10 +20,13 @@ export class ToolCallChunkHandler {
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
args: any
|
||||
// mcpTool: MCPTool
|
||||
mcpTool: MCPTool
|
||||
}
|
||||
>()
|
||||
constructor(private onChunk: (chunk: Chunk) => void) {}
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
private mcpTools: MCPTool[]
|
||||
) {}
|
||||
|
||||
// /**
|
||||
// * 设置 onChunk 回调
|
||||
@ -34,10 +38,14 @@ export class ToolCallChunkHandler {
|
||||
/**
|
||||
* 处理工具调用事件
|
||||
*/
|
||||
public handleToolCall(chunk: any): void {
|
||||
public handleToolCall(
|
||||
chunk: {
|
||||
type: 'tool-call'
|
||||
} & ToolCallUnion<ToolSet>
|
||||
): void {
|
||||
const toolCallId = chunk.toolCallId
|
||||
const toolName = chunk.toolName
|
||||
const args = chunk.args || {}
|
||||
const args = chunk.input || {}
|
||||
|
||||
if (!toolCallId || !toolName) {
|
||||
Logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
|
||||
@ -51,17 +59,14 @@ export class ToolCallChunkHandler {
|
||||
this.activeToolCalls.set(toolCallId, {
|
||||
toolCallId,
|
||||
toolName,
|
||||
args
|
||||
// mcpTool
|
||||
args,
|
||||
mcpTool: this.mcpTools.find((tool) => tool.name === toolName)!
|
||||
})
|
||||
|
||||
// 创建 MCPToolResponse 格式
|
||||
const toolResponse: MCPToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: {
|
||||
id: toolCallId,
|
||||
name: toolName
|
||||
},
|
||||
tool: this.activeToolCalls.get(toolCallId)!.mcpTool,
|
||||
arguments: args,
|
||||
status: 'invoking',
|
||||
toolCallId: toolCallId
|
||||
@ -98,10 +103,7 @@ export class ToolCallChunkHandler {
|
||||
// 创建工具调用结果的 MCPToolResponse 格式
|
||||
const toolResponse: MCPToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: {
|
||||
id: toolCallId,
|
||||
name: toolCallInfo.toolName
|
||||
},
|
||||
tool: toolCallInfo.mcpTool,
|
||||
arguments: toolCallInfo.args,
|
||||
status: 'done',
|
||||
response: {
|
||||
|
||||
@ -195,7 +195,7 @@ export default class ModernAiProvider {
|
||||
// 创建带有中间件的执行器
|
||||
if (middlewareConfig.onChunk) {
|
||||
// 流式处理 - 使用适配器
|
||||
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
|
||||
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk, middlewareConfig.mcpTools)
|
||||
console.log('最终params', params)
|
||||
const streamResult = await executor.streamText(
|
||||
modelId,
|
||||
|
||||
@ -1,8 +1,4 @@
|
||||
import {
|
||||
extractReasoningMiddleware,
|
||||
LanguageModelV1Middleware,
|
||||
simulateStreamingMiddleware
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { LanguageModelV2Middleware, simulateStreamingMiddleware } from '@cherrystudio/ai-core'
|
||||
import type { MCPTool, Model, Provider } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
|
||||
@ -26,7 +22,7 @@ export interface AiSdkMiddlewareConfig {
|
||||
*/
|
||||
export interface NamedAiSdkMiddleware {
|
||||
name: string
|
||||
middleware: LanguageModelV1Middleware
|
||||
middleware: LanguageModelV2Middleware
|
||||
}
|
||||
|
||||
/**
|
||||
@ -75,7 +71,7 @@ export class AiSdkMiddlewareBuilder {
|
||||
/**
|
||||
* 构建最终的中间件数组
|
||||
*/
|
||||
public build(): LanguageModelV1Middleware[] {
|
||||
public build(): LanguageModelV2Middleware[] {
|
||||
return this.middlewares.map((m) => m.middleware)
|
||||
}
|
||||
|
||||
@ -106,7 +102,7 @@ export class AiSdkMiddlewareBuilder {
|
||||
* 根据配置构建AI SDK中间件的工厂函数
|
||||
* 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果
|
||||
*/
|
||||
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV1Middleware[] {
|
||||
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV2Middleware[] {
|
||||
const builder = new AiSdkMiddlewareBuilder()
|
||||
|
||||
// 1. 根据provider添加特定中间件
|
||||
@ -143,10 +139,10 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
|
||||
// Anthropic特定中间件
|
||||
break
|
||||
case 'openai':
|
||||
builder.add({
|
||||
name: 'thinking-tag-extraction',
|
||||
middleware: extractReasoningMiddleware({ tagName: 'think' })
|
||||
})
|
||||
// builder.add({
|
||||
// name: 'thinking-tag-extraction',
|
||||
// middleware: extractReasoningMiddleware({ tagName: 'think' })
|
||||
// })
|
||||
break
|
||||
case 'gemini':
|
||||
// Gemini特定中间件
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import { definePlugin, TextStreamPart, ToolSet } from '@cherrystudio/ai-core'
|
||||
|
||||
export default definePlugin({
|
||||
name: 'reasoningTimePlugin',
|
||||
@ -8,57 +8,62 @@ export default definePlugin({
|
||||
let thinkingStartTime = 0
|
||||
let hasStartedThinking = false
|
||||
let accumulatedThinkingContent = ''
|
||||
let reasoningBlockId = ''
|
||||
|
||||
return new TransformStream({
|
||||
transform(chunk, controller) {
|
||||
if (chunk.type !== 'reasoning') {
|
||||
// === 处理 reasoning 结束 ===
|
||||
if (hasStartedThinking) {
|
||||
console.log(`[ReasoningPlugin] Ending reasoning.`)
|
||||
|
||||
// 生成 reasoning-signature
|
||||
controller.enqueue({
|
||||
type: 'reasoning-signature',
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: performance.now() - thinkingStartTime
|
||||
})
|
||||
|
||||
// 重置状态
|
||||
accumulatedThinkingContent = ''
|
||||
hasStartedThinking = false
|
||||
thinkingStartTime = 0
|
||||
}
|
||||
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
return new TransformStream<TextStreamPart<ToolSet>, TextStreamPart<ToolSet>>({
|
||||
transform(chunk: TextStreamPart<ToolSet>, controller: TransformStreamDefaultController<TextStreamPart<ToolSet>>) {
|
||||
// === 处理 reasoning 类型 ===
|
||||
if (chunk.type === 'reasoning') {
|
||||
if (!hasStartedThinking) {
|
||||
hasStartedThinking = true
|
||||
thinkingStartTime = performance.now()
|
||||
reasoningBlockId = chunk.id
|
||||
}
|
||||
accumulatedThinkingContent += chunk.text
|
||||
|
||||
// 1. 时间跟踪逻辑
|
||||
if (!hasStartedThinking) {
|
||||
hasStartedThinking = true
|
||||
thinkingStartTime = performance.now()
|
||||
console.log(`[ReasoningPlugin] Starting reasoning session`)
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
providerMetadata: {
|
||||
...chunk.providerMetadata,
|
||||
metadata: {
|
||||
...chunk.providerMetadata?.metadata,
|
||||
thinking_millsec: performance.now() - thinkingStartTime,
|
||||
thinking_content: accumulatedThinkingContent
|
||||
}
|
||||
}
|
||||
})
|
||||
} else if (hasStartedThinking) {
|
||||
controller.enqueue({
|
||||
type: 'reasoning-end',
|
||||
id: reasoningBlockId,
|
||||
providerMetadata: {
|
||||
metadata: {
|
||||
thinking_millsec: performance.now() - thinkingStartTime,
|
||||
thinking_content: accumulatedThinkingContent
|
||||
}
|
||||
}
|
||||
})
|
||||
accumulatedThinkingContent = ''
|
||||
hasStartedThinking = false
|
||||
thinkingStartTime = 0
|
||||
reasoningBlockId = ''
|
||||
controller.enqueue(chunk)
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
accumulatedThinkingContent += chunk.textDelta
|
||||
|
||||
// 2. 直接透传 chunk,并附加上时间
|
||||
console.log(`[ReasoningPlugin] Forwarding reasoning chunk: "${chunk.textDelta}"`)
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
thinking_millsec: performance.now() - thinkingStartTime
|
||||
})
|
||||
},
|
||||
|
||||
// === flush 处理流结束时仍在reasoning状态的场景 ===
|
||||
flush(controller) {
|
||||
if (hasStartedThinking) {
|
||||
console.log(`[ReasoningPlugin] Final flush for reasoning-signature.`)
|
||||
controller.enqueue({
|
||||
type: 'reasoning-signature',
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: performance.now() - thinkingStartTime
|
||||
type: 'reasoning-end',
|
||||
id: reasoningBlockId,
|
||||
providerMetadata: {
|
||||
metadata: {
|
||||
thinking_millsec: performance.now() - thinkingStartTime,
|
||||
thinking_content: accumulatedThinkingContent
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,7 +48,7 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string,
|
||||
console.log('mcpTool', mcpTool.inputSchema)
|
||||
tools[mcpTool.name] = aiSdk.tool<any, ToolCallResult>({
|
||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||
parameters: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||
inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||
execute: async (params): Promise<ToolCallResult> => {
|
||||
console.log('execute_params', params)
|
||||
// 创建适配的 MCPToolResponse 对象
|
||||
|
||||
@ -545,7 +545,8 @@ export enum WebSearchSource {
|
||||
QWEN = 'qwen',
|
||||
HUNYUAN = 'hunyuan',
|
||||
ZHIPU = 'zhipu',
|
||||
GROK = 'grok'
|
||||
GROK = 'grok',
|
||||
AISDK = 'ai-sdk'
|
||||
}
|
||||
|
||||
export type WebSearchResponse = {
|
||||
|
||||
@ -956,6 +956,7 @@ __metadata:
|
||||
ollama-ai-provider: "npm:^1.2.0"
|
||||
tsdown: "npm:^0.12.9"
|
||||
typescript: "npm:^5.0.0"
|
||||
zod: "npm:^3.25.0"
|
||||
peerDependenciesMeta:
|
||||
"@ai-sdk/amazon-bedrock":
|
||||
optional: true
|
||||
@ -20140,6 +20141,13 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"zod@npm:^3.25.0":
|
||||
version: 3.25.74
|
||||
resolution: "zod@npm:3.25.74"
|
||||
checksum: 10c0/59e38b046ac333b5bd1ba325a83b6798721227cbfb1e69dfc7159bd7824b904241ab923026edb714fafefec3624265ae374a70aee9a5a45b365bd31781ffa105
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"zustand@npm:^4.4.0":
|
||||
version: 4.5.6
|
||||
resolution: "zustand@npm:4.5.6"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user