mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-29 23:12:38 +08:00
refactor(ToolCall): refactor:mcp-tool-state-management (#10028)
This commit is contained in:
parent
7df1060370
commit
f6ffd574bf
@ -1,103 +0,0 @@
|
||||
/**
|
||||
* Hub Provider 使用示例
|
||||
*
|
||||
* 演示如何使用简化后的Hub Provider功能来路由到多个底层provider
|
||||
*/
|
||||
|
||||
import { createHubProvider, initializeProvider, providerRegistry } from '../src/index'
|
||||
|
||||
async function demonstrateHubProvider() {
|
||||
try {
|
||||
// 1. 初始化底层providers
|
||||
console.log('📦 初始化底层providers...')
|
||||
|
||||
initializeProvider('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY || 'sk-test-key'
|
||||
})
|
||||
|
||||
initializeProvider('anthropic', {
|
||||
apiKey: process.env.ANTHROPIC_API_KEY || 'sk-ant-test-key'
|
||||
})
|
||||
|
||||
// 2. 创建Hub Provider(自动包含所有已初始化的providers)
|
||||
console.log('🌐 创建Hub Provider...')
|
||||
|
||||
const aihubmixProvider = createHubProvider({
|
||||
hubId: 'aihubmix',
|
||||
debug: true
|
||||
})
|
||||
|
||||
// 3. 注册Hub Provider
|
||||
providerRegistry.registerProvider('aihubmix', aihubmixProvider)
|
||||
|
||||
console.log('✅ Hub Provider "aihubmix" 注册成功')
|
||||
|
||||
// 4. 使用Hub Provider访问不同的模型
|
||||
console.log('\n🚀 使用Hub模型...')
|
||||
|
||||
// 通过Hub路由到OpenAI
|
||||
const openaiModel = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
||||
console.log('✓ OpenAI模型已获取:', openaiModel.modelId)
|
||||
|
||||
// 通过Hub路由到Anthropic
|
||||
const anthropicModel = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
||||
console.log('✓ Anthropic模型已获取:', anthropicModel.modelId)
|
||||
|
||||
// 5. 演示错误处理
|
||||
console.log('\n❌ 演示错误处理...')
|
||||
|
||||
try {
|
||||
// 尝试访问未初始化的provider
|
||||
providerRegistry.languageModel('aihubmix:google:gemini-pro')
|
||||
} catch (error) {
|
||||
console.log('预期错误:', error.message)
|
||||
}
|
||||
|
||||
try {
|
||||
// 尝试使用错误的模型ID格式
|
||||
providerRegistry.languageModel('aihubmix:invalid-format')
|
||||
} catch (error) {
|
||||
console.log('预期错误:', error.message)
|
||||
}
|
||||
|
||||
// 6. 多个Hub Provider示例
|
||||
console.log('\n🔄 创建多个Hub Provider...')
|
||||
|
||||
const localHubProvider = createHubProvider({
|
||||
hubId: 'local-ai'
|
||||
})
|
||||
|
||||
providerRegistry.registerProvider('local-ai', localHubProvider)
|
||||
console.log('✅ Hub Provider "local-ai" 注册成功')
|
||||
|
||||
console.log('\n🎉 Hub Provider演示完成!')
|
||||
} catch (error) {
|
||||
console.error('💥 演示过程中发生错误:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 演示简化的使用方式
|
||||
function simplifiedUsageExample() {
|
||||
console.log('\n📝 简化使用示例:')
|
||||
console.log(`
|
||||
// 1. 初始化providers
|
||||
initializeProvider('openai', { apiKey: 'sk-xxx' })
|
||||
initializeProvider('anthropic', { apiKey: 'sk-ant-xxx' })
|
||||
|
||||
// 2. 创建并注册Hub Provider
|
||||
const hubProvider = createHubProvider({ hubId: 'aihubmix' })
|
||||
providerRegistry.registerProvider('aihubmix', hubProvider)
|
||||
|
||||
// 3. 直接使用
|
||||
const model1 = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
||||
const model2 = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
||||
`)
|
||||
}
|
||||
|
||||
// 运行演示
|
||||
if (require.main === module) {
|
||||
demonstrateHubProvider()
|
||||
simplifiedUsageExample()
|
||||
}
|
||||
|
||||
export { demonstrateHubProvider, simplifiedUsageExample }
|
||||
@ -1,167 +0,0 @@
|
||||
/**
|
||||
* Image Generation Example
|
||||
* 演示如何使用 aiCore 的文生图功能
|
||||
*/
|
||||
|
||||
import { createExecutor, generateImage } from '../src/index'
|
||||
|
||||
async function main() {
|
||||
// 方式1: 使用执行器实例
|
||||
console.log('📸 创建 OpenAI 图像生成执行器...')
|
||||
const executor = createExecutor('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY!
|
||||
})
|
||||
|
||||
try {
|
||||
console.log('🎨 使用执行器生成图像...')
|
||||
const result1 = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A futuristic cityscape at sunset with flying cars',
|
||||
size: '1024x1024',
|
||||
n: 1
|
||||
})
|
||||
|
||||
console.log('✅ 图像生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result1.images.length,
|
||||
mediaType: result1.image.mediaType,
|
||||
hasBase64: !!result1.image.base64,
|
||||
providerMetadata: result1.providerMetadata
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 执行器生成失败:', error)
|
||||
}
|
||||
|
||||
// 方式2: 使用直接调用 API
|
||||
try {
|
||||
console.log('🎨 使用直接 API 生成图像...')
|
||||
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
|
||||
prompt: 'A magical forest with glowing mushrooms and fairy lights',
|
||||
aspectRatio: '16:9',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
console.log('✅ 直接 API 生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result2.images.length,
|
||||
mediaType: result2.image.mediaType,
|
||||
hasBase64: !!result2.image.base64
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 直接 API 生成失败:', error)
|
||||
}
|
||||
|
||||
// 方式3: 支持其他提供商 (Google Imagen)
|
||||
if (process.env.GOOGLE_API_KEY) {
|
||||
try {
|
||||
console.log('🎨 使用 Google Imagen 生成图像...')
|
||||
const googleExecutor = createExecutor('google', {
|
||||
apiKey: process.env.GOOGLE_API_KEY!
|
||||
})
|
||||
|
||||
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
|
||||
prompt: 'A serene mountain lake at dawn with mist rising from the water',
|
||||
aspectRatio: '1:1'
|
||||
})
|
||||
|
||||
console.log('✅ Google Imagen 生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result3.images.length,
|
||||
mediaType: result3.image.mediaType,
|
||||
hasBase64: !!result3.image.base64
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ Google Imagen 生成失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 方式4: 支持插件系统
|
||||
const pluginExample = async () => {
|
||||
console.log('🔌 演示插件系统...')
|
||||
|
||||
// 创建一个示例插件,用于修改提示词
|
||||
const promptEnhancerPlugin = {
|
||||
name: 'prompt-enhancer',
|
||||
transformParams: async (params: any) => {
|
||||
console.log('🔧 插件: 增强提示词...')
|
||||
return {
|
||||
...params,
|
||||
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
|
||||
}
|
||||
},
|
||||
transformResult: async (result: any) => {
|
||||
console.log('🔧 插件: 处理结果...')
|
||||
return {
|
||||
...result,
|
||||
enhanced: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const executorWithPlugin = createExecutor(
|
||||
'openai',
|
||||
{
|
||||
apiKey: process.env.OPENAI_API_KEY!
|
||||
},
|
||||
[promptEnhancerPlugin]
|
||||
)
|
||||
|
||||
try {
|
||||
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A cute robot playing in a garden'
|
||||
})
|
||||
|
||||
console.log('✅ 插件系统生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result4.images.length,
|
||||
enhanced: (result4 as any).enhanced,
|
||||
mediaType: result4.image.mediaType
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 插件系统生成失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
await pluginExample()
|
||||
}
|
||||
|
||||
// 错误处理演示
|
||||
async function errorHandlingExample() {
|
||||
console.log('⚠️ 演示错误处理...')
|
||||
|
||||
try {
|
||||
const executor = createExecutor('openai', {
|
||||
apiKey: 'invalid-key'
|
||||
})
|
||||
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'Test image'
|
||||
})
|
||||
} catch (error: any) {
|
||||
console.log('✅ 成功捕获错误:', error.constructor.name)
|
||||
console.log('📋 错误信息:', error.message)
|
||||
console.log('🏷️ 提供商ID:', error.providerId)
|
||||
console.log('🏷️ 模型ID:', error.modelId)
|
||||
}
|
||||
}
|
||||
|
||||
// 运行示例
|
||||
if (require.main === module) {
|
||||
main()
|
||||
.then(() => {
|
||||
console.log('🎉 所有示例完成!')
|
||||
return errorHandlingExample()
|
||||
})
|
||||
.then(() => {
|
||||
console.log('🎯 示例程序结束')
|
||||
process.exit(0)
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('💥 程序执行出错:', error)
|
||||
process.exit(1)
|
||||
})
|
||||
}
|
||||
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@cherrystudio/ai-core",
|
||||
"version": "1.0.0-alpha.12",
|
||||
"version": "1.0.0-alpha.13",
|
||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/index.mjs",
|
||||
@ -39,8 +39,8 @@
|
||||
"@ai-sdk/anthropic": "^2.0.5",
|
||||
"@ai-sdk/azure": "^2.0.16",
|
||||
"@ai-sdk/deepseek": "^1.0.9",
|
||||
"@ai-sdk/google": "^2.0.7",
|
||||
"@ai-sdk/openai": "^2.0.19",
|
||||
"@ai-sdk/google": "^2.0.13",
|
||||
"@ai-sdk/openai": "^2.0.26",
|
||||
"@ai-sdk/openai-compatible": "^1.0.9",
|
||||
"@ai-sdk/provider": "^2.0.0",
|
||||
"@ai-sdk/provider-utils": "^3.0.4",
|
||||
|
||||
@ -84,7 +84,6 @@ export class ModelResolver {
|
||||
*/
|
||||
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
|
||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||
console.log('fullModelId', fullModelId)
|
||||
return globalRegistryManagement.languageModel(fullModelId as any)
|
||||
}
|
||||
|
||||
|
||||
@ -27,10 +27,20 @@ export class StreamEventManager {
|
||||
/**
|
||||
* 发送步骤完成事件
|
||||
*/
|
||||
sendStepFinishEvent(controller: StreamController, chunk: any): void {
|
||||
sendStepFinishEvent(
|
||||
controller: StreamController,
|
||||
chunk: any,
|
||||
context: AiRequestContext,
|
||||
finishReason: string = 'stop'
|
||||
): void {
|
||||
// 累加当前步骤的 usage
|
||||
if (chunk.usage && context.accumulatedUsage) {
|
||||
this.accumulateUsage(context.accumulatedUsage, chunk.usage)
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
type: 'finish-step',
|
||||
finishReason: 'stop',
|
||||
finishReason,
|
||||
response: chunk.response,
|
||||
usage: chunk.usage,
|
||||
providerMetadata: chunk.providerMetadata
|
||||
@ -43,28 +53,32 @@ export class StreamEventManager {
|
||||
async handleRecursiveCall(
|
||||
controller: StreamController,
|
||||
recursiveParams: any,
|
||||
context: AiRequestContext,
|
||||
stepId: string
|
||||
context: AiRequestContext
|
||||
): Promise<void> {
|
||||
try {
|
||||
console.log('[MCP Prompt] Starting recursive call after tool execution...')
|
||||
// try {
|
||||
// 重置工具执行状态,准备处理新的步骤
|
||||
context.hasExecutedToolsInCurrentStep = false
|
||||
|
||||
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||
|
||||
if (recursiveResult && recursiveResult.fullStream) {
|
||||
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
|
||||
} else {
|
||||
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
||||
}
|
||||
} catch (error) {
|
||||
this.handleRecursiveCallError(controller, error, stepId)
|
||||
if (recursiveResult && recursiveResult.fullStream) {
|
||||
await this.pipeRecursiveStream(controller, recursiveResult.fullStream, context)
|
||||
} else {
|
||||
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
||||
}
|
||||
// } catch (error) {
|
||||
// this.handleRecursiveCallError(controller, error, stepId)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
* 将递归流的数据传递到当前流
|
||||
*/
|
||||
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
|
||||
private async pipeRecursiveStream(
|
||||
controller: StreamController,
|
||||
recursiveStream: ReadableStream,
|
||||
context?: AiRequestContext
|
||||
): Promise<void> {
|
||||
const reader = recursiveStream.getReader()
|
||||
try {
|
||||
while (true) {
|
||||
@ -73,9 +87,16 @@ export class StreamEventManager {
|
||||
break
|
||||
}
|
||||
if (value.type === 'finish') {
|
||||
// 迭代的流不发finish
|
||||
// 迭代的流不发finish,但需要累加其 usage
|
||||
if (value.usage && context?.accumulatedUsage) {
|
||||
this.accumulateUsage(context.accumulatedUsage, value.usage)
|
||||
}
|
||||
break
|
||||
}
|
||||
// 对于 finish-step 类型,累加其 usage
|
||||
if (value.type === 'finish-step' && value.usage && context?.accumulatedUsage) {
|
||||
this.accumulateUsage(context.accumulatedUsage, value.usage)
|
||||
}
|
||||
// 将递归流的数据传递到当前流
|
||||
controller.enqueue(value)
|
||||
}
|
||||
@ -87,25 +108,25 @@ export class StreamEventManager {
|
||||
/**
|
||||
* 处理递归调用错误
|
||||
*/
|
||||
private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void {
|
||||
console.error('[MCP Prompt] Recursive call failed:', error)
|
||||
// private handleRecursiveCallError(controller: StreamController, error: unknown): void {
|
||||
// console.error('[MCP Prompt] Recursive call failed:', error)
|
||||
|
||||
// 使用 AI SDK 标准错误格式,但不中断流
|
||||
controller.enqueue({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: error instanceof Error ? error.message : String(error),
|
||||
name: error instanceof Error ? error.name : 'RecursiveCallError'
|
||||
}
|
||||
})
|
||||
// // 使用 AI SDK 标准错误格式,但不中断流
|
||||
// controller.enqueue({
|
||||
// type: 'error',
|
||||
// error: {
|
||||
// message: error instanceof Error ? error.message : String(error),
|
||||
// name: error instanceof Error ? error.name : 'RecursiveCallError'
|
||||
// }
|
||||
// })
|
||||
|
||||
// 继续发送文本增量,保持流的连续性
|
||||
controller.enqueue({
|
||||
type: 'text-delta',
|
||||
id: stepId,
|
||||
text: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||
})
|
||||
}
|
||||
// // // 继续发送文本增量,保持流的连续性
|
||||
// // controller.enqueue({
|
||||
// // type: 'text-delta',
|
||||
// // id: stepId,
|
||||
// // text: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||
// // })
|
||||
// }
|
||||
|
||||
/**
|
||||
* 构建递归调用的参数
|
||||
@ -136,4 +157,18 @@ export class StreamEventManager {
|
||||
|
||||
return recursiveParams
|
||||
}
|
||||
|
||||
/**
|
||||
* 累加 usage 数据
|
||||
*/
|
||||
private accumulateUsage(target: any, source: any): void {
|
||||
if (!target || !source) return
|
||||
|
||||
// 累加各种 token 类型
|
||||
target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0)
|
||||
target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0)
|
||||
target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0)
|
||||
target.reasoningTokens = (target.reasoningTokens || 0) + (source.reasoningTokens || 0)
|
||||
target.cachedInputTokens = (target.cachedInputTokens || 0) + (source.cachedInputTokens || 0)
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
* 负责工具的执行、结果格式化和相关事件发送
|
||||
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||
*/
|
||||
import type { ToolSet } from 'ai'
|
||||
import type { ToolSet, TypedToolError } from 'ai'
|
||||
|
||||
import type { ToolUseResult } from './type'
|
||||
|
||||
@ -38,7 +38,6 @@ export class ToolExecutor {
|
||||
controller: StreamController
|
||||
): Promise<ExecutedResult[]> {
|
||||
const executedResults: ExecutedResult[] = []
|
||||
|
||||
for (const toolUse of toolUses) {
|
||||
try {
|
||||
const tool = tools[toolUse.toolName]
|
||||
@ -46,17 +45,12 @@ export class ToolExecutor {
|
||||
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
|
||||
}
|
||||
|
||||
// 发送工具调用开始事件
|
||||
this.sendToolStartEvents(controller, toolUse)
|
||||
|
||||
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
|
||||
|
||||
// 发送 tool-call 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-call',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
input: tool.inputSchema
|
||||
input: toolUse.arguments
|
||||
})
|
||||
|
||||
const result = await tool.execute(toolUse.arguments, {
|
||||
@ -111,45 +105,46 @@ export class ToolExecutor {
|
||||
/**
|
||||
* 发送工具调用开始相关事件
|
||||
*/
|
||||
private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
|
||||
// 发送 tool-input-start 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-input-start',
|
||||
id: toolUse.id,
|
||||
toolName: toolUse.toolName
|
||||
})
|
||||
}
|
||||
// private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
|
||||
// // 发送 tool-input-start 事件
|
||||
// controller.enqueue({
|
||||
// type: 'tool-input-start',
|
||||
// id: toolUse.id,
|
||||
// toolName: toolUse.toolName
|
||||
// })
|
||||
// }
|
||||
|
||||
/**
|
||||
* 处理工具执行错误
|
||||
*/
|
||||
private handleToolError(
|
||||
private handleToolError<T extends ToolSet>(
|
||||
toolUse: ToolUseResult,
|
||||
error: unknown,
|
||||
controller: StreamController
|
||||
// _tools: ToolSet
|
||||
): ExecutedResult {
|
||||
// 使用 AI SDK 标准错误格式
|
||||
// const toolError: TypedToolError<typeof _tools> = {
|
||||
// type: 'tool-error',
|
||||
// toolCallId: toolUse.id,
|
||||
// toolName: toolUse.toolName,
|
||||
// input: toolUse.arguments,
|
||||
// error: error instanceof Error ? error.message : String(error)
|
||||
// }
|
||||
const toolError: TypedToolError<T> = {
|
||||
type: 'tool-error',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
input: toolUse.arguments,
|
||||
error
|
||||
}
|
||||
|
||||
// controller.enqueue(toolError)
|
||||
controller.enqueue(toolError)
|
||||
|
||||
// 发送标准错误事件
|
||||
controller.enqueue({
|
||||
type: 'error',
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
// controller.enqueue({
|
||||
// type: 'tool-error',
|
||||
// toolCallId: toolUse.id,
|
||||
// error: error instanceof Error ? error.message : String(error),
|
||||
// input: toolUse.arguments
|
||||
// })
|
||||
|
||||
return {
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
result: error instanceof Error ? error.message : String(error),
|
||||
result: error,
|
||||
isError: true
|
||||
}
|
||||
}
|
||||
|
||||
@ -8,9 +8,19 @@ import type { TextStreamPart, ToolSet } from 'ai'
|
||||
import { definePlugin } from '../../index'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import { StreamEventManager } from './StreamEventManager'
|
||||
import { type TagConfig, TagExtractor } from './tagExtraction'
|
||||
import { ToolExecutor } from './ToolExecutor'
|
||||
import { PromptToolUseConfig, ToolUseResult } from './type'
|
||||
|
||||
/**
|
||||
* 工具使用标签配置
|
||||
*/
|
||||
const TOOL_USE_TAG_CONFIG: TagConfig = {
|
||||
openingTag: '<tool_use>',
|
||||
closingTag: '</tool_use>',
|
||||
separator: '\n'
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认系统提示符模板(提取自 Cherry Studio)
|
||||
*/
|
||||
@ -249,13 +259,11 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
}
|
||||
|
||||
context.mcpTools = params.tools
|
||||
console.log('tools stored in context', params.tools)
|
||||
|
||||
// 构建系统提示符
|
||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
|
||||
let systemMessage: string | null = systemPrompt
|
||||
console.log('config.context', context)
|
||||
if (config.createSystemMessage) {
|
||||
// 🎯 如果用户提供了自定义处理函数,使用它
|
||||
systemMessage = config.createSystemMessage(systemPrompt, params, context)
|
||||
@ -268,20 +276,40 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
tools: undefined
|
||||
}
|
||||
context.originalParams = transformedParams
|
||||
console.log('transformedParams', transformedParams)
|
||||
return transformedParams
|
||||
},
|
||||
transformStream: (_: any, context: AiRequestContext) => () => {
|
||||
let textBuffer = ''
|
||||
let stepId = ''
|
||||
// let stepId = ''
|
||||
|
||||
if (!context.mcpTools) {
|
||||
throw new Error('No tools available')
|
||||
}
|
||||
|
||||
// 创建工具执行器和流事件管理器
|
||||
// 从 context 中获取或初始化 usage 累加器
|
||||
if (!context.accumulatedUsage) {
|
||||
context.accumulatedUsage = {
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalTokens: 0,
|
||||
reasoningTokens: 0,
|
||||
cachedInputTokens: 0
|
||||
}
|
||||
}
|
||||
|
||||
// 创建工具执行器、流事件管理器和标签提取器
|
||||
const toolExecutor = new ToolExecutor()
|
||||
const streamEventManager = new StreamEventManager()
|
||||
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
|
||||
|
||||
// 在context中初始化工具执行状态,避免递归调用时状态丢失
|
||||
if (!context.hasExecutedToolsInCurrentStep) {
|
||||
context.hasExecutedToolsInCurrentStep = false
|
||||
}
|
||||
|
||||
// 用于hold text-start事件,直到确认有非工具标签内容
|
||||
let pendingTextStart: TextStreamPart<TOOLS> | null = null
|
||||
let hasStartedText = false
|
||||
|
||||
type TOOLS = NonNullable<typeof context.mcpTools>
|
||||
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
|
||||
@ -289,83 +317,106 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
chunk: TextStreamPart<TOOLS>,
|
||||
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
|
||||
) {
|
||||
// 收集文本内容
|
||||
if (chunk.type === 'text-delta') {
|
||||
textBuffer += chunk.text || ''
|
||||
stepId = chunk.id || ''
|
||||
controller.enqueue(chunk)
|
||||
// Hold住text-start事件,直到确认有非工具标签内容
|
||||
if ((chunk as any).type === 'text-start') {
|
||||
pendingTextStart = chunk
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
|
||||
const tools = context.mcpTools
|
||||
if (!tools || Object.keys(tools).length === 0) {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
// text-delta阶段:收集文本内容并过滤工具标签
|
||||
if (chunk.type === 'text-delta') {
|
||||
textBuffer += chunk.text || ''
|
||||
// stepId = chunk.id || ''
|
||||
|
||||
// 解析工具调用
|
||||
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
|
||||
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||
// 使用TagExtractor过滤工具标签,只传递非标签内容到UI层
|
||||
const extractionResults = tagExtractor.processText(chunk.text || '')
|
||||
|
||||
// 如果没有有效的工具调用,直接传递原始事件
|
||||
if (validToolUses.length === 0) {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'text-end') {
|
||||
controller.enqueue({
|
||||
type: 'text-end',
|
||||
id: stepId,
|
||||
providerMetadata: {
|
||||
text: {
|
||||
value: parsedContent
|
||||
}
|
||||
for (const result of extractionResults) {
|
||||
// 只传递非标签内容到UI层
|
||||
if (!result.isTagContent && result.content) {
|
||||
// 如果还没有发送text-start且有pending的text-start,先发送它
|
||||
if (!hasStartedText && pendingTextStart) {
|
||||
controller.enqueue(pendingTextStart)
|
||||
hasStartedText = true
|
||||
pendingTextStart = null
|
||||
}
|
||||
})
|
||||
return
|
||||
|
||||
const filteredChunk = {
|
||||
...chunk,
|
||||
text: result.content
|
||||
}
|
||||
controller.enqueue(filteredChunk)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'text-end') {
|
||||
// 只有当已经发送了text-start时才发送text-end
|
||||
if (hasStartedText) {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'finish-step') {
|
||||
// 统一在finish-step阶段检查并执行工具调用
|
||||
const tools = context.mcpTools
|
||||
if (tools && Object.keys(tools).length > 0 && !context.hasExecutedToolsInCurrentStep) {
|
||||
// 解析完整的textBuffer来检测工具调用
|
||||
const { results: parsedTools } = parseToolUse(textBuffer, tools)
|
||||
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||
|
||||
if (validToolUses.length > 0) {
|
||||
context.hasExecutedToolsInCurrentStep = true
|
||||
|
||||
// 执行工具调用(不需要手动发送 start-step,外部流已经处理)
|
||||
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
|
||||
|
||||
// 发送步骤完成事件,使用 tool-calls 作为 finishReason
|
||||
streamEventManager.sendStepFinishEvent(controller, chunk, context, 'tool-calls')
|
||||
|
||||
// 处理递归调用
|
||||
const toolResultsText = toolExecutor.formatToolResults(executedResults)
|
||||
const recursiveParams = streamEventManager.buildRecursiveParams(
|
||||
context,
|
||||
textBuffer,
|
||||
toolResultsText,
|
||||
tools
|
||||
)
|
||||
|
||||
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
finishReason: 'tool-calls'
|
||||
})
|
||||
|
||||
// 发送步骤开始事件
|
||||
streamEventManager.sendStepStartEvent(controller)
|
||||
|
||||
// 执行工具调用
|
||||
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
|
||||
|
||||
// 发送步骤完成事件
|
||||
streamEventManager.sendStepFinishEvent(controller, chunk)
|
||||
|
||||
// 处理递归调用
|
||||
if (validToolUses.length > 0) {
|
||||
const toolResultsText = toolExecutor.formatToolResults(executedResults)
|
||||
const recursiveParams = streamEventManager.buildRecursiveParams(
|
||||
context,
|
||||
textBuffer,
|
||||
toolResultsText,
|
||||
tools
|
||||
)
|
||||
|
||||
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId)
|
||||
}
|
||||
// 如果没有执行工具调用,直接传递原始finish-step事件
|
||||
controller.enqueue(chunk)
|
||||
|
||||
// 清理状态
|
||||
textBuffer = ''
|
||||
return
|
||||
}
|
||||
|
||||
// 对于其他类型的事件,直接传递
|
||||
controller.enqueue(chunk)
|
||||
// 处理 finish 类型,使用累加后的 totalUsage
|
||||
if (chunk.type === 'finish') {
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
totalUsage: context.accumulatedUsage
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 对于其他类型的事件,直接传递(不包括text-start,已在上面处理)
|
||||
if ((chunk as any).type !== 'text-start') {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
},
|
||||
|
||||
flush() {
|
||||
// 流结束时的清理工作
|
||||
console.log('[MCP Prompt] Stream ended, cleaning up...')
|
||||
// 清理pending状态
|
||||
pendingTextStart = null
|
||||
hasStartedText = false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -27,7 +27,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
|
||||
case 'openai': {
|
||||
if (config.openai) {
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
|
||||
params.tools.web_search = openai.tools.webSearch(config.openai)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
// 核心类型和接口
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
|
||||
import type { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import type { LanguageModel } from 'ai'
|
||||
|
||||
import type { ProviderId } from '../providers'
|
||||
import type { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
@ -9,16 +12,16 @@ export { PluginManager } from './manager'
|
||||
// 工具函数
|
||||
export function createContext<T extends ProviderId>(
|
||||
providerId: T,
|
||||
modelId: string,
|
||||
model: LanguageModel | ImageModelV2,
|
||||
originalParams: any
|
||||
): AiRequestContext {
|
||||
return {
|
||||
providerId,
|
||||
modelId,
|
||||
model,
|
||||
originalParams,
|
||||
metadata: {},
|
||||
startTime: Date.now(),
|
||||
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||
requestId: `${providerId}-${typeof model === 'string' ? model : model?.modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||
// 占位
|
||||
recursiveCall: () => Promise.resolve(null)
|
||||
}
|
||||
|
||||
@ -14,7 +14,7 @@ export type RecursiveCallFn = (newParams: any) => Promise<any>
|
||||
*/
|
||||
export interface AiRequestContext {
|
||||
providerId: ProviderId
|
||||
modelId: string
|
||||
model: LanguageModel | ImageModelV2
|
||||
originalParams: any
|
||||
metadata: Record<string, any>
|
||||
startTime: number
|
||||
|
||||
@ -83,7 +83,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
const context = _context ? _context : createContext(this.providerId, model, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
@ -159,7 +159,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
const context = _context ? _context : createContext(this.providerId, model, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
@ -235,7 +235,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
|
||||
// 创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
const context = _context ? _context : createContext(this.providerId, model, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
|
||||
@ -152,12 +152,14 @@ export class AiSdkToChunkAdapter {
|
||||
// this.toolCallHandler.handleToolCallCreated(chunk)
|
||||
// break
|
||||
case 'tool-call':
|
||||
// 原始的工具调用(未被中间件处理)
|
||||
this.toolCallHandler.handleToolCall(chunk)
|
||||
break
|
||||
|
||||
case 'tool-error':
|
||||
this.toolCallHandler.handleToolError(chunk)
|
||||
break
|
||||
|
||||
case 'tool-result':
|
||||
// 原始的工具调用结果(未被中间件处理)
|
||||
this.toolCallHandler.handleToolResult(chunk)
|
||||
break
|
||||
|
||||
@ -167,7 +169,6 @@ export class AiSdkToChunkAdapter {
|
||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||
// })
|
||||
// break
|
||||
// TODO: 需要区分接口开始和步骤开始
|
||||
// case 'start-step':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.BLOCK_CREATED
|
||||
@ -305,8 +306,6 @@ export class AiSdkToChunkAdapter {
|
||||
break
|
||||
|
||||
default:
|
||||
// 其他类型的 chunk 可以忽略或记录日志
|
||||
// console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -8,34 +8,61 @@ import { loggerService } from '@logger'
|
||||
import { processKnowledgeReferences } from '@renderer/services/KnowledgeService'
|
||||
import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai'
|
||||
// import type {
|
||||
// AnthropicSearchOutput,
|
||||
// WebSearchPluginConfig
|
||||
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
|
||||
import type { ToolSet, TypedToolCall, TypedToolError, TypedToolResult } from 'ai'
|
||||
|
||||
const logger = loggerService.withContext('ToolCallChunkHandler')
|
||||
|
||||
export type ToolcallsMap = {
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
args: any
|
||||
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
|
||||
tool: BaseTool
|
||||
}
|
||||
/**
|
||||
* 工具调用处理器类
|
||||
*/
|
||||
export class ToolCallChunkHandler {
|
||||
// private onChunk: (chunk: Chunk) => void
|
||||
private activeToolCalls = new Map<
|
||||
string,
|
||||
{
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
args: any
|
||||
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
|
||||
tool: BaseTool
|
||||
}
|
||||
>()
|
||||
private static globalActiveToolCalls = new Map<string, ToolcallsMap>()
|
||||
|
||||
private activeToolCalls = ToolCallChunkHandler.globalActiveToolCalls
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
private mcpTools: MCPTool[]
|
||||
) {}
|
||||
|
||||
/**
|
||||
* 内部静态方法:添加活跃工具调用的核心逻辑
|
||||
*/
|
||||
private static addActiveToolCallImpl(toolCallId: string, map: ToolcallsMap): boolean {
|
||||
if (!ToolCallChunkHandler.globalActiveToolCalls.has(toolCallId)) {
|
||||
ToolCallChunkHandler.globalActiveToolCalls.set(toolCallId, map)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* 实例方法:添加活跃工具调用
|
||||
*/
|
||||
private addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
|
||||
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取全局活跃的工具调用
|
||||
*/
|
||||
public static getActiveToolCalls() {
|
||||
return ToolCallChunkHandler.globalActiveToolCalls
|
||||
}
|
||||
|
||||
/**
|
||||
* 静态方法:添加活跃工具调用(外部访问)
|
||||
*/
|
||||
public static addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
|
||||
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
|
||||
}
|
||||
|
||||
// /**
|
||||
// * 设置 onChunk 回调
|
||||
// */
|
||||
@ -43,103 +70,103 @@ export class ToolCallChunkHandler {
|
||||
// this.onChunk = callback
|
||||
// }
|
||||
|
||||
handleToolCallCreated(
|
||||
chunk:
|
||||
| {
|
||||
type: 'tool-input-start'
|
||||
id: string
|
||||
toolName: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
providerExecuted?: boolean
|
||||
}
|
||||
| {
|
||||
type: 'tool-input-end'
|
||||
id: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
}
|
||||
| {
|
||||
type: 'tool-input-delta'
|
||||
id: string
|
||||
delta: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
}
|
||||
): void {
|
||||
switch (chunk.type) {
|
||||
case 'tool-input-start': {
|
||||
// 能拿到说明是mcpTool
|
||||
// if (this.activeToolCalls.get(chunk.id)) return
|
||||
// handleToolCallCreated(
|
||||
// chunk:
|
||||
// | {
|
||||
// type: 'tool-input-start'
|
||||
// id: string
|
||||
// toolName: string
|
||||
// providerMetadata?: ProviderMetadata
|
||||
// providerExecuted?: boolean
|
||||
// }
|
||||
// | {
|
||||
// type: 'tool-input-end'
|
||||
// id: string
|
||||
// providerMetadata?: ProviderMetadata
|
||||
// }
|
||||
// | {
|
||||
// type: 'tool-input-delta'
|
||||
// id: string
|
||||
// delta: string
|
||||
// providerMetadata?: ProviderMetadata
|
||||
// }
|
||||
// ): void {
|
||||
// switch (chunk.type) {
|
||||
// case 'tool-input-start': {
|
||||
// // 能拿到说明是mcpTool
|
||||
// // if (this.activeToolCalls.get(chunk.id)) return
|
||||
|
||||
const tool: BaseTool | MCPTool = {
|
||||
id: chunk.id,
|
||||
name: chunk.toolName,
|
||||
description: chunk.toolName,
|
||||
type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
|
||||
}
|
||||
this.activeToolCalls.set(chunk.id, {
|
||||
toolCallId: chunk.id,
|
||||
toolName: chunk.toolName,
|
||||
args: '',
|
||||
tool
|
||||
})
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: chunk.id,
|
||||
tool: tool,
|
||||
arguments: {},
|
||||
status: 'pending',
|
||||
toolCallId: chunk.id
|
||||
}
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'tool-input-delta': {
|
||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
if (!toolCall) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
return
|
||||
}
|
||||
toolCall.args += chunk.delta
|
||||
break
|
||||
}
|
||||
case 'tool-input-end': {
|
||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
this.activeToolCalls.delete(chunk.id)
|
||||
if (!toolCall) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
return
|
||||
}
|
||||
// const toolResponse: ToolCallResponse = {
|
||||
// id: toolCall.toolCallId,
|
||||
// tool: toolCall.tool,
|
||||
// arguments: toolCall.args,
|
||||
// status: 'pending',
|
||||
// toolCallId: toolCall.toolCallId
|
||||
// }
|
||||
// logger.debug('toolResponse', toolResponse)
|
||||
// this.onChunk({
|
||||
// type: ChunkType.MCP_TOOL_PENDING,
|
||||
// responses: [toolResponse]
|
||||
// })
|
||||
break
|
||||
}
|
||||
}
|
||||
// if (!toolCall) {
|
||||
// Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
// return
|
||||
// }
|
||||
// this.onChunk({
|
||||
// type: ChunkType.MCP_TOOL_CREATED,
|
||||
// tool_calls: [
|
||||
// {
|
||||
// id: chunk.id,
|
||||
// name: chunk.toolName,
|
||||
// status: 'pending'
|
||||
// }
|
||||
// ]
|
||||
// })
|
||||
}
|
||||
// const tool: BaseTool | MCPTool = {
|
||||
// id: chunk.id,
|
||||
// name: chunk.toolName,
|
||||
// description: chunk.toolName,
|
||||
// type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
|
||||
// }
|
||||
// this.activeToolCalls.set(chunk.id, {
|
||||
// toolCallId: chunk.id,
|
||||
// toolName: chunk.toolName,
|
||||
// args: '',
|
||||
// tool
|
||||
// })
|
||||
// const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
// id: chunk.id,
|
||||
// tool: tool,
|
||||
// arguments: {},
|
||||
// status: 'pending',
|
||||
// toolCallId: chunk.id
|
||||
// }
|
||||
// this.onChunk({
|
||||
// type: ChunkType.MCP_TOOL_PENDING,
|
||||
// responses: [toolResponse]
|
||||
// })
|
||||
// break
|
||||
// }
|
||||
// case 'tool-input-delta': {
|
||||
// const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
// if (!toolCall) {
|
||||
// logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
// return
|
||||
// }
|
||||
// toolCall.args += chunk.delta
|
||||
// break
|
||||
// }
|
||||
// case 'tool-input-end': {
|
||||
// const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
// this.activeToolCalls.delete(chunk.id)
|
||||
// if (!toolCall) {
|
||||
// logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
// return
|
||||
// }
|
||||
// // const toolResponse: ToolCallResponse = {
|
||||
// // id: toolCall.toolCallId,
|
||||
// // tool: toolCall.tool,
|
||||
// // arguments: toolCall.args,
|
||||
// // status: 'pending',
|
||||
// // toolCallId: toolCall.toolCallId
|
||||
// // }
|
||||
// // logger.debug('toolResponse', toolResponse)
|
||||
// // this.onChunk({
|
||||
// // type: ChunkType.MCP_TOOL_PENDING,
|
||||
// // responses: [toolResponse]
|
||||
// // })
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// // if (!toolCall) {
|
||||
// // Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
// // return
|
||||
// // }
|
||||
// // this.onChunk({
|
||||
// // type: ChunkType.MCP_TOOL_CREATED,
|
||||
// // tool_calls: [
|
||||
// // {
|
||||
// // id: chunk.id,
|
||||
// // name: chunk.toolName,
|
||||
// // status: 'pending'
|
||||
// // }
|
||||
// // ]
|
||||
// // })
|
||||
// }
|
||||
|
||||
/**
|
||||
* 处理工具调用事件
|
||||
@ -158,7 +185,6 @@ export class ToolCallChunkHandler {
|
||||
|
||||
let tool: BaseTool
|
||||
let mcpTool: MCPTool | undefined
|
||||
|
||||
// 根据 providerExecuted 标志区分处理逻辑
|
||||
if (providerExecuted) {
|
||||
// 如果是 Provider 执行的工具(如 web_search)
|
||||
@ -196,27 +222,25 @@ export class ToolCallChunkHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// 记录活跃的工具调用
|
||||
this.activeToolCalls.set(toolCallId, {
|
||||
this.addActiveToolCall(toolCallId, {
|
||||
toolCallId,
|
||||
toolName,
|
||||
args,
|
||||
tool
|
||||
})
|
||||
|
||||
// 创建 MCPToolResponse 格式
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: tool,
|
||||
arguments: args,
|
||||
status: 'pending',
|
||||
status: 'pending', // 统一使用 pending 状态
|
||||
toolCallId: toolCallId
|
||||
}
|
||||
|
||||
// 调用 onChunk
|
||||
if (this.onChunk) {
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING,
|
||||
type: ChunkType.MCP_TOOL_PENDING, // 统一发送 pending 状态
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
@ -276,4 +300,33 @@ export class ToolCallChunkHandler {
|
||||
})
|
||||
}
|
||||
}
|
||||
handleToolError(
|
||||
chunk: {
|
||||
type: 'tool-error'
|
||||
} & TypedToolError<ToolSet>
|
||||
): void {
|
||||
const { toolCallId, error, input } = chunk
|
||||
const toolCallInfo = this.activeToolCalls.get(toolCallId)
|
||||
if (!toolCallInfo) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
|
||||
return
|
||||
}
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: toolCallInfo.tool,
|
||||
arguments: input,
|
||||
status: 'error',
|
||||
response: error,
|
||||
toolCallId: toolCallId
|
||||
}
|
||||
this.activeToolCalls.delete(toolCallId)
|
||||
if (this.onChunk) {
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const addActiveToolCall = ToolCallChunkHandler.addActiveToolCall.bind(ToolCallChunkHandler)
|
||||
|
||||
@ -265,15 +265,15 @@ export default class ModernAiProvider {
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
const modelId = this.model!.id
|
||||
logger.info('Starting modernCompletions', {
|
||||
modelId,
|
||||
providerId: this.config!.providerId,
|
||||
topicId: config.topicId,
|
||||
hasOnChunk: !!config.onChunk,
|
||||
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||
toolCount: params.tools ? Object.keys(params.tools).length : 0
|
||||
})
|
||||
// const modelId = this.model!.id
|
||||
// logger.info('Starting modernCompletions', {
|
||||
// modelId,
|
||||
// providerId: this.config!.providerId,
|
||||
// topicId: config.topicId,
|
||||
// hasOnChunk: !!config.onChunk,
|
||||
// hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||
// toolCount: params.tools ? Object.keys(params.tools).length : 0
|
||||
// })
|
||||
|
||||
// 根据条件构建插件数组
|
||||
const plugins = buildPlugins(config)
|
||||
|
||||
@ -35,7 +35,7 @@ export function buildPlugins(
|
||||
plugins.push(webSearchPlugin())
|
||||
}
|
||||
// 2. 支持工具调用时添加搜索插件
|
||||
if (middlewareConfig.isSupportedToolUse) {
|
||||
if (middlewareConfig.isSupportedToolUse || middlewareConfig.isPromptToolUse) {
|
||||
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant, middlewareConfig.topicId || ''))
|
||||
}
|
||||
|
||||
@ -45,12 +45,13 @@ export function buildPlugins(
|
||||
}
|
||||
|
||||
// 4. 启用Prompt工具调用时添加工具插件
|
||||
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||
if (middlewareConfig.isPromptToolUse) {
|
||||
plugins.push(
|
||||
createPromptToolUsePlugin({
|
||||
enabled: true,
|
||||
createSystemMessage: (systemPrompt, params, context) => {
|
||||
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
|
||||
const modelId = typeof context.model === 'string' ? context.model : context.model.modelId
|
||||
if (modelId.includes('o1-mini') || modelId.includes('o1-preview')) {
|
||||
if (context.isRecursiveCall) {
|
||||
return null
|
||||
}
|
||||
|
||||
@ -19,7 +19,8 @@ import store from '@renderer/store'
|
||||
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
||||
import type { Assistant } from '@renderer/types'
|
||||
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
|
||||
import type { ModelMessage } from 'ai'
|
||||
import type { LanguageModel, ModelMessage } from 'ai'
|
||||
import { generateText } from 'ai'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||
@ -76,9 +77,7 @@ async function analyzeSearchIntent(
|
||||
shouldKnowledgeSearch?: boolean
|
||||
shouldMemorySearch?: boolean
|
||||
lastAnswer?: ModelMessage
|
||||
context: AiRequestContext & {
|
||||
isAnalyzing?: boolean
|
||||
}
|
||||
context: AiRequestContext
|
||||
topicId: string
|
||||
}
|
||||
): Promise<ExtractResults | undefined> {
|
||||
@ -122,9 +121,7 @@ async function analyzeSearchIntent(
|
||||
logger.error('Provider not found or missing API key')
|
||||
return getFallbackResult()
|
||||
}
|
||||
// console.log('formattedPrompt', schema)
|
||||
try {
|
||||
context.isAnalyzing = true
|
||||
logger.info('Starting intent analysis generateText call', {
|
||||
modelId: model.id,
|
||||
topicId: options.topicId,
|
||||
@ -133,18 +130,16 @@ async function analyzeSearchIntent(
|
||||
hasKnowledgeSearch: needKnowledgeExtract
|
||||
})
|
||||
|
||||
const { text: result } = await context.executor
|
||||
.generateText(model.id, {
|
||||
prompt: formattedPrompt
|
||||
})
|
||||
.finally(() => {
|
||||
context.isAnalyzing = false
|
||||
logger.info('Intent analysis generateText call completed', {
|
||||
modelId: model.id,
|
||||
topicId: options.topicId,
|
||||
requestId: context.requestId
|
||||
})
|
||||
const { text: result } = await generateText({
|
||||
model: context.model as LanguageModel,
|
||||
prompt: formattedPrompt
|
||||
}).finally(() => {
|
||||
logger.info('Intent analysis generateText call completed', {
|
||||
modelId: model.id,
|
||||
topicId: options.topicId,
|
||||
requestId: context.requestId
|
||||
})
|
||||
})
|
||||
const parsedResult = extractInfoFromXML(result)
|
||||
logger.debug('Intent analysis result', { parsedResult })
|
||||
|
||||
@ -183,7 +178,6 @@ async function storeConversationMemory(
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
|
||||
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||
// console.log('Memory storage is disabled')
|
||||
return
|
||||
}
|
||||
|
||||
@ -245,25 +239,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
||||
// 存储意图分析结果
|
||||
const intentAnalysisResults: { [requestId: string]: ExtractResults } = {}
|
||||
const userMessages: { [requestId: string]: ModelMessage } = {}
|
||||
let currentContext: AiRequestContext | null = null
|
||||
|
||||
return definePlugin({
|
||||
name: 'search-orchestration',
|
||||
enforce: 'pre', // 确保在其他插件之前执行
|
||||
|
||||
configureContext: (context: AiRequestContext) => {
|
||||
if (currentContext) {
|
||||
context.isAnalyzing = currentContext.isAnalyzing
|
||||
}
|
||||
currentContext = context
|
||||
},
|
||||
|
||||
/**
|
||||
* 🔍 Step 1: 意图识别阶段
|
||||
*/
|
||||
onRequestStart: async (context: AiRequestContext) => {
|
||||
if (context.isAnalyzing) return
|
||||
|
||||
// 没开启任何搜索则不进行意图分析
|
||||
if (!(assistant.webSearchProviderId || assistant.knowledge_bases?.length || assistant.enableMemory)) return
|
||||
|
||||
@ -315,7 +298,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
||||
* 🔧 Step 2: 工具配置阶段
|
||||
*/
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
if (context.isAnalyzing) return params
|
||||
// logger.info('🔧 Configuring tools based on intent...', context.requestId)
|
||||
|
||||
try {
|
||||
@ -409,7 +391,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
||||
// context.isAnalyzing = false
|
||||
// logger.info('context.isAnalyzing', context, result)
|
||||
// logger.info('💾 Starting memory storage...', context.requestId)
|
||||
if (context.isAnalyzing) return
|
||||
try {
|
||||
const messages = context.originalParams.messages
|
||||
|
||||
|
||||
@ -19,8 +19,6 @@ export const memorySearchTool = () => {
|
||||
limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return')
|
||||
}),
|
||||
execute: async ({ query, limit = 5 }) => {
|
||||
// console.log('🧠 [memorySearchTool] Searching memories:', { query, limit })
|
||||
|
||||
try {
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
if (!globalMemoryEnabled) {
|
||||
@ -29,7 +27,6 @@ export const memorySearchTool = () => {
|
||||
|
||||
const memoryConfig = selectMemoryConfig(store.getState())
|
||||
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||
// console.warn('Memory search skipped: embedding or LLM model not configured')
|
||||
return []
|
||||
}
|
||||
|
||||
@ -40,12 +37,10 @@ export const memorySearchTool = () => {
|
||||
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
|
||||
|
||||
if (relevantMemories?.length > 0) {
|
||||
// console.log('🧠 [memorySearchTool] Found memories:', relevantMemories.length)
|
||||
return relevantMemories
|
||||
}
|
||||
return []
|
||||
} catch (error) {
|
||||
// console.error('🧠 [memorySearchTool] Error:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
@ -84,8 +79,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
|
||||
.optional()
|
||||
}) satisfies z.ZodSchema<MemorySearchWithExtractionInput>,
|
||||
execute: async ({ userMessage }) => {
|
||||
// console.log('🧠 [memorySearchToolWithExtraction] Processing:', { userMessage, lastAnswer })
|
||||
|
||||
try {
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||
@ -97,7 +90,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
|
||||
|
||||
const memoryConfig = selectMemoryConfig(store.getState())
|
||||
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||
// console.warn('Memory search skipped: embedding or LLM model not configured')
|
||||
return {
|
||||
extractedKeywords: 'Memory models not configured',
|
||||
searchResults: []
|
||||
@ -125,7 +117,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
|
||||
)
|
||||
|
||||
if (relevantMemories?.length > 0) {
|
||||
// console.log('🧠 [memorySearchToolWithExtraction] Found memories:', relevantMemories.length)
|
||||
return {
|
||||
extractedKeywords: content,
|
||||
searchResults: relevantMemories
|
||||
@ -137,7 +128,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
|
||||
searchResults: []
|
||||
}
|
||||
} catch (error) {
|
||||
// console.error('🧠 [memorySearchToolWithExtraction] Error:', error)
|
||||
return {
|
||||
extractedKeywords: 'Search failed',
|
||||
searchResults: []
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import { loggerService } from '@logger'
|
||||
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
||||
import { MCPTool, MCPToolResponse } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { callMCPTool, getMcpServerByTool, isToolAutoApproved } from '@renderer/utils/mcp-tools'
|
||||
import { requestToolConfirmation } from '@renderer/utils/userConfirmation'
|
||||
import { type Tool, type ToolSet } from 'ai'
|
||||
@ -33,8 +31,36 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
|
||||
tools[mcpTool.name] = tool({
|
||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||
execute: async (params, { toolCallId, experimental_context }) => {
|
||||
const { onChunk } = experimental_context as { onChunk: (chunk: Chunk) => void }
|
||||
execute: async (params, { toolCallId }) => {
|
||||
// 检查是否启用自动批准
|
||||
const server = getMcpServerByTool(mcpTool)
|
||||
const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server)
|
||||
|
||||
let confirmed = true
|
||||
|
||||
if (!isAutoApproveEnabled) {
|
||||
// 请求用户确认
|
||||
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`)
|
||||
confirmed = await requestToolConfirmation(toolCallId)
|
||||
}
|
||||
|
||||
if (!confirmed) {
|
||||
// 用户拒绝执行工具
|
||||
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `User declined to execute tool "${mcpTool.name}".`
|
||||
}
|
||||
],
|
||||
isError: false
|
||||
}
|
||||
}
|
||||
|
||||
// 用户确认或自动批准,执行工具
|
||||
logger.debug(`Executing tool: ${mcpTool.name}`)
|
||||
|
||||
// 创建适配的 MCPToolResponse 对象
|
||||
const toolResponse: MCPToolResponse = {
|
||||
id: toolCallId,
|
||||
@ -44,53 +70,18 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
|
||||
toolCallId
|
||||
}
|
||||
|
||||
try {
|
||||
// 检查是否启用自动批准
|
||||
const server = getMcpServerByTool(mcpTool)
|
||||
const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server)
|
||||
const result = await callMCPTool(toolResponse)
|
||||
|
||||
let confirmed = true
|
||||
if (!isAutoApproveEnabled) {
|
||||
// 请求用户确认
|
||||
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`)
|
||||
confirmed = await requestToolConfirmation(toolResponse.id)
|
||||
}
|
||||
|
||||
if (!confirmed) {
|
||||
// 用户拒绝执行工具
|
||||
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `User declined to execute tool "${mcpTool.name}".`
|
||||
}
|
||||
],
|
||||
isError: false
|
||||
}
|
||||
}
|
||||
|
||||
// 用户确认或自动批准,执行工具
|
||||
toolResponse.status = 'invoking'
|
||||
logger.debug(`Executing tool: ${mcpTool.name}`)
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.MCP_TOOL_IN_PROGRESS,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
|
||||
const result = await callMCPTool(toolResponse)
|
||||
|
||||
// 返回结果,AI SDK 会处理序列化
|
||||
if (result.isError) {
|
||||
throw new Error(result.content?.[0]?.text || 'Tool execution failed')
|
||||
}
|
||||
// 返回工具执行结果
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.error(`MCP Tool execution failed: ${mcpTool.name}`, { error })
|
||||
throw error
|
||||
// 返回结果,AI SDK 会处理序列化
|
||||
if (result.isError) {
|
||||
// throw new Error(result.content?.[0]?.text || 'Tool execution failed')
|
||||
return Promise.reject(result)
|
||||
}
|
||||
// 返回工具执行结果
|
||||
return result
|
||||
// } catch (error) {
|
||||
// logger.error(`MCP Tool execution failed: ${mcpTool.name}`, { error })
|
||||
// }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -105,7 +105,6 @@ const ThinkingBlock: React.FC<Props> = ({ block }) => {
|
||||
const ThinkingTimeSeconds = memo(
|
||||
({ blockThinkingTime, isThinking }: { blockThinkingTime: number; isThinking: boolean }) => {
|
||||
const { t } = useTranslation()
|
||||
// console.log('blockThinkingTime', blockThinkingTime)
|
||||
// const [thinkingTime, setThinkingTime] = useState(blockThinkingTime || 0)
|
||||
|
||||
// FIXME: 这里统计的和请求处统计的有一定误差
|
||||
|
||||
@ -186,7 +186,6 @@ const MessageMenubar: FC<Props> = (props) => {
|
||||
try {
|
||||
await translateText(mainTextContent, language, translationUpdater)
|
||||
} catch (error) {
|
||||
// console.error('Translation failed:', error)
|
||||
window.message.error({ content: t('translate.error.failed'), key: 'translate-message' })
|
||||
// 理应只有一个
|
||||
const translationBlocks = findTranslationBlocksById(message.id)
|
||||
|
||||
@ -60,14 +60,29 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
||||
const toolResponse = block.metadata?.rawMcpToolResponse
|
||||
|
||||
const { id, tool, status, response } = toolResponse!
|
||||
|
||||
const isPending = status === 'pending'
|
||||
const isInvoking = status === 'invoking'
|
||||
const isDone = status === 'done'
|
||||
const isError = status === 'error'
|
||||
|
||||
const isAutoApproved = useMemo(
|
||||
() =>
|
||||
isToolAutoApproved(
|
||||
tool,
|
||||
mcpServers.find((s) => s.id === tool.serverId)
|
||||
),
|
||||
[tool, mcpServers]
|
||||
)
|
||||
|
||||
// 增加本地状态来跟踪用户确认
|
||||
const [isConfirmed, setIsConfirmed] = useState(isAutoApproved)
|
||||
|
||||
// 判断不同的UI状态
|
||||
const isWaitingConfirmation = isPending && !isAutoApproved && !isConfirmed
|
||||
const isExecuting = isPending && (isAutoApproved || isConfirmed)
|
||||
|
||||
const timer = useRef<NodeJS.Timeout | null>(null)
|
||||
useEffect(() => {
|
||||
if (!isPending) return
|
||||
if (!isWaitingConfirmation) return
|
||||
|
||||
if (countdown > 0) {
|
||||
timer.current = setTimeout(() => {
|
||||
@ -75,6 +90,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
||||
setCountdown((prev) => prev - 1)
|
||||
}, 1000)
|
||||
} else if (countdown === 0) {
|
||||
setIsConfirmed(true)
|
||||
confirmToolAction(id)
|
||||
}
|
||||
|
||||
@ -83,7 +99,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
||||
clearTimeout(timer.current)
|
||||
}
|
||||
}
|
||||
}, [countdown, id, isPending])
|
||||
}, [countdown, id, isWaitingConfirmation])
|
||||
|
||||
useEffect(() => {
|
||||
const removeListener = window.electron.ipcRenderer.on(
|
||||
@ -146,6 +162,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
||||
|
||||
const handleConfirmTool = () => {
|
||||
cancelCountdown()
|
||||
setIsConfirmed(true)
|
||||
confirmToolAction(id)
|
||||
}
|
||||
|
||||
@ -195,6 +212,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
||||
updateMCPServer(updatedServer)
|
||||
|
||||
// Also confirm the current tool
|
||||
setIsConfirmed(true)
|
||||
confirmToolAction(id)
|
||||
|
||||
window.message.success({
|
||||
@ -206,32 +224,31 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
||||
const renderStatusIndicator = (status: string, hasError: boolean) => {
|
||||
let label = ''
|
||||
let icon: React.ReactNode | null = null
|
||||
switch (status) {
|
||||
case 'pending':
|
||||
|
||||
if (status === 'pending') {
|
||||
if (isWaitingConfirmation) {
|
||||
label = t('message.tools.pending', 'Awaiting Approval')
|
||||
icon = <LoadingIcon style={{ marginLeft: 6, color: 'var(--status-color-warning)' }} />
|
||||
break
|
||||
case 'invoking':
|
||||
} else if (isExecuting) {
|
||||
label = t('message.tools.invoking')
|
||||
icon = <LoadingIcon style={{ marginLeft: 6 }} />
|
||||
break
|
||||
case 'cancelled':
|
||||
label = t('message.tools.cancelled')
|
||||
icon = <X size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
||||
break
|
||||
case 'done':
|
||||
if (hasError) {
|
||||
label = t('message.tools.error')
|
||||
icon = <TriangleAlert size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
||||
} else {
|
||||
label = t('message.tools.completed')
|
||||
icon = <Check size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
||||
}
|
||||
break
|
||||
default:
|
||||
label = ''
|
||||
icon = null
|
||||
}
|
||||
} else if (status === 'cancelled') {
|
||||
label = t('message.tools.cancelled')
|
||||
icon = <X size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
||||
} else if (status === 'done') {
|
||||
if (hasError) {
|
||||
label = t('message.tools.error')
|
||||
icon = <TriangleAlert size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
||||
} else {
|
||||
label = t('message.tools.completed')
|
||||
icon = <Check size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
||||
}
|
||||
} else if (status === 'error') {
|
||||
label = t('message.tools.error')
|
||||
icon = <TriangleAlert size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
||||
}
|
||||
|
||||
return (
|
||||
<StatusIndicator status={status} hasError={hasError}>
|
||||
{label}
|
||||
@ -248,7 +265,6 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
||||
params: toolResponse.arguments,
|
||||
response: toolResponse.response
|
||||
}
|
||||
|
||||
items.push({
|
||||
key: id,
|
||||
label: (
|
||||
@ -283,7 +299,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
||||
<Maximize size={14} />
|
||||
</ActionButton>
|
||||
</Tooltip>
|
||||
{!isPending && !isInvoking && (
|
||||
{!isPending && (
|
||||
<Tooltip title={t('common.copy')} mouseEnterDelay={0.5}>
|
||||
<ActionButton
|
||||
className="message-action-button"
|
||||
@ -301,7 +317,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
||||
</MessageTitleLabel>
|
||||
),
|
||||
children:
|
||||
isDone && result ? (
|
||||
(isDone || isError) && result ? (
|
||||
<ToolResponseContainer
|
||||
style={{
|
||||
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
|
||||
@ -370,7 +386,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
||||
}
|
||||
}}>
|
||||
<ToolContainer>
|
||||
<ToolContentWrapper className={status}>
|
||||
<ToolContentWrapper className={isPending ? 'pending' : status}>
|
||||
<CollapseContainer
|
||||
ghost
|
||||
activeKey={activeKeys}
|
||||
@ -383,14 +399,16 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
||||
<ExpandIcon $isActive={isActive} size={18} color="var(--color-text-3)" strokeWidth={1.5} />
|
||||
)}
|
||||
/>
|
||||
{(isPending || isInvoking) && (
|
||||
{isPending && (
|
||||
<ActionsBar>
|
||||
<ActionLabel>
|
||||
{isPending ? t('settings.mcp.tools.autoApprove.tooltip.confirm') : t('message.tools.invoking')}
|
||||
{isWaitingConfirmation
|
||||
? t('settings.mcp.tools.autoApprove.tooltip.confirm')
|
||||
: t('message.tools.invoking')}
|
||||
</ActionLabel>
|
||||
|
||||
<ActionButtonsGroup>
|
||||
{isPending && (
|
||||
{isWaitingConfirmation && (
|
||||
<Button
|
||||
color="danger"
|
||||
variant="filled"
|
||||
@ -402,7 +420,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
)}
|
||||
{isInvoking && toolResponse?.id ? (
|
||||
{isExecuting && toolResponse?.id ? (
|
||||
<Button
|
||||
size="small"
|
||||
color="danger"
|
||||
@ -416,29 +434,31 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
||||
{t('chat.input.pause')}
|
||||
</Button>
|
||||
) : (
|
||||
<StyledDropdownButton
|
||||
size="small"
|
||||
type="primary"
|
||||
icon={<ChevronDown size={14} />}
|
||||
onClick={() => {
|
||||
handleConfirmTool()
|
||||
}}
|
||||
menu={{
|
||||
items: [
|
||||
{
|
||||
key: 'autoApprove',
|
||||
label: t('settings.mcp.tools.autoApprove.label'),
|
||||
onClick: () => {
|
||||
handleAutoApprove()
|
||||
isWaitingConfirmation && (
|
||||
<StyledDropdownButton
|
||||
size="small"
|
||||
type="primary"
|
||||
icon={<ChevronDown size={14} />}
|
||||
onClick={() => {
|
||||
handleConfirmTool()
|
||||
}}
|
||||
menu={{
|
||||
items: [
|
||||
{
|
||||
key: 'autoApprove',
|
||||
label: t('settings.mcp.tools.autoApprove.label'),
|
||||
onClick: () => {
|
||||
handleAutoApprove()
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}}>
|
||||
<CirclePlay size={15} className="lucide-custom" />
|
||||
<CountdownText>
|
||||
{t('settings.mcp.tools.run', 'Run')} ({countdown}s)
|
||||
</CountdownText>
|
||||
</StyledDropdownButton>
|
||||
]
|
||||
}}>
|
||||
<CirclePlay size={15} className="lucide-custom" />
|
||||
<CountdownText>
|
||||
{t('settings.mcp.tools.run', 'Run')} ({countdown}s)
|
||||
</CountdownText>
|
||||
</StyledDropdownButton>
|
||||
)
|
||||
)}
|
||||
</ActionButtonsGroup>
|
||||
</ActionsBar>
|
||||
@ -542,8 +562,7 @@ const ToolContentWrapper = styled.div`
|
||||
border: 1px solid var(--color-border);
|
||||
}
|
||||
|
||||
&.pending,
|
||||
&.invoking {
|
||||
&.pending {
|
||||
background-color: var(--color-background-soft);
|
||||
.ant-collapse {
|
||||
border: none;
|
||||
@ -663,6 +682,8 @@ const StatusIndicator = styled.span<{ status: string; hasError?: boolean }>`
|
||||
return 'var(--status-color-error)'
|
||||
case 'done':
|
||||
return props.hasError ? 'var(--status-color-error)' : 'var(--status-color-success)'
|
||||
case 'error':
|
||||
return 'var(--status-color-error)'
|
||||
default:
|
||||
return 'var(--color-text)'
|
||||
}
|
||||
|
||||
@ -99,6 +99,8 @@ export async function fetchChatCompletion({
|
||||
const provider = AI.getActualProvider()
|
||||
|
||||
const mcpTools: MCPTool[] = []
|
||||
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
|
||||
if (isPromptToolUse(assistant) || isSupportedToolUse(assistant)) {
|
||||
mcpTools.push(...(await fetchMcpTools(assistant)))
|
||||
}
|
||||
@ -137,7 +139,6 @@ export async function fetchChatCompletion({
|
||||
}
|
||||
|
||||
// --- Call AI Completions ---
|
||||
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
await AI.completions(modelId, aiSdkParams, {
|
||||
...middlewareConfig,
|
||||
assistant,
|
||||
|
||||
@ -178,12 +178,7 @@ class WebSearchService {
|
||||
formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}`
|
||||
}
|
||||
|
||||
// try {
|
||||
return await webSearchEngine.search(formattedQuery, websearch, httpOptions)
|
||||
// } catch (error) {
|
||||
// console.error('Search failed:', error)
|
||||
// throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -69,16 +69,6 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
|
||||
})
|
||||
await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
|
||||
},
|
||||
// onBlockCreated: async () => {
|
||||
// if (blockManager.hasInitialPlaceholder) {
|
||||
// return
|
||||
// }
|
||||
// console.log('onBlockCreated')
|
||||
// const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, {
|
||||
// status: MessageBlockStatus.PROCESSING
|
||||
// })
|
||||
// await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
|
||||
// },
|
||||
|
||||
onError: async (error: AISDKError) => {
|
||||
logger.debug('onError', error)
|
||||
|
||||
@ -51,28 +51,6 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
||||
}
|
||||
},
|
||||
|
||||
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
|
||||
// 根据 toolResponse.id 查找对应的块ID
|
||||
const targetBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
||||
|
||||
if (targetBlockId && toolResponse.status === 'invoking') {
|
||||
const changes = {
|
||||
status: MessageBlockStatus.PROCESSING,
|
||||
metadata: { rawMcpToolResponse: toolResponse }
|
||||
}
|
||||
blockManager.smartBlockUpdate(targetBlockId, changes, MessageBlockType.TOOL)
|
||||
} else if (!targetBlockId) {
|
||||
logger.warn(
|
||||
`[onToolCallInProgress] No block ID found for tool ID: ${toolResponse.id}. Available mappings:`,
|
||||
Array.from(toolCallIdToBlockIdMap.entries())
|
||||
)
|
||||
} else {
|
||||
logger.warn(
|
||||
`[onToolCallInProgress] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
||||
)
|
||||
}
|
||||
},
|
||||
|
||||
onToolCallComplete: (toolResponse: MCPToolResponse) => {
|
||||
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
||||
toolCallIdToBlockIdMap.delete(toolResponse.id)
|
||||
@ -104,7 +82,6 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
||||
stack: null
|
||||
}
|
||||
}
|
||||
|
||||
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
|
||||
|
||||
// Handle citation block creation for web search results
|
||||
@ -132,7 +109,6 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
||||
citationBlockId = citationBlock.id
|
||||
blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
}
|
||||
// TODO: 处理 memory 引用
|
||||
} else {
|
||||
logger.warn(
|
||||
`[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
||||
|
||||
@ -458,18 +458,18 @@ describe('streamCallback Integration Tests', () => {
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
type: ChunkType.MCP_TOOL_IN_PROGRESS,
|
||||
responses: [
|
||||
{
|
||||
id: 'tool-call-1',
|
||||
tool: mockTool,
|
||||
arguments: { testArg: 'value' },
|
||||
status: 'invoking' as const,
|
||||
response: ''
|
||||
}
|
||||
]
|
||||
},
|
||||
// {
|
||||
// type: ChunkType.MCP_TOOL_PENDING,
|
||||
// responses: [
|
||||
// {
|
||||
// id: 'tool-call-1',
|
||||
// tool: mockTool,
|
||||
// arguments: { testArg: 'value' },
|
||||
// status: 'invoking' as const,
|
||||
// response: ''
|
||||
// }
|
||||
// ]
|
||||
// },
|
||||
{
|
||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||
responses: [
|
||||
@ -611,18 +611,18 @@ describe('streamCallback Integration Tests', () => {
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
type: ChunkType.MCP_TOOL_IN_PROGRESS,
|
||||
responses: [
|
||||
{
|
||||
id: 'tool-call-1',
|
||||
tool: mockCalculatorTool,
|
||||
arguments: { operation: 'add', a: 1, b: 2 },
|
||||
status: 'invoking' as const,
|
||||
response: ''
|
||||
}
|
||||
]
|
||||
},
|
||||
// {
|
||||
// type: ChunkType.MCP_TOOL_PENDING,
|
||||
// responses: [
|
||||
// {
|
||||
// id: 'tool-call-1',
|
||||
// tool: mockCalculatorTool,
|
||||
// arguments: { operation: 'add', a: 1, b: 2 },
|
||||
// status: 'invoking' as const,
|
||||
// response: ''
|
||||
// }
|
||||
// ]
|
||||
// },
|
||||
{
|
||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||
responses: [
|
||||
|
||||
@ -826,9 +826,7 @@ export function mcpToolCallResponseToAwsBedrockMessage(
|
||||
}
|
||||
|
||||
/**
|
||||
* 是否启用工具使用
|
||||
* 1. 如果模型支持函数调用,则启用工具使用
|
||||
* 2. 如果工具使用模式为 prompt,则启用工具使用
|
||||
* 是否启用工具使用(function call)
|
||||
* @param assistant
|
||||
* @returns 是否启用工具使用
|
||||
*/
|
||||
|
||||
51
yarn.lock
51
yarn.lock
@ -179,15 +179,15 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/google@npm:^2.0.7":
|
||||
version: 2.0.7
|
||||
resolution: "@ai-sdk/google@npm:2.0.7"
|
||||
"@ai-sdk/google@npm:^2.0.13":
|
||||
version: 2.0.13
|
||||
resolution: "@ai-sdk/google@npm:2.0.13"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.4"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.8"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4
|
||||
checksum: 10c0/bde4c95a2a167355cda18de9d5b273d562d2a724f650ca69016daa8df2766280487e143cf0cdd96f6654c255d587a680c6a937b280eb734ca2c35d6f9b9e943c
|
||||
checksum: 10c0/a05210de11d7ab41d49bcd0330c37f4116441b149d8ccc9b6bc5eaa12ea42bae82364dc2cd09502734b15115071f07395525806ea4998930b285b1ce74102186
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
@ -227,15 +227,15 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/openai@npm:^2.0.19":
|
||||
version: 2.0.19
|
||||
resolution: "@ai-sdk/openai@npm:2.0.19"
|
||||
"@ai-sdk/openai@npm:^2.0.26":
|
||||
version: 2.0.26
|
||||
resolution: "@ai-sdk/openai@npm:2.0.26"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.5"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.8"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4
|
||||
checksum: 10c0/04db695669d783a810b80283e0cd48f6e7654667fd76ca2d35c7cffae6fdd68fb0473118e4e097ef1352f4432dd7c15c07f873d712b940c72495e5839b0ede98
|
||||
checksum: 10c0/b8cb01c0c38525c38901f41f1693cd15589932a2aceddea14bed30f44719532a5e74615fb0e974eff1a0513048ac204c27456ff8829a9c811d1461cc635c9cc5
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
@ -267,20 +267,6 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/provider-utils@npm:3.0.5":
|
||||
version: 3.0.5
|
||||
resolution: "@ai-sdk/provider-utils@npm:3.0.5"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@standard-schema/spec": "npm:^1.0.0"
|
||||
eventsource-parser: "npm:^3.0.3"
|
||||
zod-to-json-schema: "npm:^3.24.1"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4
|
||||
checksum: 10c0/4057810b320bda149a178dc1bfc9cdd592ca88b736c3c22bd0c1f8111c75ef69beec4a523f363e5d0d120348b876942fd66c0bb4965864da4c12c5cfddee15a3
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/provider-utils@npm:3.0.7":
|
||||
version: 3.0.7
|
||||
resolution: "@ai-sdk/provider-utils@npm:3.0.7"
|
||||
@ -294,6 +280,19 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/provider-utils@npm:3.0.8":
|
||||
version: 3.0.8
|
||||
resolution: "@ai-sdk/provider-utils@npm:3.0.8"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@standard-schema/spec": "npm:^1.0.0"
|
||||
eventsource-parser: "npm:^3.0.5"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4
|
||||
checksum: 10c0/f466657c886cbb9f7ecbcd2dd1abc51a88af9d3f1cff030f7e97e70a4790a99f3338ad886e9c0dccf04dacdcc84522c7d57119b9a4e8e1d84f2dae9c893c397e
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/provider@npm:2.0.0, @ai-sdk/provider@npm:^2.0.0":
|
||||
version: 2.0.0
|
||||
resolution: "@ai-sdk/provider@npm:2.0.0"
|
||||
@ -2249,8 +2248,8 @@ __metadata:
|
||||
"@ai-sdk/anthropic": "npm:^2.0.5"
|
||||
"@ai-sdk/azure": "npm:^2.0.16"
|
||||
"@ai-sdk/deepseek": "npm:^1.0.9"
|
||||
"@ai-sdk/google": "npm:^2.0.7"
|
||||
"@ai-sdk/openai": "npm:^2.0.19"
|
||||
"@ai-sdk/google": "npm:^2.0.13"
|
||||
"@ai-sdk/openai": "npm:^2.0.26"
|
||||
"@ai-sdk/openai-compatible": "npm:^1.0.9"
|
||||
"@ai-sdk/provider": "npm:^2.0.0"
|
||||
"@ai-sdk/provider-utils": "npm:^3.0.4"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user