refactor(ToolCall): refactor:mcp-tool-state-management (#10028)

This commit is contained in:
MyPrototypeWhat 2025-09-08 23:29:34 +08:00 committed by GitHub
parent 7df1060370
commit f6ffd574bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 594 additions and 788 deletions

View File

@ -1,103 +0,0 @@
/**
* Hub Provider 使
*
* 使Hub Provider功能来路由到多个底层provider
*/
import { createHubProvider, initializeProvider, providerRegistry } from '../src/index'
async function demonstrateHubProvider() {
try {
// 1. 初始化底层providers
console.log('📦 初始化底层providers...')
initializeProvider('openai', {
apiKey: process.env.OPENAI_API_KEY || 'sk-test-key'
})
initializeProvider('anthropic', {
apiKey: process.env.ANTHROPIC_API_KEY || 'sk-ant-test-key'
})
// 2. 创建Hub Provider自动包含所有已初始化的providers
console.log('🌐 创建Hub Provider...')
const aihubmixProvider = createHubProvider({
hubId: 'aihubmix',
debug: true
})
// 3. 注册Hub Provider
providerRegistry.registerProvider('aihubmix', aihubmixProvider)
console.log('✅ Hub Provider "aihubmix" 注册成功')
// 4. 使用Hub Provider访问不同的模型
console.log('\n🚀 使用Hub模型...')
// 通过Hub路由到OpenAI
const openaiModel = providerRegistry.languageModel('aihubmix:openai:gpt-4')
console.log('✓ OpenAI模型已获取:', openaiModel.modelId)
// 通过Hub路由到Anthropic
const anthropicModel = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
console.log('✓ Anthropic模型已获取:', anthropicModel.modelId)
// 5. 演示错误处理
console.log('\n❌ 演示错误处理...')
try {
// 尝试访问未初始化的provider
providerRegistry.languageModel('aihubmix:google:gemini-pro')
} catch (error) {
console.log('预期错误:', error.message)
}
try {
// 尝试使用错误的模型ID格式
providerRegistry.languageModel('aihubmix:invalid-format')
} catch (error) {
console.log('预期错误:', error.message)
}
// 6. 多个Hub Provider示例
console.log('\n🔄 创建多个Hub Provider...')
const localHubProvider = createHubProvider({
hubId: 'local-ai'
})
providerRegistry.registerProvider('local-ai', localHubProvider)
console.log('✅ Hub Provider "local-ai" 注册成功')
console.log('\n🎉 Hub Provider演示完成')
} catch (error) {
console.error('💥 演示过程中发生错误:', error)
}
}
// 演示简化的使用方式
function simplifiedUsageExample() {
console.log('\n📝 简化使用示例:')
console.log(`
// 1. 初始化providers
initializeProvider('openai', { apiKey: 'sk-xxx' })
initializeProvider('anthropic', { apiKey: 'sk-ant-xxx' })
// 2. 创建并注册Hub Provider
const hubProvider = createHubProvider({ hubId: 'aihubmix' })
providerRegistry.registerProvider('aihubmix', hubProvider)
// 3. 直接使用
const model1 = providerRegistry.languageModel('aihubmix:openai:gpt-4')
const model2 = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
`)
}
// 运行演示
if (require.main === module) {
demonstrateHubProvider()
simplifiedUsageExample()
}
export { demonstrateHubProvider, simplifiedUsageExample }

View File

@ -1,167 +0,0 @@
/**
* Image Generation Example
* 使 aiCore
*/
import { createExecutor, generateImage } from '../src/index'
async function main() {
// 方式1: 使用执行器实例
console.log('📸 创建 OpenAI 图像生成执行器...')
const executor = createExecutor('openai', {
apiKey: process.env.OPENAI_API_KEY!
})
try {
console.log('🎨 使用执行器生成图像...')
const result1 = await executor.generateImage('dall-e-3', {
prompt: 'A futuristic cityscape at sunset with flying cars',
size: '1024x1024',
n: 1
})
console.log('✅ 图像生成成功!')
console.log('📊 结果:', {
imagesCount: result1.images.length,
mediaType: result1.image.mediaType,
hasBase64: !!result1.image.base64,
providerMetadata: result1.providerMetadata
})
} catch (error) {
console.error('❌ 执行器生成失败:', error)
}
// 方式2: 使用直接调用 API
try {
console.log('🎨 使用直接 API 生成图像...')
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
prompt: 'A magical forest with glowing mushrooms and fairy lights',
aspectRatio: '16:9',
providerOptions: {
openai: {
quality: 'hd',
style: 'vivid'
}
}
})
console.log('✅ 直接 API 生成成功!')
console.log('📊 结果:', {
imagesCount: result2.images.length,
mediaType: result2.image.mediaType,
hasBase64: !!result2.image.base64
})
} catch (error) {
console.error('❌ 直接 API 生成失败:', error)
}
// 方式3: 支持其他提供商 (Google Imagen)
if (process.env.GOOGLE_API_KEY) {
try {
console.log('🎨 使用 Google Imagen 生成图像...')
const googleExecutor = createExecutor('google', {
apiKey: process.env.GOOGLE_API_KEY!
})
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
prompt: 'A serene mountain lake at dawn with mist rising from the water',
aspectRatio: '1:1'
})
console.log('✅ Google Imagen 生成成功!')
console.log('📊 结果:', {
imagesCount: result3.images.length,
mediaType: result3.image.mediaType,
hasBase64: !!result3.image.base64
})
} catch (error) {
console.error('❌ Google Imagen 生成失败:', error)
}
}
// 方式4: 支持插件系统
const pluginExample = async () => {
console.log('🔌 演示插件系统...')
// 创建一个示例插件,用于修改提示词
const promptEnhancerPlugin = {
name: 'prompt-enhancer',
transformParams: async (params: any) => {
console.log('🔧 插件: 增强提示词...')
return {
...params,
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
}
},
transformResult: async (result: any) => {
console.log('🔧 插件: 处理结果...')
return {
...result,
enhanced: true
}
}
}
const executorWithPlugin = createExecutor(
'openai',
{
apiKey: process.env.OPENAI_API_KEY!
},
[promptEnhancerPlugin]
)
try {
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A cute robot playing in a garden'
})
console.log('✅ 插件系统生成成功!')
console.log('📊 结果:', {
imagesCount: result4.images.length,
enhanced: (result4 as any).enhanced,
mediaType: result4.image.mediaType
})
} catch (error) {
console.error('❌ 插件系统生成失败:', error)
}
}
await pluginExample()
}
// 错误处理演示
async function errorHandlingExample() {
console.log('⚠️ 演示错误处理...')
try {
const executor = createExecutor('openai', {
apiKey: 'invalid-key'
})
await executor.generateImage('dall-e-3', {
prompt: 'Test image'
})
} catch (error: any) {
console.log('✅ 成功捕获错误:', error.constructor.name)
console.log('📋 错误信息:', error.message)
console.log('🏷️ 提供商ID:', error.providerId)
console.log('🏷️ 模型ID:', error.modelId)
}
}
// 运行示例
if (require.main === module) {
main()
.then(() => {
console.log('🎉 所有示例完成!')
return errorHandlingExample()
})
.then(() => {
console.log('🎯 示例程序结束')
process.exit(0)
})
.catch((error) => {
console.error('💥 程序执行出错:', error)
process.exit(1)
})
}

View File

@ -1,6 +1,6 @@
{
"name": "@cherrystudio/ai-core",
"version": "1.0.0-alpha.12",
"version": "1.0.0-alpha.13",
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
"main": "dist/index.js",
"module": "dist/index.mjs",
@ -39,8 +39,8 @@
"@ai-sdk/anthropic": "^2.0.5",
"@ai-sdk/azure": "^2.0.16",
"@ai-sdk/deepseek": "^1.0.9",
"@ai-sdk/google": "^2.0.7",
"@ai-sdk/openai": "^2.0.19",
"@ai-sdk/google": "^2.0.13",
"@ai-sdk/openai": "^2.0.26",
"@ai-sdk/openai-compatible": "^1.0.9",
"@ai-sdk/provider": "^2.0.0",
"@ai-sdk/provider-utils": "^3.0.4",

View File

@ -84,7 +84,6 @@ export class ModelResolver {
*/
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
console.log('fullModelId', fullModelId)
return globalRegistryManagement.languageModel(fullModelId as any)
}

View File

@ -27,10 +27,20 @@ export class StreamEventManager {
/**
*
*/
sendStepFinishEvent(controller: StreamController, chunk: any): void {
sendStepFinishEvent(
controller: StreamController,
chunk: any,
context: AiRequestContext,
finishReason: string = 'stop'
): void {
// 累加当前步骤的 usage
if (chunk.usage && context.accumulatedUsage) {
this.accumulateUsage(context.accumulatedUsage, chunk.usage)
}
controller.enqueue({
type: 'finish-step',
finishReason: 'stop',
finishReason,
response: chunk.response,
usage: chunk.usage,
providerMetadata: chunk.providerMetadata
@ -43,28 +53,32 @@ export class StreamEventManager {
async handleRecursiveCall(
controller: StreamController,
recursiveParams: any,
context: AiRequestContext,
stepId: string
context: AiRequestContext
): Promise<void> {
try {
console.log('[MCP Prompt] Starting recursive call after tool execution...')
// try {
// 重置工具执行状态,准备处理新的步骤
context.hasExecutedToolsInCurrentStep = false
const recursiveResult = await context.recursiveCall(recursiveParams)
const recursiveResult = await context.recursiveCall(recursiveParams)
if (recursiveResult && recursiveResult.fullStream) {
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
} else {
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
}
} catch (error) {
this.handleRecursiveCallError(controller, error, stepId)
if (recursiveResult && recursiveResult.fullStream) {
await this.pipeRecursiveStream(controller, recursiveResult.fullStream, context)
} else {
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
}
// } catch (error) {
// this.handleRecursiveCallError(controller, error, stepId)
// }
}
/**
*
*/
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
private async pipeRecursiveStream(
controller: StreamController,
recursiveStream: ReadableStream,
context?: AiRequestContext
): Promise<void> {
const reader = recursiveStream.getReader()
try {
while (true) {
@ -73,9 +87,16 @@ export class StreamEventManager {
break
}
if (value.type === 'finish') {
// 迭代的流不发finish
// 迭代的流不发finish但需要累加其 usage
if (value.usage && context?.accumulatedUsage) {
this.accumulateUsage(context.accumulatedUsage, value.usage)
}
break
}
// 对于 finish-step 类型,累加其 usage
if (value.type === 'finish-step' && value.usage && context?.accumulatedUsage) {
this.accumulateUsage(context.accumulatedUsage, value.usage)
}
// 将递归流的数据传递到当前流
controller.enqueue(value)
}
@ -87,25 +108,25 @@ export class StreamEventManager {
/**
*
*/
private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void {
console.error('[MCP Prompt] Recursive call failed:', error)
// private handleRecursiveCallError(controller: StreamController, error: unknown): void {
// console.error('[MCP Prompt] Recursive call failed:', error)
// 使用 AI SDK 标准错误格式,但不中断流
controller.enqueue({
type: 'error',
error: {
message: error instanceof Error ? error.message : String(error),
name: error instanceof Error ? error.name : 'RecursiveCallError'
}
})
// // 使用 AI SDK 标准错误格式,但不中断流
// controller.enqueue({
// type: 'error',
// error: {
// message: error instanceof Error ? error.message : String(error),
// name: error instanceof Error ? error.name : 'RecursiveCallError'
// }
// })
// 继续发送文本增量,保持流的连续性
controller.enqueue({
type: 'text-delta',
id: stepId,
text: '\n\n[工具执行后递归调用失败,继续对话...]'
})
}
// // // 继续发送文本增量,保持流的连续性
// // controller.enqueue({
// // type: 'text-delta',
// // id: stepId,
// // text: '\n\n[工具执行后递归调用失败,继续对话...]'
// // })
// }
/**
*
@ -136,4 +157,18 @@ export class StreamEventManager {
return recursiveParams
}
/**
* usage
*/
private accumulateUsage(target: any, source: any): void {
if (!target || !source) return
// 累加各种 token 类型
target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0)
target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0)
target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0)
target.reasoningTokens = (target.reasoningTokens || 0) + (source.reasoningTokens || 0)
target.cachedInputTokens = (target.cachedInputTokens || 0) + (source.cachedInputTokens || 0)
}
}

View File

@ -4,7 +4,7 @@
*
* promptToolUsePlugin.ts
*/
import type { ToolSet } from 'ai'
import type { ToolSet, TypedToolError } from 'ai'
import type { ToolUseResult } from './type'
@ -38,7 +38,6 @@ export class ToolExecutor {
controller: StreamController
): Promise<ExecutedResult[]> {
const executedResults: ExecutedResult[] = []
for (const toolUse of toolUses) {
try {
const tool = tools[toolUse.toolName]
@ -46,17 +45,12 @@ export class ToolExecutor {
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
}
// 发送工具调用开始事件
this.sendToolStartEvents(controller, toolUse)
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
// 发送 tool-call 事件
controller.enqueue({
type: 'tool-call',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: tool.inputSchema
input: toolUse.arguments
})
const result = await tool.execute(toolUse.arguments, {
@ -111,45 +105,46 @@ export class ToolExecutor {
/**
*
*/
private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
// 发送 tool-input-start 事件
controller.enqueue({
type: 'tool-input-start',
id: toolUse.id,
toolName: toolUse.toolName
})
}
// private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
// // 发送 tool-input-start 事件
// controller.enqueue({
// type: 'tool-input-start',
// id: toolUse.id,
// toolName: toolUse.toolName
// })
// }
/**
*
*/
private handleToolError(
private handleToolError<T extends ToolSet>(
toolUse: ToolUseResult,
error: unknown,
controller: StreamController
// _tools: ToolSet
): ExecutedResult {
// 使用 AI SDK 标准错误格式
// const toolError: TypedToolError<typeof _tools> = {
// type: 'tool-error',
// toolCallId: toolUse.id,
// toolName: toolUse.toolName,
// input: toolUse.arguments,
// error: error instanceof Error ? error.message : String(error)
// }
const toolError: TypedToolError<T> = {
type: 'tool-error',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: toolUse.arguments,
error
}
// controller.enqueue(toolError)
controller.enqueue(toolError)
// 发送标准错误事件
controller.enqueue({
type: 'error',
error: error instanceof Error ? error.message : String(error)
})
// controller.enqueue({
// type: 'tool-error',
// toolCallId: toolUse.id,
// error: error instanceof Error ? error.message : String(error),
// input: toolUse.arguments
// })
return {
toolCallId: toolUse.id,
toolName: toolUse.toolName,
result: error instanceof Error ? error.message : String(error),
result: error,
isError: true
}
}

View File

@ -8,9 +8,19 @@ import type { TextStreamPart, ToolSet } from 'ai'
import { definePlugin } from '../../index'
import type { AiRequestContext } from '../../types'
import { StreamEventManager } from './StreamEventManager'
import { type TagConfig, TagExtractor } from './tagExtraction'
import { ToolExecutor } from './ToolExecutor'
import { PromptToolUseConfig, ToolUseResult } from './type'
/**
* 使
*/
const TOOL_USE_TAG_CONFIG: TagConfig = {
openingTag: '<tool_use>',
closingTag: '</tool_use>',
separator: '\n'
}
/**
* Cherry Studio
*/
@ -249,13 +259,11 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
}
context.mcpTools = params.tools
console.log('tools stored in context', params.tools)
// 构建系统提示符
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
let systemMessage: string | null = systemPrompt
console.log('config.context', context)
if (config.createSystemMessage) {
// 🎯 如果用户提供了自定义处理函数,使用它
systemMessage = config.createSystemMessage(systemPrompt, params, context)
@ -268,20 +276,40 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
tools: undefined
}
context.originalParams = transformedParams
console.log('transformedParams', transformedParams)
return transformedParams
},
transformStream: (_: any, context: AiRequestContext) => () => {
let textBuffer = ''
let stepId = ''
// let stepId = ''
if (!context.mcpTools) {
throw new Error('No tools available')
}
// 创建工具执行器和流事件管理器
// 从 context 中获取或初始化 usage 累加器
if (!context.accumulatedUsage) {
context.accumulatedUsage = {
inputTokens: 0,
outputTokens: 0,
totalTokens: 0,
reasoningTokens: 0,
cachedInputTokens: 0
}
}
// 创建工具执行器、流事件管理器和标签提取器
const toolExecutor = new ToolExecutor()
const streamEventManager = new StreamEventManager()
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
// 在context中初始化工具执行状态避免递归调用时状态丢失
if (!context.hasExecutedToolsInCurrentStep) {
context.hasExecutedToolsInCurrentStep = false
}
// 用于hold text-start事件直到确认有非工具标签内容
let pendingTextStart: TextStreamPart<TOOLS> | null = null
let hasStartedText = false
type TOOLS = NonNullable<typeof context.mcpTools>
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
@ -289,83 +317,106 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
chunk: TextStreamPart<TOOLS>,
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
) {
// 收集文本内容
if (chunk.type === 'text-delta') {
textBuffer += chunk.text || ''
stepId = chunk.id || ''
controller.enqueue(chunk)
// Hold住text-start事件直到确认有非工具标签内容
if ((chunk as any).type === 'text-start') {
pendingTextStart = chunk
return
}
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
const tools = context.mcpTools
if (!tools || Object.keys(tools).length === 0) {
controller.enqueue(chunk)
return
}
// text-delta阶段收集文本内容并过滤工具标签
if (chunk.type === 'text-delta') {
textBuffer += chunk.text || ''
// stepId = chunk.id || ''
// 解析工具调用
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
// 使用TagExtractor过滤工具标签只传递非标签内容到UI层
const extractionResults = tagExtractor.processText(chunk.text || '')
// 如果没有有效的工具调用,直接传递原始事件
if (validToolUses.length === 0) {
controller.enqueue(chunk)
return
}
if (chunk.type === 'text-end') {
controller.enqueue({
type: 'text-end',
id: stepId,
providerMetadata: {
text: {
value: parsedContent
}
for (const result of extractionResults) {
// 只传递非标签内容到UI层
if (!result.isTagContent && result.content) {
// 如果还没有发送text-start且有pending的text-start先发送它
if (!hasStartedText && pendingTextStart) {
controller.enqueue(pendingTextStart)
hasStartedText = true
pendingTextStart = null
}
})
return
const filteredChunk = {
...chunk,
text: result.content
}
controller.enqueue(filteredChunk)
}
}
return
}
if (chunk.type === 'text-end') {
// 只有当已经发送了text-start时才发送text-end
if (hasStartedText) {
controller.enqueue(chunk)
}
return
}
if (chunk.type === 'finish-step') {
// 统一在finish-step阶段检查并执行工具调用
const tools = context.mcpTools
if (tools && Object.keys(tools).length > 0 && !context.hasExecutedToolsInCurrentStep) {
// 解析完整的textBuffer来检测工具调用
const { results: parsedTools } = parseToolUse(textBuffer, tools)
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
if (validToolUses.length > 0) {
context.hasExecutedToolsInCurrentStep = true
// 执行工具调用(不需要手动发送 start-step外部流已经处理
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
// 发送步骤完成事件,使用 tool-calls 作为 finishReason
streamEventManager.sendStepFinishEvent(controller, chunk, context, 'tool-calls')
// 处理递归调用
const toolResultsText = toolExecutor.formatToolResults(executedResults)
const recursiveParams = streamEventManager.buildRecursiveParams(
context,
textBuffer,
toolResultsText,
tools
)
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context)
return
}
}
controller.enqueue({
...chunk,
finishReason: 'tool-calls'
})
// 发送步骤开始事件
streamEventManager.sendStepStartEvent(controller)
// 执行工具调用
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
// 发送步骤完成事件
streamEventManager.sendStepFinishEvent(controller, chunk)
// 处理递归调用
if (validToolUses.length > 0) {
const toolResultsText = toolExecutor.formatToolResults(executedResults)
const recursiveParams = streamEventManager.buildRecursiveParams(
context,
textBuffer,
toolResultsText,
tools
)
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId)
}
// 如果没有执行工具调用直接传递原始finish-step事件
controller.enqueue(chunk)
// 清理状态
textBuffer = ''
return
}
// 对于其他类型的事件,直接传递
controller.enqueue(chunk)
// 处理 finish 类型,使用累加后的 totalUsage
if (chunk.type === 'finish') {
controller.enqueue({
...chunk,
totalUsage: context.accumulatedUsage
})
return
}
// 对于其他类型的事件直接传递不包括text-start已在上面处理
if ((chunk as any).type !== 'text-start') {
controller.enqueue(chunk)
}
},
flush() {
// 流结束时的清理工作
console.log('[MCP Prompt] Stream ended, cleaning up...')
// 清理pending状态
pendingTextStart = null
hasStartedText = false
}
})
}

View File

@ -27,7 +27,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
case 'openai': {
if (config.openai) {
if (!params.tools) params.tools = {}
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
params.tools.web_search = openai.tools.webSearch(config.openai)
}
break
}

View File

@ -1,5 +1,8 @@
// 核心类型和接口
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
import type { ImageModelV2 } from '@ai-sdk/provider'
import type { LanguageModel } from 'ai'
import type { ProviderId } from '../providers'
import type { AiPlugin, AiRequestContext } from './types'
@ -9,16 +12,16 @@ export { PluginManager } from './manager'
// 工具函数
export function createContext<T extends ProviderId>(
providerId: T,
modelId: string,
model: LanguageModel | ImageModelV2,
originalParams: any
): AiRequestContext {
return {
providerId,
modelId,
model,
originalParams,
metadata: {},
startTime: Date.now(),
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
requestId: `${providerId}-${typeof model === 'string' ? model : model?.modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
// 占位
recursiveCall: () => Promise.resolve(null)
}

View File

@ -14,7 +14,7 @@ export type RecursiveCallFn = (newParams: any) => Promise<any>
*/
export interface AiRequestContext {
providerId: ProviderId
modelId: string
model: LanguageModel | ImageModelV2
originalParams: any
metadata: Record<string, any>
startTime: number

View File

@ -83,7 +83,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
}
// 使用正确的createContext创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
const context = _context ? _context : createContext(this.providerId, model, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => {
@ -159,7 +159,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
}
// 使用正确的createContext创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
const context = _context ? _context : createContext(this.providerId, model, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => {
@ -235,7 +235,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
}
// 创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
const context = _context ? _context : createContext(this.providerId, model, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => {

View File

@ -152,12 +152,14 @@ export class AiSdkToChunkAdapter {
// this.toolCallHandler.handleToolCallCreated(chunk)
// break
case 'tool-call':
// 原始的工具调用(未被中间件处理)
this.toolCallHandler.handleToolCall(chunk)
break
case 'tool-error':
this.toolCallHandler.handleToolError(chunk)
break
case 'tool-result':
// 原始的工具调用结果(未被中间件处理)
this.toolCallHandler.handleToolResult(chunk)
break
@ -167,7 +169,6 @@ export class AiSdkToChunkAdapter {
// type: ChunkType.LLM_RESPONSE_CREATED
// })
// break
// TODO: 需要区分接口开始和步骤开始
// case 'start-step':
// this.onChunk({
// type: ChunkType.BLOCK_CREATED
@ -305,8 +306,6 @@ export class AiSdkToChunkAdapter {
break
default:
// 其他类型的 chunk 可以忽略或记录日志
// console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
}
}
}

View File

@ -8,34 +8,61 @@ import { loggerService } from '@logger'
import { processKnowledgeReferences } from '@renderer/services/KnowledgeService'
import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai'
// import type {
// AnthropicSearchOutput,
// WebSearchPluginConfig
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
import type { ToolSet, TypedToolCall, TypedToolError, TypedToolResult } from 'ai'
const logger = loggerService.withContext('ToolCallChunkHandler')
export type ToolcallsMap = {
toolCallId: string
toolName: string
args: any
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
tool: BaseTool
}
/**
*
*/
export class ToolCallChunkHandler {
// private onChunk: (chunk: Chunk) => void
private activeToolCalls = new Map<
string,
{
toolCallId: string
toolName: string
args: any
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
tool: BaseTool
}
>()
private static globalActiveToolCalls = new Map<string, ToolcallsMap>()
private activeToolCalls = ToolCallChunkHandler.globalActiveToolCalls
constructor(
private onChunk: (chunk: Chunk) => void,
private mcpTools: MCPTool[]
) {}
/**
*
*/
private static addActiveToolCallImpl(toolCallId: string, map: ToolcallsMap): boolean {
if (!ToolCallChunkHandler.globalActiveToolCalls.has(toolCallId)) {
ToolCallChunkHandler.globalActiveToolCalls.set(toolCallId, map)
return true
}
return false
}
/**
*
*/
private addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
}
/**
*
*/
public static getActiveToolCalls() {
return ToolCallChunkHandler.globalActiveToolCalls
}
/**
* 访
*/
public static addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
}
// /**
// * 设置 onChunk 回调
// */
@ -43,103 +70,103 @@ export class ToolCallChunkHandler {
// this.onChunk = callback
// }
handleToolCallCreated(
chunk:
| {
type: 'tool-input-start'
id: string
toolName: string
providerMetadata?: ProviderMetadata
providerExecuted?: boolean
}
| {
type: 'tool-input-end'
id: string
providerMetadata?: ProviderMetadata
}
| {
type: 'tool-input-delta'
id: string
delta: string
providerMetadata?: ProviderMetadata
}
): void {
switch (chunk.type) {
case 'tool-input-start': {
// 能拿到说明是mcpTool
// if (this.activeToolCalls.get(chunk.id)) return
// handleToolCallCreated(
// chunk:
// | {
// type: 'tool-input-start'
// id: string
// toolName: string
// providerMetadata?: ProviderMetadata
// providerExecuted?: boolean
// }
// | {
// type: 'tool-input-end'
// id: string
// providerMetadata?: ProviderMetadata
// }
// | {
// type: 'tool-input-delta'
// id: string
// delta: string
// providerMetadata?: ProviderMetadata
// }
// ): void {
// switch (chunk.type) {
// case 'tool-input-start': {
// // 能拿到说明是mcpTool
// // if (this.activeToolCalls.get(chunk.id)) return
const tool: BaseTool | MCPTool = {
id: chunk.id,
name: chunk.toolName,
description: chunk.toolName,
type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
}
this.activeToolCalls.set(chunk.id, {
toolCallId: chunk.id,
toolName: chunk.toolName,
args: '',
tool
})
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: chunk.id,
tool: tool,
arguments: {},
status: 'pending',
toolCallId: chunk.id
}
this.onChunk({
type: ChunkType.MCP_TOOL_PENDING,
responses: [toolResponse]
})
break
}
case 'tool-input-delta': {
const toolCall = this.activeToolCalls.get(chunk.id)
if (!toolCall) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
return
}
toolCall.args += chunk.delta
break
}
case 'tool-input-end': {
const toolCall = this.activeToolCalls.get(chunk.id)
this.activeToolCalls.delete(chunk.id)
if (!toolCall) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
return
}
// const toolResponse: ToolCallResponse = {
// id: toolCall.toolCallId,
// tool: toolCall.tool,
// arguments: toolCall.args,
// status: 'pending',
// toolCallId: toolCall.toolCallId
// }
// logger.debug('toolResponse', toolResponse)
// this.onChunk({
// type: ChunkType.MCP_TOOL_PENDING,
// responses: [toolResponse]
// })
break
}
}
// if (!toolCall) {
// Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
// return
// }
// this.onChunk({
// type: ChunkType.MCP_TOOL_CREATED,
// tool_calls: [
// {
// id: chunk.id,
// name: chunk.toolName,
// status: 'pending'
// }
// ]
// })
}
// const tool: BaseTool | MCPTool = {
// id: chunk.id,
// name: chunk.toolName,
// description: chunk.toolName,
// type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
// }
// this.activeToolCalls.set(chunk.id, {
// toolCallId: chunk.id,
// toolName: chunk.toolName,
// args: '',
// tool
// })
// const toolResponse: MCPToolResponse | NormalToolResponse = {
// id: chunk.id,
// tool: tool,
// arguments: {},
// status: 'pending',
// toolCallId: chunk.id
// }
// this.onChunk({
// type: ChunkType.MCP_TOOL_PENDING,
// responses: [toolResponse]
// })
// break
// }
// case 'tool-input-delta': {
// const toolCall = this.activeToolCalls.get(chunk.id)
// if (!toolCall) {
// logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
// return
// }
// toolCall.args += chunk.delta
// break
// }
// case 'tool-input-end': {
// const toolCall = this.activeToolCalls.get(chunk.id)
// this.activeToolCalls.delete(chunk.id)
// if (!toolCall) {
// logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
// return
// }
// // const toolResponse: ToolCallResponse = {
// // id: toolCall.toolCallId,
// // tool: toolCall.tool,
// // arguments: toolCall.args,
// // status: 'pending',
// // toolCallId: toolCall.toolCallId
// // }
// // logger.debug('toolResponse', toolResponse)
// // this.onChunk({
// // type: ChunkType.MCP_TOOL_PENDING,
// // responses: [toolResponse]
// // })
// break
// }
// }
// // if (!toolCall) {
// // Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
// // return
// // }
// // this.onChunk({
// // type: ChunkType.MCP_TOOL_CREATED,
// // tool_calls: [
// // {
// // id: chunk.id,
// // name: chunk.toolName,
// // status: 'pending'
// // }
// // ]
// // })
// }
/**
*
@ -158,7 +185,6 @@ export class ToolCallChunkHandler {
let tool: BaseTool
let mcpTool: MCPTool | undefined
// 根据 providerExecuted 标志区分处理逻辑
if (providerExecuted) {
// 如果是 Provider 执行的工具(如 web_search
@ -196,27 +222,25 @@ export class ToolCallChunkHandler {
}
}
// 记录活跃的工具调用
this.activeToolCalls.set(toolCallId, {
this.addActiveToolCall(toolCallId, {
toolCallId,
toolName,
args,
tool
})
// 创建 MCPToolResponse 格式
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallId,
tool: tool,
arguments: args,
status: 'pending',
status: 'pending', // 统一使用 pending 状态
toolCallId: toolCallId
}
// 调用 onChunk
if (this.onChunk) {
this.onChunk({
type: ChunkType.MCP_TOOL_PENDING,
type: ChunkType.MCP_TOOL_PENDING, // 统一发送 pending 状态
responses: [toolResponse]
})
}
@ -276,4 +300,33 @@ export class ToolCallChunkHandler {
})
}
}
handleToolError(
chunk: {
type: 'tool-error'
} & TypedToolError<ToolSet>
): void {
const { toolCallId, error, input } = chunk
const toolCallInfo = this.activeToolCalls.get(toolCallId)
if (!toolCallInfo) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
return
}
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallId,
tool: toolCallInfo.tool,
arguments: input,
status: 'error',
response: error,
toolCallId: toolCallId
}
this.activeToolCalls.delete(toolCallId)
if (this.onChunk) {
this.onChunk({
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [toolResponse]
})
}
}
}
export const addActiveToolCall = ToolCallChunkHandler.addActiveToolCall.bind(ToolCallChunkHandler)

View File

@ -265,15 +265,15 @@ export default class ModernAiProvider {
params: StreamTextParams,
config: ModernAiProviderConfig
): Promise<CompletionsResult> {
const modelId = this.model!.id
logger.info('Starting modernCompletions', {
modelId,
providerId: this.config!.providerId,
topicId: config.topicId,
hasOnChunk: !!config.onChunk,
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
toolCount: params.tools ? Object.keys(params.tools).length : 0
})
// const modelId = this.model!.id
// logger.info('Starting modernCompletions', {
// modelId,
// providerId: this.config!.providerId,
// topicId: config.topicId,
// hasOnChunk: !!config.onChunk,
// hasTools: !!params.tools && Object.keys(params.tools).length > 0,
// toolCount: params.tools ? Object.keys(params.tools).length : 0
// })
// 根据条件构建插件数组
const plugins = buildPlugins(config)

View File

@ -35,7 +35,7 @@ export function buildPlugins(
plugins.push(webSearchPlugin())
}
// 2. 支持工具调用时添加搜索插件
if (middlewareConfig.isSupportedToolUse) {
if (middlewareConfig.isSupportedToolUse || middlewareConfig.isPromptToolUse) {
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant, middlewareConfig.topicId || ''))
}
@ -45,12 +45,13 @@ export function buildPlugins(
}
// 4. 启用Prompt工具调用时添加工具插件
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
if (middlewareConfig.isPromptToolUse) {
plugins.push(
createPromptToolUsePlugin({
enabled: true,
createSystemMessage: (systemPrompt, params, context) => {
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
const modelId = typeof context.model === 'string' ? context.model : context.model.modelId
if (modelId.includes('o1-mini') || modelId.includes('o1-preview')) {
if (context.isRecursiveCall) {
return null
}

View File

@ -19,7 +19,8 @@ import store from '@renderer/store'
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
import type { Assistant } from '@renderer/types'
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
import type { ModelMessage } from 'ai'
import type { LanguageModel, ModelMessage } from 'ai'
import { generateText } from 'ai'
import { isEmpty } from 'lodash'
import { MemoryProcessor } from '../../services/MemoryProcessor'
@ -76,9 +77,7 @@ async function analyzeSearchIntent(
shouldKnowledgeSearch?: boolean
shouldMemorySearch?: boolean
lastAnswer?: ModelMessage
context: AiRequestContext & {
isAnalyzing?: boolean
}
context: AiRequestContext
topicId: string
}
): Promise<ExtractResults | undefined> {
@ -122,9 +121,7 @@ async function analyzeSearchIntent(
logger.error('Provider not found or missing API key')
return getFallbackResult()
}
// console.log('formattedPrompt', schema)
try {
context.isAnalyzing = true
logger.info('Starting intent analysis generateText call', {
modelId: model.id,
topicId: options.topicId,
@ -133,18 +130,16 @@ async function analyzeSearchIntent(
hasKnowledgeSearch: needKnowledgeExtract
})
const { text: result } = await context.executor
.generateText(model.id, {
prompt: formattedPrompt
})
.finally(() => {
context.isAnalyzing = false
logger.info('Intent analysis generateText call completed', {
modelId: model.id,
topicId: options.topicId,
requestId: context.requestId
})
const { text: result } = await generateText({
model: context.model as LanguageModel,
prompt: formattedPrompt
}).finally(() => {
logger.info('Intent analysis generateText call completed', {
modelId: model.id,
topicId: options.topicId,
requestId: context.requestId
})
})
const parsedResult = extractInfoFromXML(result)
logger.debug('Intent analysis result', { parsedResult })
@ -183,7 +178,6 @@ async function storeConversationMemory(
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
if (!globalMemoryEnabled || !assistant.enableMemory) {
// console.log('Memory storage is disabled')
return
}
@ -245,25 +239,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
// 存储意图分析结果
const intentAnalysisResults: { [requestId: string]: ExtractResults } = {}
const userMessages: { [requestId: string]: ModelMessage } = {}
let currentContext: AiRequestContext | null = null
return definePlugin({
name: 'search-orchestration',
enforce: 'pre', // 确保在其他插件之前执行
configureContext: (context: AiRequestContext) => {
if (currentContext) {
context.isAnalyzing = currentContext.isAnalyzing
}
currentContext = context
},
/**
* 🔍 Step 1: 意图识别阶段
*/
onRequestStart: async (context: AiRequestContext) => {
if (context.isAnalyzing) return
// 没开启任何搜索则不进行意图分析
if (!(assistant.webSearchProviderId || assistant.knowledge_bases?.length || assistant.enableMemory)) return
@ -315,7 +298,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
* 🔧 Step 2: 工具配置阶段
*/
transformParams: async (params: any, context: AiRequestContext) => {
if (context.isAnalyzing) return params
// logger.info('🔧 Configuring tools based on intent...', context.requestId)
try {
@ -409,7 +391,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
// context.isAnalyzing = false
// logger.info('context.isAnalyzing', context, result)
// logger.info('💾 Starting memory storage...', context.requestId)
if (context.isAnalyzing) return
try {
const messages = context.originalParams.messages

View File

@ -19,8 +19,6 @@ export const memorySearchTool = () => {
limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return')
}),
execute: async ({ query, limit = 5 }) => {
// console.log('🧠 [memorySearchTool] Searching memories:', { query, limit })
try {
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
if (!globalMemoryEnabled) {
@ -29,7 +27,6 @@ export const memorySearchTool = () => {
const memoryConfig = selectMemoryConfig(store.getState())
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
// console.warn('Memory search skipped: embedding or LLM model not configured')
return []
}
@ -40,12 +37,10 @@ export const memorySearchTool = () => {
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
if (relevantMemories?.length > 0) {
// console.log('🧠 [memorySearchTool] Found memories:', relevantMemories.length)
return relevantMemories
}
return []
} catch (error) {
// console.error('🧠 [memorySearchTool] Error:', error)
return []
}
}
@ -84,8 +79,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
.optional()
}) satisfies z.ZodSchema<MemorySearchWithExtractionInput>,
execute: async ({ userMessage }) => {
// console.log('🧠 [memorySearchToolWithExtraction] Processing:', { userMessage, lastAnswer })
try {
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
if (!globalMemoryEnabled || !assistant.enableMemory) {
@ -97,7 +90,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
const memoryConfig = selectMemoryConfig(store.getState())
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
// console.warn('Memory search skipped: embedding or LLM model not configured')
return {
extractedKeywords: 'Memory models not configured',
searchResults: []
@ -125,7 +117,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
)
if (relevantMemories?.length > 0) {
// console.log('🧠 [memorySearchToolWithExtraction] Found memories:', relevantMemories.length)
return {
extractedKeywords: content,
searchResults: relevantMemories
@ -137,7 +128,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
searchResults: []
}
} catch (error) {
// console.error('🧠 [memorySearchToolWithExtraction] Error:', error)
return {
extractedKeywords: 'Search failed',
searchResults: []

View File

@ -1,7 +1,5 @@
import { loggerService } from '@logger'
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
import { MCPTool, MCPToolResponse } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { callMCPTool, getMcpServerByTool, isToolAutoApproved } from '@renderer/utils/mcp-tools'
import { requestToolConfirmation } from '@renderer/utils/userConfirmation'
import { type Tool, type ToolSet } from 'ai'
@ -33,8 +31,36 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
tools[mcpTool.name] = tool({
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
execute: async (params, { toolCallId, experimental_context }) => {
const { onChunk } = experimental_context as { onChunk: (chunk: Chunk) => void }
execute: async (params, { toolCallId }) => {
// 检查是否启用自动批准
const server = getMcpServerByTool(mcpTool)
const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server)
let confirmed = true
if (!isAutoApproveEnabled) {
// 请求用户确认
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`)
confirmed = await requestToolConfirmation(toolCallId)
}
if (!confirmed) {
// 用户拒绝执行工具
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
return {
content: [
{
type: 'text',
text: `User declined to execute tool "${mcpTool.name}".`
}
],
isError: false
}
}
// 用户确认或自动批准,执行工具
logger.debug(`Executing tool: ${mcpTool.name}`)
// 创建适配的 MCPToolResponse 对象
const toolResponse: MCPToolResponse = {
id: toolCallId,
@ -44,53 +70,18 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
toolCallId
}
try {
// 检查是否启用自动批准
const server = getMcpServerByTool(mcpTool)
const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server)
const result = await callMCPTool(toolResponse)
let confirmed = true
if (!isAutoApproveEnabled) {
// 请求用户确认
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`)
confirmed = await requestToolConfirmation(toolResponse.id)
}
if (!confirmed) {
// 用户拒绝执行工具
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
return {
content: [
{
type: 'text',
text: `User declined to execute tool "${mcpTool.name}".`
}
],
isError: false
}
}
// 用户确认或自动批准,执行工具
toolResponse.status = 'invoking'
logger.debug(`Executing tool: ${mcpTool.name}`)
onChunk({
type: ChunkType.MCP_TOOL_IN_PROGRESS,
responses: [toolResponse]
})
const result = await callMCPTool(toolResponse)
// 返回结果AI SDK 会处理序列化
if (result.isError) {
throw new Error(result.content?.[0]?.text || 'Tool execution failed')
}
// 返回工具执行结果
return result
} catch (error) {
logger.error(`MCP Tool execution failed: ${mcpTool.name}`, { error })
throw error
// 返回结果AI SDK 会处理序列化
if (result.isError) {
// throw new Error(result.content?.[0]?.text || 'Tool execution failed')
return Promise.reject(result)
}
// 返回工具执行结果
return result
// } catch (error) {
// logger.error(`MCP Tool execution failed: ${mcpTool.name}`, { error })
// }
}
})
}

View File

@ -105,7 +105,6 @@ const ThinkingBlock: React.FC<Props> = ({ block }) => {
const ThinkingTimeSeconds = memo(
({ blockThinkingTime, isThinking }: { blockThinkingTime: number; isThinking: boolean }) => {
const { t } = useTranslation()
// console.log('blockThinkingTime', blockThinkingTime)
// const [thinkingTime, setThinkingTime] = useState(blockThinkingTime || 0)
// FIXME: 这里统计的和请求处统计的有一定误差

View File

@ -186,7 +186,6 @@ const MessageMenubar: FC<Props> = (props) => {
try {
await translateText(mainTextContent, language, translationUpdater)
} catch (error) {
// console.error('Translation failed:', error)
window.message.error({ content: t('translate.error.failed'), key: 'translate-message' })
// 理应只有一个
const translationBlocks = findTranslationBlocksById(message.id)

View File

@ -60,14 +60,29 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
const toolResponse = block.metadata?.rawMcpToolResponse
const { id, tool, status, response } = toolResponse!
const isPending = status === 'pending'
const isInvoking = status === 'invoking'
const isDone = status === 'done'
const isError = status === 'error'
const isAutoApproved = useMemo(
() =>
isToolAutoApproved(
tool,
mcpServers.find((s) => s.id === tool.serverId)
),
[tool, mcpServers]
)
// 增加本地状态来跟踪用户确认
const [isConfirmed, setIsConfirmed] = useState(isAutoApproved)
// 判断不同的UI状态
const isWaitingConfirmation = isPending && !isAutoApproved && !isConfirmed
const isExecuting = isPending && (isAutoApproved || isConfirmed)
const timer = useRef<NodeJS.Timeout | null>(null)
useEffect(() => {
if (!isPending) return
if (!isWaitingConfirmation) return
if (countdown > 0) {
timer.current = setTimeout(() => {
@ -75,6 +90,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
setCountdown((prev) => prev - 1)
}, 1000)
} else if (countdown === 0) {
setIsConfirmed(true)
confirmToolAction(id)
}
@ -83,7 +99,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
clearTimeout(timer.current)
}
}
}, [countdown, id, isPending])
}, [countdown, id, isWaitingConfirmation])
useEffect(() => {
const removeListener = window.electron.ipcRenderer.on(
@ -146,6 +162,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
const handleConfirmTool = () => {
cancelCountdown()
setIsConfirmed(true)
confirmToolAction(id)
}
@ -195,6 +212,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
updateMCPServer(updatedServer)
// Also confirm the current tool
setIsConfirmed(true)
confirmToolAction(id)
window.message.success({
@ -206,32 +224,31 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
const renderStatusIndicator = (status: string, hasError: boolean) => {
let label = ''
let icon: React.ReactNode | null = null
switch (status) {
case 'pending':
if (status === 'pending') {
if (isWaitingConfirmation) {
label = t('message.tools.pending', 'Awaiting Approval')
icon = <LoadingIcon style={{ marginLeft: 6, color: 'var(--status-color-warning)' }} />
break
case 'invoking':
} else if (isExecuting) {
label = t('message.tools.invoking')
icon = <LoadingIcon style={{ marginLeft: 6 }} />
break
case 'cancelled':
label = t('message.tools.cancelled')
icon = <X size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
break
case 'done':
if (hasError) {
label = t('message.tools.error')
icon = <TriangleAlert size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
} else {
label = t('message.tools.completed')
icon = <Check size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
}
break
default:
label = ''
icon = null
}
} else if (status === 'cancelled') {
label = t('message.tools.cancelled')
icon = <X size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
} else if (status === 'done') {
if (hasError) {
label = t('message.tools.error')
icon = <TriangleAlert size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
} else {
label = t('message.tools.completed')
icon = <Check size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
}
} else if (status === 'error') {
label = t('message.tools.error')
icon = <TriangleAlert size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
}
return (
<StatusIndicator status={status} hasError={hasError}>
{label}
@ -248,7 +265,6 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
params: toolResponse.arguments,
response: toolResponse.response
}
items.push({
key: id,
label: (
@ -283,7 +299,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
<Maximize size={14} />
</ActionButton>
</Tooltip>
{!isPending && !isInvoking && (
{!isPending && (
<Tooltip title={t('common.copy')} mouseEnterDelay={0.5}>
<ActionButton
className="message-action-button"
@ -301,7 +317,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
</MessageTitleLabel>
),
children:
isDone && result ? (
(isDone || isError) && result ? (
<ToolResponseContainer
style={{
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
@ -370,7 +386,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
}
}}>
<ToolContainer>
<ToolContentWrapper className={status}>
<ToolContentWrapper className={isPending ? 'pending' : status}>
<CollapseContainer
ghost
activeKey={activeKeys}
@ -383,14 +399,16 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
<ExpandIcon $isActive={isActive} size={18} color="var(--color-text-3)" strokeWidth={1.5} />
)}
/>
{(isPending || isInvoking) && (
{isPending && (
<ActionsBar>
<ActionLabel>
{isPending ? t('settings.mcp.tools.autoApprove.tooltip.confirm') : t('message.tools.invoking')}
{isWaitingConfirmation
? t('settings.mcp.tools.autoApprove.tooltip.confirm')
: t('message.tools.invoking')}
</ActionLabel>
<ActionButtonsGroup>
{isPending && (
{isWaitingConfirmation && (
<Button
color="danger"
variant="filled"
@ -402,7 +420,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
{t('common.cancel')}
</Button>
)}
{isInvoking && toolResponse?.id ? (
{isExecuting && toolResponse?.id ? (
<Button
size="small"
color="danger"
@ -416,29 +434,31 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
{t('chat.input.pause')}
</Button>
) : (
<StyledDropdownButton
size="small"
type="primary"
icon={<ChevronDown size={14} />}
onClick={() => {
handleConfirmTool()
}}
menu={{
items: [
{
key: 'autoApprove',
label: t('settings.mcp.tools.autoApprove.label'),
onClick: () => {
handleAutoApprove()
isWaitingConfirmation && (
<StyledDropdownButton
size="small"
type="primary"
icon={<ChevronDown size={14} />}
onClick={() => {
handleConfirmTool()
}}
menu={{
items: [
{
key: 'autoApprove',
label: t('settings.mcp.tools.autoApprove.label'),
onClick: () => {
handleAutoApprove()
}
}
}
]
}}>
<CirclePlay size={15} className="lucide-custom" />
<CountdownText>
{t('settings.mcp.tools.run', 'Run')} ({countdown}s)
</CountdownText>
</StyledDropdownButton>
]
}}>
<CirclePlay size={15} className="lucide-custom" />
<CountdownText>
{t('settings.mcp.tools.run', 'Run')} ({countdown}s)
</CountdownText>
</StyledDropdownButton>
)
)}
</ActionButtonsGroup>
</ActionsBar>
@ -542,8 +562,7 @@ const ToolContentWrapper = styled.div`
border: 1px solid var(--color-border);
}
&.pending,
&.invoking {
&.pending {
background-color: var(--color-background-soft);
.ant-collapse {
border: none;
@ -663,6 +682,8 @@ const StatusIndicator = styled.span<{ status: string; hasError?: boolean }>`
return 'var(--status-color-error)'
case 'done':
return props.hasError ? 'var(--status-color-error)' : 'var(--status-color-success)'
case 'error':
return 'var(--status-color-error)'
default:
return 'var(--color-text)'
}

View File

@ -99,6 +99,8 @@ export async function fetchChatCompletion({
const provider = AI.getActualProvider()
const mcpTools: MCPTool[] = []
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
if (isPromptToolUse(assistant) || isSupportedToolUse(assistant)) {
mcpTools.push(...(await fetchMcpTools(assistant)))
}
@ -137,7 +139,6 @@ export async function fetchChatCompletion({
}
// --- Call AI Completions ---
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
await AI.completions(modelId, aiSdkParams, {
...middlewareConfig,
assistant,

View File

@ -178,12 +178,7 @@ class WebSearchService {
formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}`
}
// try {
return await webSearchEngine.search(formattedQuery, websearch, httpOptions)
// } catch (error) {
// console.error('Search failed:', error)
// throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
// }
}
/**

View File

@ -69,16 +69,6 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
})
await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
},
// onBlockCreated: async () => {
// if (blockManager.hasInitialPlaceholder) {
// return
// }
// console.log('onBlockCreated')
// const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, {
// status: MessageBlockStatus.PROCESSING
// })
// await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
// },
onError: async (error: AISDKError) => {
logger.debug('onError', error)

View File

@ -51,28 +51,6 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
}
},
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
// 根据 toolResponse.id 查找对应的块ID
const targetBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
if (targetBlockId && toolResponse.status === 'invoking') {
const changes = {
status: MessageBlockStatus.PROCESSING,
metadata: { rawMcpToolResponse: toolResponse }
}
blockManager.smartBlockUpdate(targetBlockId, changes, MessageBlockType.TOOL)
} else if (!targetBlockId) {
logger.warn(
`[onToolCallInProgress] No block ID found for tool ID: ${toolResponse.id}. Available mappings:`,
Array.from(toolCallIdToBlockIdMap.entries())
)
} else {
logger.warn(
`[onToolCallInProgress] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
)
}
},
onToolCallComplete: (toolResponse: MCPToolResponse) => {
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
toolCallIdToBlockIdMap.delete(toolResponse.id)
@ -104,7 +82,6 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
stack: null
}
}
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
// Handle citation block creation for web search results
@ -132,7 +109,6 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
citationBlockId = citationBlock.id
blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
}
// TODO: 处理 memory 引用
} else {
logger.warn(
`[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`

View File

@ -458,18 +458,18 @@ describe('streamCallback Integration Tests', () => {
}
]
},
{
type: ChunkType.MCP_TOOL_IN_PROGRESS,
responses: [
{
id: 'tool-call-1',
tool: mockTool,
arguments: { testArg: 'value' },
status: 'invoking' as const,
response: ''
}
]
},
// {
// type: ChunkType.MCP_TOOL_PENDING,
// responses: [
// {
// id: 'tool-call-1',
// tool: mockTool,
// arguments: { testArg: 'value' },
// status: 'invoking' as const,
// response: ''
// }
// ]
// },
{
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [
@ -611,18 +611,18 @@ describe('streamCallback Integration Tests', () => {
}
]
},
{
type: ChunkType.MCP_TOOL_IN_PROGRESS,
responses: [
{
id: 'tool-call-1',
tool: mockCalculatorTool,
arguments: { operation: 'add', a: 1, b: 2 },
status: 'invoking' as const,
response: ''
}
]
},
// {
// type: ChunkType.MCP_TOOL_PENDING,
// responses: [
// {
// id: 'tool-call-1',
// tool: mockCalculatorTool,
// arguments: { operation: 'add', a: 1, b: 2 },
// status: 'invoking' as const,
// response: ''
// }
// ]
// },
{
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [

View File

@ -826,9 +826,7 @@ export function mcpToolCallResponseToAwsBedrockMessage(
}
/**
* 使
* 1. 使
* 2. 使 prompt使
* 使(function call)
* @param assistant
* @returns 使
*/

View File

@ -179,15 +179,15 @@ __metadata:
languageName: node
linkType: hard
"@ai-sdk/google@npm:^2.0.7":
version: 2.0.7
resolution: "@ai-sdk/google@npm:2.0.7"
"@ai-sdk/google@npm:^2.0.13":
version: 2.0.13
resolution: "@ai-sdk/google@npm:2.0.13"
dependencies:
"@ai-sdk/provider": "npm:2.0.0"
"@ai-sdk/provider-utils": "npm:3.0.4"
"@ai-sdk/provider-utils": "npm:3.0.8"
peerDependencies:
zod: ^3.25.76 || ^4
checksum: 10c0/bde4c95a2a167355cda18de9d5b273d562d2a724f650ca69016daa8df2766280487e143cf0cdd96f6654c255d587a680c6a937b280eb734ca2c35d6f9b9e943c
checksum: 10c0/a05210de11d7ab41d49bcd0330c37f4116441b149d8ccc9b6bc5eaa12ea42bae82364dc2cd09502734b15115071f07395525806ea4998930b285b1ce74102186
languageName: node
linkType: hard
@ -227,15 +227,15 @@ __metadata:
languageName: node
linkType: hard
"@ai-sdk/openai@npm:^2.0.19":
version: 2.0.19
resolution: "@ai-sdk/openai@npm:2.0.19"
"@ai-sdk/openai@npm:^2.0.26":
version: 2.0.26
resolution: "@ai-sdk/openai@npm:2.0.26"
dependencies:
"@ai-sdk/provider": "npm:2.0.0"
"@ai-sdk/provider-utils": "npm:3.0.5"
"@ai-sdk/provider-utils": "npm:3.0.8"
peerDependencies:
zod: ^3.25.76 || ^4
checksum: 10c0/04db695669d783a810b80283e0cd48f6e7654667fd76ca2d35c7cffae6fdd68fb0473118e4e097ef1352f4432dd7c15c07f873d712b940c72495e5839b0ede98
checksum: 10c0/b8cb01c0c38525c38901f41f1693cd15589932a2aceddea14bed30f44719532a5e74615fb0e974eff1a0513048ac204c27456ff8829a9c811d1461cc635c9cc5
languageName: node
linkType: hard
@ -267,20 +267,6 @@ __metadata:
languageName: node
linkType: hard
"@ai-sdk/provider-utils@npm:3.0.5":
version: 3.0.5
resolution: "@ai-sdk/provider-utils@npm:3.0.5"
dependencies:
"@ai-sdk/provider": "npm:2.0.0"
"@standard-schema/spec": "npm:^1.0.0"
eventsource-parser: "npm:^3.0.3"
zod-to-json-schema: "npm:^3.24.1"
peerDependencies:
zod: ^3.25.76 || ^4
checksum: 10c0/4057810b320bda149a178dc1bfc9cdd592ca88b736c3c22bd0c1f8111c75ef69beec4a523f363e5d0d120348b876942fd66c0bb4965864da4c12c5cfddee15a3
languageName: node
linkType: hard
"@ai-sdk/provider-utils@npm:3.0.7":
version: 3.0.7
resolution: "@ai-sdk/provider-utils@npm:3.0.7"
@ -294,6 +280,19 @@ __metadata:
languageName: node
linkType: hard
"@ai-sdk/provider-utils@npm:3.0.8":
version: 3.0.8
resolution: "@ai-sdk/provider-utils@npm:3.0.8"
dependencies:
"@ai-sdk/provider": "npm:2.0.0"
"@standard-schema/spec": "npm:^1.0.0"
eventsource-parser: "npm:^3.0.5"
peerDependencies:
zod: ^3.25.76 || ^4
checksum: 10c0/f466657c886cbb9f7ecbcd2dd1abc51a88af9d3f1cff030f7e97e70a4790a99f3338ad886e9c0dccf04dacdcc84522c7d57119b9a4e8e1d84f2dae9c893c397e
languageName: node
linkType: hard
"@ai-sdk/provider@npm:2.0.0, @ai-sdk/provider@npm:^2.0.0":
version: 2.0.0
resolution: "@ai-sdk/provider@npm:2.0.0"
@ -2249,8 +2248,8 @@ __metadata:
"@ai-sdk/anthropic": "npm:^2.0.5"
"@ai-sdk/azure": "npm:^2.0.16"
"@ai-sdk/deepseek": "npm:^1.0.9"
"@ai-sdk/google": "npm:^2.0.7"
"@ai-sdk/openai": "npm:^2.0.19"
"@ai-sdk/google": "npm:^2.0.13"
"@ai-sdk/openai": "npm:^2.0.26"
"@ai-sdk/openai-compatible": "npm:^1.0.9"
"@ai-sdk/provider": "npm:^2.0.0"
"@ai-sdk/provider-utils": "npm:^3.0.4"