refactor(ToolCall): refactor:mcp-tool-state-management (#10028)

This commit is contained in:
MyPrototypeWhat 2025-09-08 23:29:34 +08:00 committed by GitHub
parent 7df1060370
commit f6ffd574bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 594 additions and 788 deletions

View File

@ -1,103 +0,0 @@
/**
* Hub Provider 使
*
* 使Hub Provider功能来路由到多个底层provider
*/
import { createHubProvider, initializeProvider, providerRegistry } from '../src/index'
async function demonstrateHubProvider() {
try {
// 1. 初始化底层providers
console.log('📦 初始化底层providers...')
initializeProvider('openai', {
apiKey: process.env.OPENAI_API_KEY || 'sk-test-key'
})
initializeProvider('anthropic', {
apiKey: process.env.ANTHROPIC_API_KEY || 'sk-ant-test-key'
})
// 2. 创建Hub Provider自动包含所有已初始化的providers
console.log('🌐 创建Hub Provider...')
const aihubmixProvider = createHubProvider({
hubId: 'aihubmix',
debug: true
})
// 3. 注册Hub Provider
providerRegistry.registerProvider('aihubmix', aihubmixProvider)
console.log('✅ Hub Provider "aihubmix" 注册成功')
// 4. 使用Hub Provider访问不同的模型
console.log('\n🚀 使用Hub模型...')
// 通过Hub路由到OpenAI
const openaiModel = providerRegistry.languageModel('aihubmix:openai:gpt-4')
console.log('✓ OpenAI模型已获取:', openaiModel.modelId)
// 通过Hub路由到Anthropic
const anthropicModel = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
console.log('✓ Anthropic模型已获取:', anthropicModel.modelId)
// 5. 演示错误处理
console.log('\n❌ 演示错误处理...')
try {
// 尝试访问未初始化的provider
providerRegistry.languageModel('aihubmix:google:gemini-pro')
} catch (error) {
console.log('预期错误:', error.message)
}
try {
// 尝试使用错误的模型ID格式
providerRegistry.languageModel('aihubmix:invalid-format')
} catch (error) {
console.log('预期错误:', error.message)
}
// 6. 多个Hub Provider示例
console.log('\n🔄 创建多个Hub Provider...')
const localHubProvider = createHubProvider({
hubId: 'local-ai'
})
providerRegistry.registerProvider('local-ai', localHubProvider)
console.log('✅ Hub Provider "local-ai" 注册成功')
console.log('\n🎉 Hub Provider演示完成')
} catch (error) {
console.error('💥 演示过程中发生错误:', error)
}
}
// 演示简化的使用方式
function simplifiedUsageExample() {
console.log('\n📝 简化使用示例:')
console.log(`
// 1. 初始化providers
initializeProvider('openai', { apiKey: 'sk-xxx' })
initializeProvider('anthropic', { apiKey: 'sk-ant-xxx' })
// 2. 创建并注册Hub Provider
const hubProvider = createHubProvider({ hubId: 'aihubmix' })
providerRegistry.registerProvider('aihubmix', hubProvider)
// 3. 直接使用
const model1 = providerRegistry.languageModel('aihubmix:openai:gpt-4')
const model2 = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
`)
}
// 运行演示
if (require.main === module) {
demonstrateHubProvider()
simplifiedUsageExample()
}
export { demonstrateHubProvider, simplifiedUsageExample }

View File

@ -1,167 +0,0 @@
/**
* Image Generation Example
* 使 aiCore
*/
import { createExecutor, generateImage } from '../src/index'
async function main() {
// 方式1: 使用执行器实例
console.log('📸 创建 OpenAI 图像生成执行器...')
const executor = createExecutor('openai', {
apiKey: process.env.OPENAI_API_KEY!
})
try {
console.log('🎨 使用执行器生成图像...')
const result1 = await executor.generateImage('dall-e-3', {
prompt: 'A futuristic cityscape at sunset with flying cars',
size: '1024x1024',
n: 1
})
console.log('✅ 图像生成成功!')
console.log('📊 结果:', {
imagesCount: result1.images.length,
mediaType: result1.image.mediaType,
hasBase64: !!result1.image.base64,
providerMetadata: result1.providerMetadata
})
} catch (error) {
console.error('❌ 执行器生成失败:', error)
}
// 方式2: 使用直接调用 API
try {
console.log('🎨 使用直接 API 生成图像...')
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
prompt: 'A magical forest with glowing mushrooms and fairy lights',
aspectRatio: '16:9',
providerOptions: {
openai: {
quality: 'hd',
style: 'vivid'
}
}
})
console.log('✅ 直接 API 生成成功!')
console.log('📊 结果:', {
imagesCount: result2.images.length,
mediaType: result2.image.mediaType,
hasBase64: !!result2.image.base64
})
} catch (error) {
console.error('❌ 直接 API 生成失败:', error)
}
// 方式3: 支持其他提供商 (Google Imagen)
if (process.env.GOOGLE_API_KEY) {
try {
console.log('🎨 使用 Google Imagen 生成图像...')
const googleExecutor = createExecutor('google', {
apiKey: process.env.GOOGLE_API_KEY!
})
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
prompt: 'A serene mountain lake at dawn with mist rising from the water',
aspectRatio: '1:1'
})
console.log('✅ Google Imagen 生成成功!')
console.log('📊 结果:', {
imagesCount: result3.images.length,
mediaType: result3.image.mediaType,
hasBase64: !!result3.image.base64
})
} catch (error) {
console.error('❌ Google Imagen 生成失败:', error)
}
}
// 方式4: 支持插件系统
const pluginExample = async () => {
console.log('🔌 演示插件系统...')
// 创建一个示例插件,用于修改提示词
const promptEnhancerPlugin = {
name: 'prompt-enhancer',
transformParams: async (params: any) => {
console.log('🔧 插件: 增强提示词...')
return {
...params,
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
}
},
transformResult: async (result: any) => {
console.log('🔧 插件: 处理结果...')
return {
...result,
enhanced: true
}
}
}
const executorWithPlugin = createExecutor(
'openai',
{
apiKey: process.env.OPENAI_API_KEY!
},
[promptEnhancerPlugin]
)
try {
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A cute robot playing in a garden'
})
console.log('✅ 插件系统生成成功!')
console.log('📊 结果:', {
imagesCount: result4.images.length,
enhanced: (result4 as any).enhanced,
mediaType: result4.image.mediaType
})
} catch (error) {
console.error('❌ 插件系统生成失败:', error)
}
}
await pluginExample()
}
// 错误处理演示
async function errorHandlingExample() {
console.log('⚠️ 演示错误处理...')
try {
const executor = createExecutor('openai', {
apiKey: 'invalid-key'
})
await executor.generateImage('dall-e-3', {
prompt: 'Test image'
})
} catch (error: any) {
console.log('✅ 成功捕获错误:', error.constructor.name)
console.log('📋 错误信息:', error.message)
console.log('🏷️ 提供商ID:', error.providerId)
console.log('🏷️ 模型ID:', error.modelId)
}
}
// 运行示例
if (require.main === module) {
main()
.then(() => {
console.log('🎉 所有示例完成!')
return errorHandlingExample()
})
.then(() => {
console.log('🎯 示例程序结束')
process.exit(0)
})
.catch((error) => {
console.error('💥 程序执行出错:', error)
process.exit(1)
})
}

View File

@ -1,6 +1,6 @@
{ {
"name": "@cherrystudio/ai-core", "name": "@cherrystudio/ai-core",
"version": "1.0.0-alpha.12", "version": "1.0.0-alpha.13",
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK", "description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
"main": "dist/index.js", "main": "dist/index.js",
"module": "dist/index.mjs", "module": "dist/index.mjs",
@ -39,8 +39,8 @@
"@ai-sdk/anthropic": "^2.0.5", "@ai-sdk/anthropic": "^2.0.5",
"@ai-sdk/azure": "^2.0.16", "@ai-sdk/azure": "^2.0.16",
"@ai-sdk/deepseek": "^1.0.9", "@ai-sdk/deepseek": "^1.0.9",
"@ai-sdk/google": "^2.0.7", "@ai-sdk/google": "^2.0.13",
"@ai-sdk/openai": "^2.0.19", "@ai-sdk/openai": "^2.0.26",
"@ai-sdk/openai-compatible": "^1.0.9", "@ai-sdk/openai-compatible": "^1.0.9",
"@ai-sdk/provider": "^2.0.0", "@ai-sdk/provider": "^2.0.0",
"@ai-sdk/provider-utils": "^3.0.4", "@ai-sdk/provider-utils": "^3.0.4",

View File

@ -84,7 +84,6 @@ export class ModelResolver {
*/ */
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 { private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}` const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
console.log('fullModelId', fullModelId)
return globalRegistryManagement.languageModel(fullModelId as any) return globalRegistryManagement.languageModel(fullModelId as any)
} }

View File

@ -27,10 +27,20 @@ export class StreamEventManager {
/** /**
* *
*/ */
sendStepFinishEvent(controller: StreamController, chunk: any): void { sendStepFinishEvent(
controller: StreamController,
chunk: any,
context: AiRequestContext,
finishReason: string = 'stop'
): void {
// 累加当前步骤的 usage
if (chunk.usage && context.accumulatedUsage) {
this.accumulateUsage(context.accumulatedUsage, chunk.usage)
}
controller.enqueue({ controller.enqueue({
type: 'finish-step', type: 'finish-step',
finishReason: 'stop', finishReason,
response: chunk.response, response: chunk.response,
usage: chunk.usage, usage: chunk.usage,
providerMetadata: chunk.providerMetadata providerMetadata: chunk.providerMetadata
@ -43,28 +53,32 @@ export class StreamEventManager {
async handleRecursiveCall( async handleRecursiveCall(
controller: StreamController, controller: StreamController,
recursiveParams: any, recursiveParams: any,
context: AiRequestContext, context: AiRequestContext
stepId: string
): Promise<void> { ): Promise<void> {
try { // try {
console.log('[MCP Prompt] Starting recursive call after tool execution...') // 重置工具执行状态,准备处理新的步骤
context.hasExecutedToolsInCurrentStep = false
const recursiveResult = await context.recursiveCall(recursiveParams) const recursiveResult = await context.recursiveCall(recursiveParams)
if (recursiveResult && recursiveResult.fullStream) { if (recursiveResult && recursiveResult.fullStream) {
await this.pipeRecursiveStream(controller, recursiveResult.fullStream) await this.pipeRecursiveStream(controller, recursiveResult.fullStream, context)
} else { } else {
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult) console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
}
} catch (error) {
this.handleRecursiveCallError(controller, error, stepId)
} }
// } catch (error) {
// this.handleRecursiveCallError(controller, error, stepId)
// }
} }
/** /**
* *
*/ */
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> { private async pipeRecursiveStream(
controller: StreamController,
recursiveStream: ReadableStream,
context?: AiRequestContext
): Promise<void> {
const reader = recursiveStream.getReader() const reader = recursiveStream.getReader()
try { try {
while (true) { while (true) {
@ -73,9 +87,16 @@ export class StreamEventManager {
break break
} }
if (value.type === 'finish') { if (value.type === 'finish') {
// 迭代的流不发finish // 迭代的流不发finish但需要累加其 usage
if (value.usage && context?.accumulatedUsage) {
this.accumulateUsage(context.accumulatedUsage, value.usage)
}
break break
} }
// 对于 finish-step 类型,累加其 usage
if (value.type === 'finish-step' && value.usage && context?.accumulatedUsage) {
this.accumulateUsage(context.accumulatedUsage, value.usage)
}
// 将递归流的数据传递到当前流 // 将递归流的数据传递到当前流
controller.enqueue(value) controller.enqueue(value)
} }
@ -87,25 +108,25 @@ export class StreamEventManager {
/** /**
* *
*/ */
private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void { // private handleRecursiveCallError(controller: StreamController, error: unknown): void {
console.error('[MCP Prompt] Recursive call failed:', error) // console.error('[MCP Prompt] Recursive call failed:', error)
// 使用 AI SDK 标准错误格式,但不中断流 // // 使用 AI SDK 标准错误格式,但不中断流
controller.enqueue({ // controller.enqueue({
type: 'error', // type: 'error',
error: { // error: {
message: error instanceof Error ? error.message : String(error), // message: error instanceof Error ? error.message : String(error),
name: error instanceof Error ? error.name : 'RecursiveCallError' // name: error instanceof Error ? error.name : 'RecursiveCallError'
} // }
}) // })
// 继续发送文本增量,保持流的连续性 // // // 继续发送文本增量,保持流的连续性
controller.enqueue({ // // controller.enqueue({
type: 'text-delta', // // type: 'text-delta',
id: stepId, // // id: stepId,
text: '\n\n[工具执行后递归调用失败,继续对话...]' // // text: '\n\n[工具执行后递归调用失败,继续对话...]'
}) // // })
} // }
/** /**
* *
@ -136,4 +157,18 @@ export class StreamEventManager {
return recursiveParams return recursiveParams
} }
/**
* usage
*/
private accumulateUsage(target: any, source: any): void {
if (!target || !source) return
// 累加各种 token 类型
target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0)
target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0)
target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0)
target.reasoningTokens = (target.reasoningTokens || 0) + (source.reasoningTokens || 0)
target.cachedInputTokens = (target.cachedInputTokens || 0) + (source.cachedInputTokens || 0)
}
} }

View File

@ -4,7 +4,7 @@
* *
* promptToolUsePlugin.ts * promptToolUsePlugin.ts
*/ */
import type { ToolSet } from 'ai' import type { ToolSet, TypedToolError } from 'ai'
import type { ToolUseResult } from './type' import type { ToolUseResult } from './type'
@ -38,7 +38,6 @@ export class ToolExecutor {
controller: StreamController controller: StreamController
): Promise<ExecutedResult[]> { ): Promise<ExecutedResult[]> {
const executedResults: ExecutedResult[] = [] const executedResults: ExecutedResult[] = []
for (const toolUse of toolUses) { for (const toolUse of toolUses) {
try { try {
const tool = tools[toolUse.toolName] const tool = tools[toolUse.toolName]
@ -46,17 +45,12 @@ export class ToolExecutor {
throw new Error(`Tool "${toolUse.toolName}" has no execute method`) throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
} }
// 发送工具调用开始事件
this.sendToolStartEvents(controller, toolUse)
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
// 发送 tool-call 事件 // 发送 tool-call 事件
controller.enqueue({ controller.enqueue({
type: 'tool-call', type: 'tool-call',
toolCallId: toolUse.id, toolCallId: toolUse.id,
toolName: toolUse.toolName, toolName: toolUse.toolName,
input: tool.inputSchema input: toolUse.arguments
}) })
const result = await tool.execute(toolUse.arguments, { const result = await tool.execute(toolUse.arguments, {
@ -111,45 +105,46 @@ export class ToolExecutor {
/** /**
* *
*/ */
private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void { // private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
// 发送 tool-input-start 事件 // // 发送 tool-input-start 事件
controller.enqueue({ // controller.enqueue({
type: 'tool-input-start', // type: 'tool-input-start',
id: toolUse.id, // id: toolUse.id,
toolName: toolUse.toolName // toolName: toolUse.toolName
}) // })
} // }
/** /**
* *
*/ */
private handleToolError( private handleToolError<T extends ToolSet>(
toolUse: ToolUseResult, toolUse: ToolUseResult,
error: unknown, error: unknown,
controller: StreamController controller: StreamController
// _tools: ToolSet
): ExecutedResult { ): ExecutedResult {
// 使用 AI SDK 标准错误格式 // 使用 AI SDK 标准错误格式
// const toolError: TypedToolError<typeof _tools> = { const toolError: TypedToolError<T> = {
// type: 'tool-error', type: 'tool-error',
// toolCallId: toolUse.id, toolCallId: toolUse.id,
// toolName: toolUse.toolName, toolName: toolUse.toolName,
// input: toolUse.arguments, input: toolUse.arguments,
// error: error instanceof Error ? error.message : String(error) error
// } }
// controller.enqueue(toolError) controller.enqueue(toolError)
// 发送标准错误事件 // 发送标准错误事件
controller.enqueue({ // controller.enqueue({
type: 'error', // type: 'tool-error',
error: error instanceof Error ? error.message : String(error) // toolCallId: toolUse.id,
}) // error: error instanceof Error ? error.message : String(error),
// input: toolUse.arguments
// })
return { return {
toolCallId: toolUse.id, toolCallId: toolUse.id,
toolName: toolUse.toolName, toolName: toolUse.toolName,
result: error instanceof Error ? error.message : String(error), result: error,
isError: true isError: true
} }
} }

View File

@ -8,9 +8,19 @@ import type { TextStreamPart, ToolSet } from 'ai'
import { definePlugin } from '../../index' import { definePlugin } from '../../index'
import type { AiRequestContext } from '../../types' import type { AiRequestContext } from '../../types'
import { StreamEventManager } from './StreamEventManager' import { StreamEventManager } from './StreamEventManager'
import { type TagConfig, TagExtractor } from './tagExtraction'
import { ToolExecutor } from './ToolExecutor' import { ToolExecutor } from './ToolExecutor'
import { PromptToolUseConfig, ToolUseResult } from './type' import { PromptToolUseConfig, ToolUseResult } from './type'
/**
* 使
*/
const TOOL_USE_TAG_CONFIG: TagConfig = {
openingTag: '<tool_use>',
closingTag: '</tool_use>',
separator: '\n'
}
/** /**
* Cherry Studio * Cherry Studio
*/ */
@ -249,13 +259,11 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
} }
context.mcpTools = params.tools context.mcpTools = params.tools
console.log('tools stored in context', params.tools)
// 构建系统提示符 // 构建系统提示符
const userSystemPrompt = typeof params.system === 'string' ? params.system : '' const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools) const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
let systemMessage: string | null = systemPrompt let systemMessage: string | null = systemPrompt
console.log('config.context', context)
if (config.createSystemMessage) { if (config.createSystemMessage) {
// 🎯 如果用户提供了自定义处理函数,使用它 // 🎯 如果用户提供了自定义处理函数,使用它
systemMessage = config.createSystemMessage(systemPrompt, params, context) systemMessage = config.createSystemMessage(systemPrompt, params, context)
@ -268,20 +276,40 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
tools: undefined tools: undefined
} }
context.originalParams = transformedParams context.originalParams = transformedParams
console.log('transformedParams', transformedParams)
return transformedParams return transformedParams
}, },
transformStream: (_: any, context: AiRequestContext) => () => { transformStream: (_: any, context: AiRequestContext) => () => {
let textBuffer = '' let textBuffer = ''
let stepId = '' // let stepId = ''
if (!context.mcpTools) { if (!context.mcpTools) {
throw new Error('No tools available') throw new Error('No tools available')
} }
// 创建工具执行器和流事件管理器 // 从 context 中获取或初始化 usage 累加器
if (!context.accumulatedUsage) {
context.accumulatedUsage = {
inputTokens: 0,
outputTokens: 0,
totalTokens: 0,
reasoningTokens: 0,
cachedInputTokens: 0
}
}
// 创建工具执行器、流事件管理器和标签提取器
const toolExecutor = new ToolExecutor() const toolExecutor = new ToolExecutor()
const streamEventManager = new StreamEventManager() const streamEventManager = new StreamEventManager()
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
// 在context中初始化工具执行状态避免递归调用时状态丢失
if (!context.hasExecutedToolsInCurrentStep) {
context.hasExecutedToolsInCurrentStep = false
}
// 用于hold text-start事件直到确认有非工具标签内容
let pendingTextStart: TextStreamPart<TOOLS> | null = null
let hasStartedText = false
type TOOLS = NonNullable<typeof context.mcpTools> type TOOLS = NonNullable<typeof context.mcpTools>
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({ return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
@ -289,83 +317,106 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
chunk: TextStreamPart<TOOLS>, chunk: TextStreamPart<TOOLS>,
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>> controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
) { ) {
// 收集文本内容 // Hold住text-start事件直到确认有非工具标签内容
if (chunk.type === 'text-delta') { if ((chunk as any).type === 'text-start') {
textBuffer += chunk.text || '' pendingTextStart = chunk
stepId = chunk.id || ''
controller.enqueue(chunk)
return return
} }
if (chunk.type === 'text-end' || chunk.type === 'finish-step') { // text-delta阶段收集文本内容并过滤工具标签
const tools = context.mcpTools if (chunk.type === 'text-delta') {
if (!tools || Object.keys(tools).length === 0) { textBuffer += chunk.text || ''
controller.enqueue(chunk) // stepId = chunk.id || ''
return
}
// 解析工具调用 // 使用TagExtractor过滤工具标签只传递非标签内容到UI层
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools) const extractionResults = tagExtractor.processText(chunk.text || '')
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
// 如果没有有效的工具调用,直接传递原始事件 for (const result of extractionResults) {
if (validToolUses.length === 0) { // 只传递非标签内容到UI层
controller.enqueue(chunk) if (!result.isTagContent && result.content) {
return // 如果还没有发送text-start且有pending的text-start先发送它
} if (!hasStartedText && pendingTextStart) {
controller.enqueue(pendingTextStart)
if (chunk.type === 'text-end') { hasStartedText = true
controller.enqueue({ pendingTextStart = null
type: 'text-end',
id: stepId,
providerMetadata: {
text: {
value: parsedContent
}
} }
})
return const filteredChunk = {
...chunk,
text: result.content
}
controller.enqueue(filteredChunk)
}
}
return
}
if (chunk.type === 'text-end') {
// 只有当已经发送了text-start时才发送text-end
if (hasStartedText) {
controller.enqueue(chunk)
}
return
}
if (chunk.type === 'finish-step') {
// 统一在finish-step阶段检查并执行工具调用
const tools = context.mcpTools
if (tools && Object.keys(tools).length > 0 && !context.hasExecutedToolsInCurrentStep) {
// 解析完整的textBuffer来检测工具调用
const { results: parsedTools } = parseToolUse(textBuffer, tools)
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
if (validToolUses.length > 0) {
context.hasExecutedToolsInCurrentStep = true
// 执行工具调用(不需要手动发送 start-step外部流已经处理
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
// 发送步骤完成事件,使用 tool-calls 作为 finishReason
streamEventManager.sendStepFinishEvent(controller, chunk, context, 'tool-calls')
// 处理递归调用
const toolResultsText = toolExecutor.formatToolResults(executedResults)
const recursiveParams = streamEventManager.buildRecursiveParams(
context,
textBuffer,
toolResultsText,
tools
)
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context)
return
}
} }
controller.enqueue({ // 如果没有执行工具调用直接传递原始finish-step事件
...chunk, controller.enqueue(chunk)
finishReason: 'tool-calls'
})
// 发送步骤开始事件
streamEventManager.sendStepStartEvent(controller)
// 执行工具调用
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
// 发送步骤完成事件
streamEventManager.sendStepFinishEvent(controller, chunk)
// 处理递归调用
if (validToolUses.length > 0) {
const toolResultsText = toolExecutor.formatToolResults(executedResults)
const recursiveParams = streamEventManager.buildRecursiveParams(
context,
textBuffer,
toolResultsText,
tools
)
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId)
}
// 清理状态 // 清理状态
textBuffer = '' textBuffer = ''
return return
} }
// 对于其他类型的事件,直接传递 // 处理 finish 类型,使用累加后的 totalUsage
controller.enqueue(chunk) if (chunk.type === 'finish') {
controller.enqueue({
...chunk,
totalUsage: context.accumulatedUsage
})
return
}
// 对于其他类型的事件直接传递不包括text-start已在上面处理
if ((chunk as any).type !== 'text-start') {
controller.enqueue(chunk)
}
}, },
flush() { flush() {
// 流结束时的清理工作 // 清理pending状态
console.log('[MCP Prompt] Stream ended, cleaning up...') pendingTextStart = null
hasStartedText = false
} }
}) })
} }

View File

@ -27,7 +27,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
case 'openai': { case 'openai': {
if (config.openai) { if (config.openai) {
if (!params.tools) params.tools = {} if (!params.tools) params.tools = {}
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai) params.tools.web_search = openai.tools.webSearch(config.openai)
} }
break break
} }

View File

@ -1,5 +1,8 @@
// 核心类型和接口 // 核心类型和接口
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types' export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
import type { ImageModelV2 } from '@ai-sdk/provider'
import type { LanguageModel } from 'ai'
import type { ProviderId } from '../providers' import type { ProviderId } from '../providers'
import type { AiPlugin, AiRequestContext } from './types' import type { AiPlugin, AiRequestContext } from './types'
@ -9,16 +12,16 @@ export { PluginManager } from './manager'
// 工具函数 // 工具函数
export function createContext<T extends ProviderId>( export function createContext<T extends ProviderId>(
providerId: T, providerId: T,
modelId: string, model: LanguageModel | ImageModelV2,
originalParams: any originalParams: any
): AiRequestContext { ): AiRequestContext {
return { return {
providerId, providerId,
modelId, model,
originalParams, originalParams,
metadata: {}, metadata: {},
startTime: Date.now(), startTime: Date.now(),
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`, requestId: `${providerId}-${typeof model === 'string' ? model : model?.modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
// 占位 // 占位
recursiveCall: () => Promise.resolve(null) recursiveCall: () => Promise.resolve(null)
} }

View File

@ -14,7 +14,7 @@ export type RecursiveCallFn = (newParams: any) => Promise<any>
*/ */
export interface AiRequestContext { export interface AiRequestContext {
providerId: ProviderId providerId: ProviderId
modelId: string model: LanguageModel | ImageModelV2
originalParams: any originalParams: any
metadata: Record<string, any> metadata: Record<string, any>
startTime: number startTime: number

View File

@ -83,7 +83,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
} }
// 使用正确的createContext创建请求上下文 // 使用正确的createContext创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params) const context = _context ? _context : createContext(this.providerId, model, params)
// 🔥 为上下文添加递归调用能力 // 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => { context.recursiveCall = async (newParams: any): Promise<TResult> => {
@ -159,7 +159,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
} }
// 使用正确的createContext创建请求上下文 // 使用正确的createContext创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params) const context = _context ? _context : createContext(this.providerId, model, params)
// 🔥 为上下文添加递归调用能力 // 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => { context.recursiveCall = async (newParams: any): Promise<TResult> => {
@ -235,7 +235,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
} }
// 创建请求上下文 // 创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params) const context = _context ? _context : createContext(this.providerId, model, params)
// 🔥 为上下文添加递归调用能力 // 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => { context.recursiveCall = async (newParams: any): Promise<TResult> => {

View File

@ -152,12 +152,14 @@ export class AiSdkToChunkAdapter {
// this.toolCallHandler.handleToolCallCreated(chunk) // this.toolCallHandler.handleToolCallCreated(chunk)
// break // break
case 'tool-call': case 'tool-call':
// 原始的工具调用(未被中间件处理)
this.toolCallHandler.handleToolCall(chunk) this.toolCallHandler.handleToolCall(chunk)
break break
case 'tool-error':
this.toolCallHandler.handleToolError(chunk)
break
case 'tool-result': case 'tool-result':
// 原始的工具调用结果(未被中间件处理)
this.toolCallHandler.handleToolResult(chunk) this.toolCallHandler.handleToolResult(chunk)
break break
@ -167,7 +169,6 @@ export class AiSdkToChunkAdapter {
// type: ChunkType.LLM_RESPONSE_CREATED // type: ChunkType.LLM_RESPONSE_CREATED
// }) // })
// break // break
// TODO: 需要区分接口开始和步骤开始
// case 'start-step': // case 'start-step':
// this.onChunk({ // this.onChunk({
// type: ChunkType.BLOCK_CREATED // type: ChunkType.BLOCK_CREATED
@ -305,8 +306,6 @@ export class AiSdkToChunkAdapter {
break break
default: default:
// 其他类型的 chunk 可以忽略或记录日志
// console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
} }
} }
} }

View File

@ -8,34 +8,61 @@ import { loggerService } from '@logger'
import { processKnowledgeReferences } from '@renderer/services/KnowledgeService' import { processKnowledgeReferences } from '@renderer/services/KnowledgeService'
import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types' import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk' import { Chunk, ChunkType } from '@renderer/types/chunk'
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai' import type { ToolSet, TypedToolCall, TypedToolError, TypedToolResult } from 'ai'
// import type {
// AnthropicSearchOutput,
// WebSearchPluginConfig
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
const logger = loggerService.withContext('ToolCallChunkHandler') const logger = loggerService.withContext('ToolCallChunkHandler')
export type ToolcallsMap = {
toolCallId: string
toolName: string
args: any
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
tool: BaseTool
}
/** /**
* *
*/ */
export class ToolCallChunkHandler { export class ToolCallChunkHandler {
// private onChunk: (chunk: Chunk) => void private static globalActiveToolCalls = new Map<string, ToolcallsMap>()
private activeToolCalls = new Map<
string, private activeToolCalls = ToolCallChunkHandler.globalActiveToolCalls
{
toolCallId: string
toolName: string
args: any
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
tool: BaseTool
}
>()
constructor( constructor(
private onChunk: (chunk: Chunk) => void, private onChunk: (chunk: Chunk) => void,
private mcpTools: MCPTool[] private mcpTools: MCPTool[]
) {} ) {}
/**
*
*/
private static addActiveToolCallImpl(toolCallId: string, map: ToolcallsMap): boolean {
if (!ToolCallChunkHandler.globalActiveToolCalls.has(toolCallId)) {
ToolCallChunkHandler.globalActiveToolCalls.set(toolCallId, map)
return true
}
return false
}
/**
*
*/
private addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
}
/**
*
*/
public static getActiveToolCalls() {
return ToolCallChunkHandler.globalActiveToolCalls
}
/**
* 访
*/
public static addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
}
// /** // /**
// * 设置 onChunk 回调 // * 设置 onChunk 回调
// */ // */
@ -43,103 +70,103 @@ export class ToolCallChunkHandler {
// this.onChunk = callback // this.onChunk = callback
// } // }
handleToolCallCreated( // handleToolCallCreated(
chunk: // chunk:
| { // | {
type: 'tool-input-start' // type: 'tool-input-start'
id: string // id: string
toolName: string // toolName: string
providerMetadata?: ProviderMetadata // providerMetadata?: ProviderMetadata
providerExecuted?: boolean // providerExecuted?: boolean
} // }
| { // | {
type: 'tool-input-end' // type: 'tool-input-end'
id: string // id: string
providerMetadata?: ProviderMetadata // providerMetadata?: ProviderMetadata
} // }
| { // | {
type: 'tool-input-delta' // type: 'tool-input-delta'
id: string // id: string
delta: string // delta: string
providerMetadata?: ProviderMetadata // providerMetadata?: ProviderMetadata
} // }
): void { // ): void {
switch (chunk.type) { // switch (chunk.type) {
case 'tool-input-start': { // case 'tool-input-start': {
// 能拿到说明是mcpTool // // 能拿到说明是mcpTool
// if (this.activeToolCalls.get(chunk.id)) return // // if (this.activeToolCalls.get(chunk.id)) return
const tool: BaseTool | MCPTool = { // const tool: BaseTool | MCPTool = {
id: chunk.id, // id: chunk.id,
name: chunk.toolName, // name: chunk.toolName,
description: chunk.toolName, // description: chunk.toolName,
type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider' // type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
} // }
this.activeToolCalls.set(chunk.id, { // this.activeToolCalls.set(chunk.id, {
toolCallId: chunk.id, // toolCallId: chunk.id,
toolName: chunk.toolName, // toolName: chunk.toolName,
args: '', // args: '',
tool // tool
}) // })
const toolResponse: MCPToolResponse | NormalToolResponse = { // const toolResponse: MCPToolResponse | NormalToolResponse = {
id: chunk.id, // id: chunk.id,
tool: tool, // tool: tool,
arguments: {}, // arguments: {},
status: 'pending', // status: 'pending',
toolCallId: chunk.id // toolCallId: chunk.id
} // }
this.onChunk({ // this.onChunk({
type: ChunkType.MCP_TOOL_PENDING, // type: ChunkType.MCP_TOOL_PENDING,
responses: [toolResponse] // responses: [toolResponse]
}) // })
break // break
} // }
case 'tool-input-delta': { // case 'tool-input-delta': {
const toolCall = this.activeToolCalls.get(chunk.id) // const toolCall = this.activeToolCalls.get(chunk.id)
if (!toolCall) { // if (!toolCall) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`) // logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
return // return
} // }
toolCall.args += chunk.delta // toolCall.args += chunk.delta
break // break
} // }
case 'tool-input-end': { // case 'tool-input-end': {
const toolCall = this.activeToolCalls.get(chunk.id) // const toolCall = this.activeToolCalls.get(chunk.id)
this.activeToolCalls.delete(chunk.id) // this.activeToolCalls.delete(chunk.id)
if (!toolCall) { // if (!toolCall) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`) // logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
return // return
} // }
// const toolResponse: ToolCallResponse = { // // const toolResponse: ToolCallResponse = {
// id: toolCall.toolCallId, // // id: toolCall.toolCallId,
// tool: toolCall.tool, // // tool: toolCall.tool,
// arguments: toolCall.args, // // arguments: toolCall.args,
// status: 'pending', // // status: 'pending',
// toolCallId: toolCall.toolCallId // // toolCallId: toolCall.toolCallId
// } // // }
// logger.debug('toolResponse', toolResponse) // // logger.debug('toolResponse', toolResponse)
// this.onChunk({ // // this.onChunk({
// type: ChunkType.MCP_TOOL_PENDING, // // type: ChunkType.MCP_TOOL_PENDING,
// responses: [toolResponse] // // responses: [toolResponse]
// }) // // })
break // break
} // }
} // }
// if (!toolCall) { // // if (!toolCall) {
// Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`) // // Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
// return // // return
// } // // }
// this.onChunk({ // // this.onChunk({
// type: ChunkType.MCP_TOOL_CREATED, // // type: ChunkType.MCP_TOOL_CREATED,
// tool_calls: [ // // tool_calls: [
// { // // {
// id: chunk.id, // // id: chunk.id,
// name: chunk.toolName, // // name: chunk.toolName,
// status: 'pending' // // status: 'pending'
// } // // }
// ] // // ]
// }) // // })
} // }
/** /**
* *
@ -158,7 +185,6 @@ export class ToolCallChunkHandler {
let tool: BaseTool let tool: BaseTool
let mcpTool: MCPTool | undefined let mcpTool: MCPTool | undefined
// 根据 providerExecuted 标志区分处理逻辑 // 根据 providerExecuted 标志区分处理逻辑
if (providerExecuted) { if (providerExecuted) {
// 如果是 Provider 执行的工具(如 web_search // 如果是 Provider 执行的工具(如 web_search
@ -196,27 +222,25 @@ export class ToolCallChunkHandler {
} }
} }
// 记录活跃的工具调用 this.addActiveToolCall(toolCallId, {
this.activeToolCalls.set(toolCallId, {
toolCallId, toolCallId,
toolName, toolName,
args, args,
tool tool
}) })
// 创建 MCPToolResponse 格式 // 创建 MCPToolResponse 格式
const toolResponse: MCPToolResponse | NormalToolResponse = { const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallId, id: toolCallId,
tool: tool, tool: tool,
arguments: args, arguments: args,
status: 'pending', status: 'pending', // 统一使用 pending 状态
toolCallId: toolCallId toolCallId: toolCallId
} }
// 调用 onChunk // 调用 onChunk
if (this.onChunk) { if (this.onChunk) {
this.onChunk({ this.onChunk({
type: ChunkType.MCP_TOOL_PENDING, type: ChunkType.MCP_TOOL_PENDING, // 统一发送 pending 状态
responses: [toolResponse] responses: [toolResponse]
}) })
} }
@ -276,4 +300,33 @@ export class ToolCallChunkHandler {
}) })
} }
} }
handleToolError(
chunk: {
type: 'tool-error'
} & TypedToolError<ToolSet>
): void {
const { toolCallId, error, input } = chunk
const toolCallInfo = this.activeToolCalls.get(toolCallId)
if (!toolCallInfo) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
return
}
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallId,
tool: toolCallInfo.tool,
arguments: input,
status: 'error',
response: error,
toolCallId: toolCallId
}
this.activeToolCalls.delete(toolCallId)
if (this.onChunk) {
this.onChunk({
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [toolResponse]
})
}
}
} }
export const addActiveToolCall = ToolCallChunkHandler.addActiveToolCall.bind(ToolCallChunkHandler)

View File

@ -265,15 +265,15 @@ export default class ModernAiProvider {
params: StreamTextParams, params: StreamTextParams,
config: ModernAiProviderConfig config: ModernAiProviderConfig
): Promise<CompletionsResult> { ): Promise<CompletionsResult> {
const modelId = this.model!.id // const modelId = this.model!.id
logger.info('Starting modernCompletions', { // logger.info('Starting modernCompletions', {
modelId, // modelId,
providerId: this.config!.providerId, // providerId: this.config!.providerId,
topicId: config.topicId, // topicId: config.topicId,
hasOnChunk: !!config.onChunk, // hasOnChunk: !!config.onChunk,
hasTools: !!params.tools && Object.keys(params.tools).length > 0, // hasTools: !!params.tools && Object.keys(params.tools).length > 0,
toolCount: params.tools ? Object.keys(params.tools).length : 0 // toolCount: params.tools ? Object.keys(params.tools).length : 0
}) // })
// 根据条件构建插件数组 // 根据条件构建插件数组
const plugins = buildPlugins(config) const plugins = buildPlugins(config)

View File

@ -35,7 +35,7 @@ export function buildPlugins(
plugins.push(webSearchPlugin()) plugins.push(webSearchPlugin())
} }
// 2. 支持工具调用时添加搜索插件 // 2. 支持工具调用时添加搜索插件
if (middlewareConfig.isSupportedToolUse) { if (middlewareConfig.isSupportedToolUse || middlewareConfig.isPromptToolUse) {
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant, middlewareConfig.topicId || '')) plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant, middlewareConfig.topicId || ''))
} }
@ -45,12 +45,13 @@ export function buildPlugins(
} }
// 4. 启用Prompt工具调用时添加工具插件 // 4. 启用Prompt工具调用时添加工具插件
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) { if (middlewareConfig.isPromptToolUse) {
plugins.push( plugins.push(
createPromptToolUsePlugin({ createPromptToolUsePlugin({
enabled: true, enabled: true,
createSystemMessage: (systemPrompt, params, context) => { createSystemMessage: (systemPrompt, params, context) => {
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) { const modelId = typeof context.model === 'string' ? context.model : context.model.modelId
if (modelId.includes('o1-mini') || modelId.includes('o1-preview')) {
if (context.isRecursiveCall) { if (context.isRecursiveCall) {
return null return null
} }

View File

@ -19,7 +19,8 @@ import store from '@renderer/store'
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory' import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
import type { Assistant } from '@renderer/types' import type { Assistant } from '@renderer/types'
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract' import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
import type { ModelMessage } from 'ai' import type { LanguageModel, ModelMessage } from 'ai'
import { generateText } from 'ai'
import { isEmpty } from 'lodash' import { isEmpty } from 'lodash'
import { MemoryProcessor } from '../../services/MemoryProcessor' import { MemoryProcessor } from '../../services/MemoryProcessor'
@ -76,9 +77,7 @@ async function analyzeSearchIntent(
shouldKnowledgeSearch?: boolean shouldKnowledgeSearch?: boolean
shouldMemorySearch?: boolean shouldMemorySearch?: boolean
lastAnswer?: ModelMessage lastAnswer?: ModelMessage
context: AiRequestContext & { context: AiRequestContext
isAnalyzing?: boolean
}
topicId: string topicId: string
} }
): Promise<ExtractResults | undefined> { ): Promise<ExtractResults | undefined> {
@ -122,9 +121,7 @@ async function analyzeSearchIntent(
logger.error('Provider not found or missing API key') logger.error('Provider not found or missing API key')
return getFallbackResult() return getFallbackResult()
} }
// console.log('formattedPrompt', schema)
try { try {
context.isAnalyzing = true
logger.info('Starting intent analysis generateText call', { logger.info('Starting intent analysis generateText call', {
modelId: model.id, modelId: model.id,
topicId: options.topicId, topicId: options.topicId,
@ -133,18 +130,16 @@ async function analyzeSearchIntent(
hasKnowledgeSearch: needKnowledgeExtract hasKnowledgeSearch: needKnowledgeExtract
}) })
const { text: result } = await context.executor const { text: result } = await generateText({
.generateText(model.id, { model: context.model as LanguageModel,
prompt: formattedPrompt prompt: formattedPrompt
}) }).finally(() => {
.finally(() => { logger.info('Intent analysis generateText call completed', {
context.isAnalyzing = false modelId: model.id,
logger.info('Intent analysis generateText call completed', { topicId: options.topicId,
modelId: model.id, requestId: context.requestId
topicId: options.topicId,
requestId: context.requestId
})
}) })
})
const parsedResult = extractInfoFromXML(result) const parsedResult = extractInfoFromXML(result)
logger.debug('Intent analysis result', { parsedResult }) logger.debug('Intent analysis result', { parsedResult })
@ -183,7 +178,6 @@ async function storeConversationMemory(
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
if (!globalMemoryEnabled || !assistant.enableMemory) { if (!globalMemoryEnabled || !assistant.enableMemory) {
// console.log('Memory storage is disabled')
return return
} }
@ -245,25 +239,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
// 存储意图分析结果 // 存储意图分析结果
const intentAnalysisResults: { [requestId: string]: ExtractResults } = {} const intentAnalysisResults: { [requestId: string]: ExtractResults } = {}
const userMessages: { [requestId: string]: ModelMessage } = {} const userMessages: { [requestId: string]: ModelMessage } = {}
let currentContext: AiRequestContext | null = null
return definePlugin({ return definePlugin({
name: 'search-orchestration', name: 'search-orchestration',
enforce: 'pre', // 确保在其他插件之前执行 enforce: 'pre', // 确保在其他插件之前执行
configureContext: (context: AiRequestContext) => {
if (currentContext) {
context.isAnalyzing = currentContext.isAnalyzing
}
currentContext = context
},
/** /**
* 🔍 Step 1: 意图识别阶段 * 🔍 Step 1: 意图识别阶段
*/ */
onRequestStart: async (context: AiRequestContext) => { onRequestStart: async (context: AiRequestContext) => {
if (context.isAnalyzing) return
// 没开启任何搜索则不进行意图分析 // 没开启任何搜索则不进行意图分析
if (!(assistant.webSearchProviderId || assistant.knowledge_bases?.length || assistant.enableMemory)) return if (!(assistant.webSearchProviderId || assistant.knowledge_bases?.length || assistant.enableMemory)) return
@ -315,7 +298,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
* 🔧 Step 2: 工具配置阶段 * 🔧 Step 2: 工具配置阶段
*/ */
transformParams: async (params: any, context: AiRequestContext) => { transformParams: async (params: any, context: AiRequestContext) => {
if (context.isAnalyzing) return params
// logger.info('🔧 Configuring tools based on intent...', context.requestId) // logger.info('🔧 Configuring tools based on intent...', context.requestId)
try { try {
@ -409,7 +391,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
// context.isAnalyzing = false // context.isAnalyzing = false
// logger.info('context.isAnalyzing', context, result) // logger.info('context.isAnalyzing', context, result)
// logger.info('💾 Starting memory storage...', context.requestId) // logger.info('💾 Starting memory storage...', context.requestId)
if (context.isAnalyzing) return
try { try {
const messages = context.originalParams.messages const messages = context.originalParams.messages

View File

@ -19,8 +19,6 @@ export const memorySearchTool = () => {
limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return') limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return')
}), }),
execute: async ({ query, limit = 5 }) => { execute: async ({ query, limit = 5 }) => {
// console.log('🧠 [memorySearchTool] Searching memories:', { query, limit })
try { try {
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
if (!globalMemoryEnabled) { if (!globalMemoryEnabled) {
@ -29,7 +27,6 @@ export const memorySearchTool = () => {
const memoryConfig = selectMemoryConfig(store.getState()) const memoryConfig = selectMemoryConfig(store.getState())
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) { if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
// console.warn('Memory search skipped: embedding or LLM model not configured')
return [] return []
} }
@ -40,12 +37,10 @@ export const memorySearchTool = () => {
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit) const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
if (relevantMemories?.length > 0) { if (relevantMemories?.length > 0) {
// console.log('🧠 [memorySearchTool] Found memories:', relevantMemories.length)
return relevantMemories return relevantMemories
} }
return [] return []
} catch (error) { } catch (error) {
// console.error('🧠 [memorySearchTool] Error:', error)
return [] return []
} }
} }
@ -84,8 +79,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
.optional() .optional()
}) satisfies z.ZodSchema<MemorySearchWithExtractionInput>, }) satisfies z.ZodSchema<MemorySearchWithExtractionInput>,
execute: async ({ userMessage }) => { execute: async ({ userMessage }) => {
// console.log('🧠 [memorySearchToolWithExtraction] Processing:', { userMessage, lastAnswer })
try { try {
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
if (!globalMemoryEnabled || !assistant.enableMemory) { if (!globalMemoryEnabled || !assistant.enableMemory) {
@ -97,7 +90,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
const memoryConfig = selectMemoryConfig(store.getState()) const memoryConfig = selectMemoryConfig(store.getState())
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) { if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
// console.warn('Memory search skipped: embedding or LLM model not configured')
return { return {
extractedKeywords: 'Memory models not configured', extractedKeywords: 'Memory models not configured',
searchResults: [] searchResults: []
@ -125,7 +117,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
) )
if (relevantMemories?.length > 0) { if (relevantMemories?.length > 0) {
// console.log('🧠 [memorySearchToolWithExtraction] Found memories:', relevantMemories.length)
return { return {
extractedKeywords: content, extractedKeywords: content,
searchResults: relevantMemories searchResults: relevantMemories
@ -137,7 +128,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
searchResults: [] searchResults: []
} }
} catch (error) { } catch (error) {
// console.error('🧠 [memorySearchToolWithExtraction] Error:', error)
return { return {
extractedKeywords: 'Search failed', extractedKeywords: 'Search failed',
searchResults: [] searchResults: []

View File

@ -1,7 +1,5 @@
import { loggerService } from '@logger' import { loggerService } from '@logger'
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
import { MCPTool, MCPToolResponse } from '@renderer/types' import { MCPTool, MCPToolResponse } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { callMCPTool, getMcpServerByTool, isToolAutoApproved } from '@renderer/utils/mcp-tools' import { callMCPTool, getMcpServerByTool, isToolAutoApproved } from '@renderer/utils/mcp-tools'
import { requestToolConfirmation } from '@renderer/utils/userConfirmation' import { requestToolConfirmation } from '@renderer/utils/userConfirmation'
import { type Tool, type ToolSet } from 'ai' import { type Tool, type ToolSet } from 'ai'
@ -33,8 +31,36 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
tools[mcpTool.name] = tool({ tools[mcpTool.name] = tool({
description: mcpTool.description || `Tool from ${mcpTool.serverName}`, description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7), inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
execute: async (params, { toolCallId, experimental_context }) => { execute: async (params, { toolCallId }) => {
const { onChunk } = experimental_context as { onChunk: (chunk: Chunk) => void } // 检查是否启用自动批准
const server = getMcpServerByTool(mcpTool)
const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server)
let confirmed = true
if (!isAutoApproveEnabled) {
// 请求用户确认
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`)
confirmed = await requestToolConfirmation(toolCallId)
}
if (!confirmed) {
// 用户拒绝执行工具
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
return {
content: [
{
type: 'text',
text: `User declined to execute tool "${mcpTool.name}".`
}
],
isError: false
}
}
// 用户确认或自动批准,执行工具
logger.debug(`Executing tool: ${mcpTool.name}`)
// 创建适配的 MCPToolResponse 对象 // 创建适配的 MCPToolResponse 对象
const toolResponse: MCPToolResponse = { const toolResponse: MCPToolResponse = {
id: toolCallId, id: toolCallId,
@ -44,53 +70,18 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
toolCallId toolCallId
} }
try { const result = await callMCPTool(toolResponse)
// 检查是否启用自动批准
const server = getMcpServerByTool(mcpTool)
const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server)
let confirmed = true // 返回结果AI SDK 会处理序列化
if (!isAutoApproveEnabled) { if (result.isError) {
// 请求用户确认 // throw new Error(result.content?.[0]?.text || 'Tool execution failed')
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`) return Promise.reject(result)
confirmed = await requestToolConfirmation(toolResponse.id)
}
if (!confirmed) {
// 用户拒绝执行工具
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
return {
content: [
{
type: 'text',
text: `User declined to execute tool "${mcpTool.name}".`
}
],
isError: false
}
}
// 用户确认或自动批准,执行工具
toolResponse.status = 'invoking'
logger.debug(`Executing tool: ${mcpTool.name}`)
onChunk({
type: ChunkType.MCP_TOOL_IN_PROGRESS,
responses: [toolResponse]
})
const result = await callMCPTool(toolResponse)
// 返回结果AI SDK 会处理序列化
if (result.isError) {
throw new Error(result.content?.[0]?.text || 'Tool execution failed')
}
// 返回工具执行结果
return result
} catch (error) {
logger.error(`MCP Tool execution failed: ${mcpTool.name}`, { error })
throw error
} }
// 返回工具执行结果
return result
// } catch (error) {
// logger.error(`MCP Tool execution failed: ${mcpTool.name}`, { error })
// }
} }
}) })
} }

View File

@ -105,7 +105,6 @@ const ThinkingBlock: React.FC<Props> = ({ block }) => {
const ThinkingTimeSeconds = memo( const ThinkingTimeSeconds = memo(
({ blockThinkingTime, isThinking }: { blockThinkingTime: number; isThinking: boolean }) => { ({ blockThinkingTime, isThinking }: { blockThinkingTime: number; isThinking: boolean }) => {
const { t } = useTranslation() const { t } = useTranslation()
// console.log('blockThinkingTime', blockThinkingTime)
// const [thinkingTime, setThinkingTime] = useState(blockThinkingTime || 0) // const [thinkingTime, setThinkingTime] = useState(blockThinkingTime || 0)
// FIXME: 这里统计的和请求处统计的有一定误差 // FIXME: 这里统计的和请求处统计的有一定误差

View File

@ -186,7 +186,6 @@ const MessageMenubar: FC<Props> = (props) => {
try { try {
await translateText(mainTextContent, language, translationUpdater) await translateText(mainTextContent, language, translationUpdater)
} catch (error) { } catch (error) {
// console.error('Translation failed:', error)
window.message.error({ content: t('translate.error.failed'), key: 'translate-message' }) window.message.error({ content: t('translate.error.failed'), key: 'translate-message' })
// 理应只有一个 // 理应只有一个
const translationBlocks = findTranslationBlocksById(message.id) const translationBlocks = findTranslationBlocksById(message.id)

View File

@ -60,14 +60,29 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
const toolResponse = block.metadata?.rawMcpToolResponse const toolResponse = block.metadata?.rawMcpToolResponse
const { id, tool, status, response } = toolResponse! const { id, tool, status, response } = toolResponse!
const isPending = status === 'pending' const isPending = status === 'pending'
const isInvoking = status === 'invoking'
const isDone = status === 'done' const isDone = status === 'done'
const isError = status === 'error'
const isAutoApproved = useMemo(
() =>
isToolAutoApproved(
tool,
mcpServers.find((s) => s.id === tool.serverId)
),
[tool, mcpServers]
)
// 增加本地状态来跟踪用户确认
const [isConfirmed, setIsConfirmed] = useState(isAutoApproved)
// 判断不同的UI状态
const isWaitingConfirmation = isPending && !isAutoApproved && !isConfirmed
const isExecuting = isPending && (isAutoApproved || isConfirmed)
const timer = useRef<NodeJS.Timeout | null>(null) const timer = useRef<NodeJS.Timeout | null>(null)
useEffect(() => { useEffect(() => {
if (!isPending) return if (!isWaitingConfirmation) return
if (countdown > 0) { if (countdown > 0) {
timer.current = setTimeout(() => { timer.current = setTimeout(() => {
@ -75,6 +90,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
setCountdown((prev) => prev - 1) setCountdown((prev) => prev - 1)
}, 1000) }, 1000)
} else if (countdown === 0) { } else if (countdown === 0) {
setIsConfirmed(true)
confirmToolAction(id) confirmToolAction(id)
} }
@ -83,7 +99,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
clearTimeout(timer.current) clearTimeout(timer.current)
} }
} }
}, [countdown, id, isPending]) }, [countdown, id, isWaitingConfirmation])
useEffect(() => { useEffect(() => {
const removeListener = window.electron.ipcRenderer.on( const removeListener = window.electron.ipcRenderer.on(
@ -146,6 +162,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
const handleConfirmTool = () => { const handleConfirmTool = () => {
cancelCountdown() cancelCountdown()
setIsConfirmed(true)
confirmToolAction(id) confirmToolAction(id)
} }
@ -195,6 +212,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
updateMCPServer(updatedServer) updateMCPServer(updatedServer)
// Also confirm the current tool // Also confirm the current tool
setIsConfirmed(true)
confirmToolAction(id) confirmToolAction(id)
window.message.success({ window.message.success({
@ -206,32 +224,31 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
const renderStatusIndicator = (status: string, hasError: boolean) => { const renderStatusIndicator = (status: string, hasError: boolean) => {
let label = '' let label = ''
let icon: React.ReactNode | null = null let icon: React.ReactNode | null = null
switch (status) {
case 'pending': if (status === 'pending') {
if (isWaitingConfirmation) {
label = t('message.tools.pending', 'Awaiting Approval') label = t('message.tools.pending', 'Awaiting Approval')
icon = <LoadingIcon style={{ marginLeft: 6, color: 'var(--status-color-warning)' }} /> icon = <LoadingIcon style={{ marginLeft: 6, color: 'var(--status-color-warning)' }} />
break } else if (isExecuting) {
case 'invoking':
label = t('message.tools.invoking') label = t('message.tools.invoking')
icon = <LoadingIcon style={{ marginLeft: 6 }} /> icon = <LoadingIcon style={{ marginLeft: 6 }} />
break }
case 'cancelled': } else if (status === 'cancelled') {
label = t('message.tools.cancelled') label = t('message.tools.cancelled')
icon = <X size={13} style={{ marginLeft: 6 }} className="lucide-custom" /> icon = <X size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
break } else if (status === 'done') {
case 'done': if (hasError) {
if (hasError) { label = t('message.tools.error')
label = t('message.tools.error') icon = <TriangleAlert size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
icon = <TriangleAlert size={13} style={{ marginLeft: 6 }} className="lucide-custom" /> } else {
} else { label = t('message.tools.completed')
label = t('message.tools.completed') icon = <Check size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
icon = <Check size={13} style={{ marginLeft: 6 }} className="lucide-custom" /> }
} } else if (status === 'error') {
break label = t('message.tools.error')
default: icon = <TriangleAlert size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
label = ''
icon = null
} }
return ( return (
<StatusIndicator status={status} hasError={hasError}> <StatusIndicator status={status} hasError={hasError}>
{label} {label}
@ -248,7 +265,6 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
params: toolResponse.arguments, params: toolResponse.arguments,
response: toolResponse.response response: toolResponse.response
} }
items.push({ items.push({
key: id, key: id,
label: ( label: (
@ -283,7 +299,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
<Maximize size={14} /> <Maximize size={14} />
</ActionButton> </ActionButton>
</Tooltip> </Tooltip>
{!isPending && !isInvoking && ( {!isPending && (
<Tooltip title={t('common.copy')} mouseEnterDelay={0.5}> <Tooltip title={t('common.copy')} mouseEnterDelay={0.5}>
<ActionButton <ActionButton
className="message-action-button" className="message-action-button"
@ -301,7 +317,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
</MessageTitleLabel> </MessageTitleLabel>
), ),
children: children:
isDone && result ? ( (isDone || isError) && result ? (
<ToolResponseContainer <ToolResponseContainer
style={{ style={{
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)', fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
@ -370,7 +386,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
} }
}}> }}>
<ToolContainer> <ToolContainer>
<ToolContentWrapper className={status}> <ToolContentWrapper className={isPending ? 'pending' : status}>
<CollapseContainer <CollapseContainer
ghost ghost
activeKey={activeKeys} activeKey={activeKeys}
@ -383,14 +399,16 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
<ExpandIcon $isActive={isActive} size={18} color="var(--color-text-3)" strokeWidth={1.5} /> <ExpandIcon $isActive={isActive} size={18} color="var(--color-text-3)" strokeWidth={1.5} />
)} )}
/> />
{(isPending || isInvoking) && ( {isPending && (
<ActionsBar> <ActionsBar>
<ActionLabel> <ActionLabel>
{isPending ? t('settings.mcp.tools.autoApprove.tooltip.confirm') : t('message.tools.invoking')} {isWaitingConfirmation
? t('settings.mcp.tools.autoApprove.tooltip.confirm')
: t('message.tools.invoking')}
</ActionLabel> </ActionLabel>
<ActionButtonsGroup> <ActionButtonsGroup>
{isPending && ( {isWaitingConfirmation && (
<Button <Button
color="danger" color="danger"
variant="filled" variant="filled"
@ -402,7 +420,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
{t('common.cancel')} {t('common.cancel')}
</Button> </Button>
)} )}
{isInvoking && toolResponse?.id ? ( {isExecuting && toolResponse?.id ? (
<Button <Button
size="small" size="small"
color="danger" color="danger"
@ -416,29 +434,31 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
{t('chat.input.pause')} {t('chat.input.pause')}
</Button> </Button>
) : ( ) : (
<StyledDropdownButton isWaitingConfirmation && (
size="small" <StyledDropdownButton
type="primary" size="small"
icon={<ChevronDown size={14} />} type="primary"
onClick={() => { icon={<ChevronDown size={14} />}
handleConfirmTool() onClick={() => {
}} handleConfirmTool()
menu={{ }}
items: [ menu={{
{ items: [
key: 'autoApprove', {
label: t('settings.mcp.tools.autoApprove.label'), key: 'autoApprove',
onClick: () => { label: t('settings.mcp.tools.autoApprove.label'),
handleAutoApprove() onClick: () => {
handleAutoApprove()
}
} }
} ]
] }}>
}}> <CirclePlay size={15} className="lucide-custom" />
<CirclePlay size={15} className="lucide-custom" /> <CountdownText>
<CountdownText> {t('settings.mcp.tools.run', 'Run')} ({countdown}s)
{t('settings.mcp.tools.run', 'Run')} ({countdown}s) </CountdownText>
</CountdownText> </StyledDropdownButton>
</StyledDropdownButton> )
)} )}
</ActionButtonsGroup> </ActionButtonsGroup>
</ActionsBar> </ActionsBar>
@ -542,8 +562,7 @@ const ToolContentWrapper = styled.div`
border: 1px solid var(--color-border); border: 1px solid var(--color-border);
} }
&.pending, &.pending {
&.invoking {
background-color: var(--color-background-soft); background-color: var(--color-background-soft);
.ant-collapse { .ant-collapse {
border: none; border: none;
@ -663,6 +682,8 @@ const StatusIndicator = styled.span<{ status: string; hasError?: boolean }>`
return 'var(--status-color-error)' return 'var(--status-color-error)'
case 'done': case 'done':
return props.hasError ? 'var(--status-color-error)' : 'var(--status-color-success)' return props.hasError ? 'var(--status-color-error)' : 'var(--status-color-success)'
case 'error':
return 'var(--status-color-error)'
default: default:
return 'var(--color-text)' return 'var(--color-text)'
} }

View File

@ -99,6 +99,8 @@ export async function fetchChatCompletion({
const provider = AI.getActualProvider() const provider = AI.getActualProvider()
const mcpTools: MCPTool[] = [] const mcpTools: MCPTool[] = []
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
if (isPromptToolUse(assistant) || isSupportedToolUse(assistant)) { if (isPromptToolUse(assistant) || isSupportedToolUse(assistant)) {
mcpTools.push(...(await fetchMcpTools(assistant))) mcpTools.push(...(await fetchMcpTools(assistant)))
} }
@ -137,7 +139,6 @@ export async function fetchChatCompletion({
} }
// --- Call AI Completions --- // --- Call AI Completions ---
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
await AI.completions(modelId, aiSdkParams, { await AI.completions(modelId, aiSdkParams, {
...middlewareConfig, ...middlewareConfig,
assistant, assistant,

View File

@ -178,12 +178,7 @@ class WebSearchService {
formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}` formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}`
} }
// try {
return await webSearchEngine.search(formattedQuery, websearch, httpOptions) return await webSearchEngine.search(formattedQuery, websearch, httpOptions)
// } catch (error) {
// console.error('Search failed:', error)
// throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
// }
} }
/** /**

View File

@ -69,16 +69,6 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
}) })
await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN) await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
}, },
// onBlockCreated: async () => {
// if (blockManager.hasInitialPlaceholder) {
// return
// }
// console.log('onBlockCreated')
// const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, {
// status: MessageBlockStatus.PROCESSING
// })
// await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
// },
onError: async (error: AISDKError) => { onError: async (error: AISDKError) => {
logger.debug('onError', error) logger.debug('onError', error)

View File

@ -51,28 +51,6 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
} }
}, },
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
// 根据 toolResponse.id 查找对应的块ID
const targetBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
if (targetBlockId && toolResponse.status === 'invoking') {
const changes = {
status: MessageBlockStatus.PROCESSING,
metadata: { rawMcpToolResponse: toolResponse }
}
blockManager.smartBlockUpdate(targetBlockId, changes, MessageBlockType.TOOL)
} else if (!targetBlockId) {
logger.warn(
`[onToolCallInProgress] No block ID found for tool ID: ${toolResponse.id}. Available mappings:`,
Array.from(toolCallIdToBlockIdMap.entries())
)
} else {
logger.warn(
`[onToolCallInProgress] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
)
}
},
onToolCallComplete: (toolResponse: MCPToolResponse) => { onToolCallComplete: (toolResponse: MCPToolResponse) => {
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id) const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
toolCallIdToBlockIdMap.delete(toolResponse.id) toolCallIdToBlockIdMap.delete(toolResponse.id)
@ -104,7 +82,6 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
stack: null stack: null
} }
} }
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true) blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
// Handle citation block creation for web search results // Handle citation block creation for web search results
@ -132,7 +109,6 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
citationBlockId = citationBlock.id citationBlockId = citationBlock.id
blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION) blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
} }
// TODO: 处理 memory 引用
} else { } else {
logger.warn( logger.warn(
`[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}` `[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`

View File

@ -458,18 +458,18 @@ describe('streamCallback Integration Tests', () => {
} }
] ]
}, },
{ // {
type: ChunkType.MCP_TOOL_IN_PROGRESS, // type: ChunkType.MCP_TOOL_PENDING,
responses: [ // responses: [
{ // {
id: 'tool-call-1', // id: 'tool-call-1',
tool: mockTool, // tool: mockTool,
arguments: { testArg: 'value' }, // arguments: { testArg: 'value' },
status: 'invoking' as const, // status: 'invoking' as const,
response: '' // response: ''
} // }
] // ]
}, // },
{ {
type: ChunkType.MCP_TOOL_COMPLETE, type: ChunkType.MCP_TOOL_COMPLETE,
responses: [ responses: [
@ -611,18 +611,18 @@ describe('streamCallback Integration Tests', () => {
} }
] ]
}, },
{ // {
type: ChunkType.MCP_TOOL_IN_PROGRESS, // type: ChunkType.MCP_TOOL_PENDING,
responses: [ // responses: [
{ // {
id: 'tool-call-1', // id: 'tool-call-1',
tool: mockCalculatorTool, // tool: mockCalculatorTool,
arguments: { operation: 'add', a: 1, b: 2 }, // arguments: { operation: 'add', a: 1, b: 2 },
status: 'invoking' as const, // status: 'invoking' as const,
response: '' // response: ''
} // }
] // ]
}, // },
{ {
type: ChunkType.MCP_TOOL_COMPLETE, type: ChunkType.MCP_TOOL_COMPLETE,
responses: [ responses: [

View File

@ -826,9 +826,7 @@ export function mcpToolCallResponseToAwsBedrockMessage(
} }
/** /**
* 使 * 使(function call)
* 1. 使
* 2. 使 prompt使
* @param assistant * @param assistant
* @returns 使 * @returns 使
*/ */

View File

@ -179,15 +179,15 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"@ai-sdk/google@npm:^2.0.7": "@ai-sdk/google@npm:^2.0.13":
version: 2.0.7 version: 2.0.13
resolution: "@ai-sdk/google@npm:2.0.7" resolution: "@ai-sdk/google@npm:2.0.13"
dependencies: dependencies:
"@ai-sdk/provider": "npm:2.0.0" "@ai-sdk/provider": "npm:2.0.0"
"@ai-sdk/provider-utils": "npm:3.0.4" "@ai-sdk/provider-utils": "npm:3.0.8"
peerDependencies: peerDependencies:
zod: ^3.25.76 || ^4 zod: ^3.25.76 || ^4
checksum: 10c0/bde4c95a2a167355cda18de9d5b273d562d2a724f650ca69016daa8df2766280487e143cf0cdd96f6654c255d587a680c6a937b280eb734ca2c35d6f9b9e943c checksum: 10c0/a05210de11d7ab41d49bcd0330c37f4116441b149d8ccc9b6bc5eaa12ea42bae82364dc2cd09502734b15115071f07395525806ea4998930b285b1ce74102186
languageName: node languageName: node
linkType: hard linkType: hard
@ -227,15 +227,15 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"@ai-sdk/openai@npm:^2.0.19": "@ai-sdk/openai@npm:^2.0.26":
version: 2.0.19 version: 2.0.26
resolution: "@ai-sdk/openai@npm:2.0.19" resolution: "@ai-sdk/openai@npm:2.0.26"
dependencies: dependencies:
"@ai-sdk/provider": "npm:2.0.0" "@ai-sdk/provider": "npm:2.0.0"
"@ai-sdk/provider-utils": "npm:3.0.5" "@ai-sdk/provider-utils": "npm:3.0.8"
peerDependencies: peerDependencies:
zod: ^3.25.76 || ^4 zod: ^3.25.76 || ^4
checksum: 10c0/04db695669d783a810b80283e0cd48f6e7654667fd76ca2d35c7cffae6fdd68fb0473118e4e097ef1352f4432dd7c15c07f873d712b940c72495e5839b0ede98 checksum: 10c0/b8cb01c0c38525c38901f41f1693cd15589932a2aceddea14bed30f44719532a5e74615fb0e974eff1a0513048ac204c27456ff8829a9c811d1461cc635c9cc5
languageName: node languageName: node
linkType: hard linkType: hard
@ -267,20 +267,6 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"@ai-sdk/provider-utils@npm:3.0.5":
version: 3.0.5
resolution: "@ai-sdk/provider-utils@npm:3.0.5"
dependencies:
"@ai-sdk/provider": "npm:2.0.0"
"@standard-schema/spec": "npm:^1.0.0"
eventsource-parser: "npm:^3.0.3"
zod-to-json-schema: "npm:^3.24.1"
peerDependencies:
zod: ^3.25.76 || ^4
checksum: 10c0/4057810b320bda149a178dc1bfc9cdd592ca88b736c3c22bd0c1f8111c75ef69beec4a523f363e5d0d120348b876942fd66c0bb4965864da4c12c5cfddee15a3
languageName: node
linkType: hard
"@ai-sdk/provider-utils@npm:3.0.7": "@ai-sdk/provider-utils@npm:3.0.7":
version: 3.0.7 version: 3.0.7
resolution: "@ai-sdk/provider-utils@npm:3.0.7" resolution: "@ai-sdk/provider-utils@npm:3.0.7"
@ -294,6 +280,19 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"@ai-sdk/provider-utils@npm:3.0.8":
version: 3.0.8
resolution: "@ai-sdk/provider-utils@npm:3.0.8"
dependencies:
"@ai-sdk/provider": "npm:2.0.0"
"@standard-schema/spec": "npm:^1.0.0"
eventsource-parser: "npm:^3.0.5"
peerDependencies:
zod: ^3.25.76 || ^4
checksum: 10c0/f466657c886cbb9f7ecbcd2dd1abc51a88af9d3f1cff030f7e97e70a4790a99f3338ad886e9c0dccf04dacdcc84522c7d57119b9a4e8e1d84f2dae9c893c397e
languageName: node
linkType: hard
"@ai-sdk/provider@npm:2.0.0, @ai-sdk/provider@npm:^2.0.0": "@ai-sdk/provider@npm:2.0.0, @ai-sdk/provider@npm:^2.0.0":
version: 2.0.0 version: 2.0.0
resolution: "@ai-sdk/provider@npm:2.0.0" resolution: "@ai-sdk/provider@npm:2.0.0"
@ -2249,8 +2248,8 @@ __metadata:
"@ai-sdk/anthropic": "npm:^2.0.5" "@ai-sdk/anthropic": "npm:^2.0.5"
"@ai-sdk/azure": "npm:^2.0.16" "@ai-sdk/azure": "npm:^2.0.16"
"@ai-sdk/deepseek": "npm:^1.0.9" "@ai-sdk/deepseek": "npm:^1.0.9"
"@ai-sdk/google": "npm:^2.0.7" "@ai-sdk/google": "npm:^2.0.13"
"@ai-sdk/openai": "npm:^2.0.19" "@ai-sdk/openai": "npm:^2.0.26"
"@ai-sdk/openai-compatible": "npm:^1.0.9" "@ai-sdk/openai-compatible": "npm:^1.0.9"
"@ai-sdk/provider": "npm:^2.0.0" "@ai-sdk/provider": "npm:^2.0.0"
"@ai-sdk/provider-utils": "npm:^3.0.4" "@ai-sdk/provider-utils": "npm:^3.0.4"