mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-05 04:19:02 +08:00
refactor(ToolCall): refactor:mcp-tool-state-management (#10028)
This commit is contained in:
parent
7df1060370
commit
f6ffd574bf
@ -1,103 +0,0 @@
|
|||||||
/**
|
|
||||||
* Hub Provider 使用示例
|
|
||||||
*
|
|
||||||
* 演示如何使用简化后的Hub Provider功能来路由到多个底层provider
|
|
||||||
*/
|
|
||||||
|
|
||||||
import { createHubProvider, initializeProvider, providerRegistry } from '../src/index'
|
|
||||||
|
|
||||||
async function demonstrateHubProvider() {
|
|
||||||
try {
|
|
||||||
// 1. 初始化底层providers
|
|
||||||
console.log('📦 初始化底层providers...')
|
|
||||||
|
|
||||||
initializeProvider('openai', {
|
|
||||||
apiKey: process.env.OPENAI_API_KEY || 'sk-test-key'
|
|
||||||
})
|
|
||||||
|
|
||||||
initializeProvider('anthropic', {
|
|
||||||
apiKey: process.env.ANTHROPIC_API_KEY || 'sk-ant-test-key'
|
|
||||||
})
|
|
||||||
|
|
||||||
// 2. 创建Hub Provider(自动包含所有已初始化的providers)
|
|
||||||
console.log('🌐 创建Hub Provider...')
|
|
||||||
|
|
||||||
const aihubmixProvider = createHubProvider({
|
|
||||||
hubId: 'aihubmix',
|
|
||||||
debug: true
|
|
||||||
})
|
|
||||||
|
|
||||||
// 3. 注册Hub Provider
|
|
||||||
providerRegistry.registerProvider('aihubmix', aihubmixProvider)
|
|
||||||
|
|
||||||
console.log('✅ Hub Provider "aihubmix" 注册成功')
|
|
||||||
|
|
||||||
// 4. 使用Hub Provider访问不同的模型
|
|
||||||
console.log('\n🚀 使用Hub模型...')
|
|
||||||
|
|
||||||
// 通过Hub路由到OpenAI
|
|
||||||
const openaiModel = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
|
||||||
console.log('✓ OpenAI模型已获取:', openaiModel.modelId)
|
|
||||||
|
|
||||||
// 通过Hub路由到Anthropic
|
|
||||||
const anthropicModel = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
|
||||||
console.log('✓ Anthropic模型已获取:', anthropicModel.modelId)
|
|
||||||
|
|
||||||
// 5. 演示错误处理
|
|
||||||
console.log('\n❌ 演示错误处理...')
|
|
||||||
|
|
||||||
try {
|
|
||||||
// 尝试访问未初始化的provider
|
|
||||||
providerRegistry.languageModel('aihubmix:google:gemini-pro')
|
|
||||||
} catch (error) {
|
|
||||||
console.log('预期错误:', error.message)
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
// 尝试使用错误的模型ID格式
|
|
||||||
providerRegistry.languageModel('aihubmix:invalid-format')
|
|
||||||
} catch (error) {
|
|
||||||
console.log('预期错误:', error.message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 6. 多个Hub Provider示例
|
|
||||||
console.log('\n🔄 创建多个Hub Provider...')
|
|
||||||
|
|
||||||
const localHubProvider = createHubProvider({
|
|
||||||
hubId: 'local-ai'
|
|
||||||
})
|
|
||||||
|
|
||||||
providerRegistry.registerProvider('local-ai', localHubProvider)
|
|
||||||
console.log('✅ Hub Provider "local-ai" 注册成功')
|
|
||||||
|
|
||||||
console.log('\n🎉 Hub Provider演示完成!')
|
|
||||||
} catch (error) {
|
|
||||||
console.error('💥 演示过程中发生错误:', error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 演示简化的使用方式
|
|
||||||
function simplifiedUsageExample() {
|
|
||||||
console.log('\n📝 简化使用示例:')
|
|
||||||
console.log(`
|
|
||||||
// 1. 初始化providers
|
|
||||||
initializeProvider('openai', { apiKey: 'sk-xxx' })
|
|
||||||
initializeProvider('anthropic', { apiKey: 'sk-ant-xxx' })
|
|
||||||
|
|
||||||
// 2. 创建并注册Hub Provider
|
|
||||||
const hubProvider = createHubProvider({ hubId: 'aihubmix' })
|
|
||||||
providerRegistry.registerProvider('aihubmix', hubProvider)
|
|
||||||
|
|
||||||
// 3. 直接使用
|
|
||||||
const model1 = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
|
||||||
const model2 = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
|
||||||
`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 运行演示
|
|
||||||
if (require.main === module) {
|
|
||||||
demonstrateHubProvider()
|
|
||||||
simplifiedUsageExample()
|
|
||||||
}
|
|
||||||
|
|
||||||
export { demonstrateHubProvider, simplifiedUsageExample }
|
|
||||||
@ -1,167 +0,0 @@
|
|||||||
/**
|
|
||||||
* Image Generation Example
|
|
||||||
* 演示如何使用 aiCore 的文生图功能
|
|
||||||
*/
|
|
||||||
|
|
||||||
import { createExecutor, generateImage } from '../src/index'
|
|
||||||
|
|
||||||
async function main() {
|
|
||||||
// 方式1: 使用执行器实例
|
|
||||||
console.log('📸 创建 OpenAI 图像生成执行器...')
|
|
||||||
const executor = createExecutor('openai', {
|
|
||||||
apiKey: process.env.OPENAI_API_KEY!
|
|
||||||
})
|
|
||||||
|
|
||||||
try {
|
|
||||||
console.log('🎨 使用执行器生成图像...')
|
|
||||||
const result1 = await executor.generateImage('dall-e-3', {
|
|
||||||
prompt: 'A futuristic cityscape at sunset with flying cars',
|
|
||||||
size: '1024x1024',
|
|
||||||
n: 1
|
|
||||||
})
|
|
||||||
|
|
||||||
console.log('✅ 图像生成成功!')
|
|
||||||
console.log('📊 结果:', {
|
|
||||||
imagesCount: result1.images.length,
|
|
||||||
mediaType: result1.image.mediaType,
|
|
||||||
hasBase64: !!result1.image.base64,
|
|
||||||
providerMetadata: result1.providerMetadata
|
|
||||||
})
|
|
||||||
} catch (error) {
|
|
||||||
console.error('❌ 执行器生成失败:', error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 方式2: 使用直接调用 API
|
|
||||||
try {
|
|
||||||
console.log('🎨 使用直接 API 生成图像...')
|
|
||||||
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
|
|
||||||
prompt: 'A magical forest with glowing mushrooms and fairy lights',
|
|
||||||
aspectRatio: '16:9',
|
|
||||||
providerOptions: {
|
|
||||||
openai: {
|
|
||||||
quality: 'hd',
|
|
||||||
style: 'vivid'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
console.log('✅ 直接 API 生成成功!')
|
|
||||||
console.log('📊 结果:', {
|
|
||||||
imagesCount: result2.images.length,
|
|
||||||
mediaType: result2.image.mediaType,
|
|
||||||
hasBase64: !!result2.image.base64
|
|
||||||
})
|
|
||||||
} catch (error) {
|
|
||||||
console.error('❌ 直接 API 生成失败:', error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 方式3: 支持其他提供商 (Google Imagen)
|
|
||||||
if (process.env.GOOGLE_API_KEY) {
|
|
||||||
try {
|
|
||||||
console.log('🎨 使用 Google Imagen 生成图像...')
|
|
||||||
const googleExecutor = createExecutor('google', {
|
|
||||||
apiKey: process.env.GOOGLE_API_KEY!
|
|
||||||
})
|
|
||||||
|
|
||||||
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
|
|
||||||
prompt: 'A serene mountain lake at dawn with mist rising from the water',
|
|
||||||
aspectRatio: '1:1'
|
|
||||||
})
|
|
||||||
|
|
||||||
console.log('✅ Google Imagen 生成成功!')
|
|
||||||
console.log('📊 结果:', {
|
|
||||||
imagesCount: result3.images.length,
|
|
||||||
mediaType: result3.image.mediaType,
|
|
||||||
hasBase64: !!result3.image.base64
|
|
||||||
})
|
|
||||||
} catch (error) {
|
|
||||||
console.error('❌ Google Imagen 生成失败:', error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 方式4: 支持插件系统
|
|
||||||
const pluginExample = async () => {
|
|
||||||
console.log('🔌 演示插件系统...')
|
|
||||||
|
|
||||||
// 创建一个示例插件,用于修改提示词
|
|
||||||
const promptEnhancerPlugin = {
|
|
||||||
name: 'prompt-enhancer',
|
|
||||||
transformParams: async (params: any) => {
|
|
||||||
console.log('🔧 插件: 增强提示词...')
|
|
||||||
return {
|
|
||||||
...params,
|
|
||||||
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
|
|
||||||
}
|
|
||||||
},
|
|
||||||
transformResult: async (result: any) => {
|
|
||||||
console.log('🔧 插件: 处理结果...')
|
|
||||||
return {
|
|
||||||
...result,
|
|
||||||
enhanced: true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const executorWithPlugin = createExecutor(
|
|
||||||
'openai',
|
|
||||||
{
|
|
||||||
apiKey: process.env.OPENAI_API_KEY!
|
|
||||||
},
|
|
||||||
[promptEnhancerPlugin]
|
|
||||||
)
|
|
||||||
|
|
||||||
try {
|
|
||||||
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
|
|
||||||
prompt: 'A cute robot playing in a garden'
|
|
||||||
})
|
|
||||||
|
|
||||||
console.log('✅ 插件系统生成成功!')
|
|
||||||
console.log('📊 结果:', {
|
|
||||||
imagesCount: result4.images.length,
|
|
||||||
enhanced: (result4 as any).enhanced,
|
|
||||||
mediaType: result4.image.mediaType
|
|
||||||
})
|
|
||||||
} catch (error) {
|
|
||||||
console.error('❌ 插件系统生成失败:', error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
await pluginExample()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 错误处理演示
|
|
||||||
async function errorHandlingExample() {
|
|
||||||
console.log('⚠️ 演示错误处理...')
|
|
||||||
|
|
||||||
try {
|
|
||||||
const executor = createExecutor('openai', {
|
|
||||||
apiKey: 'invalid-key'
|
|
||||||
})
|
|
||||||
|
|
||||||
await executor.generateImage('dall-e-3', {
|
|
||||||
prompt: 'Test image'
|
|
||||||
})
|
|
||||||
} catch (error: any) {
|
|
||||||
console.log('✅ 成功捕获错误:', error.constructor.name)
|
|
||||||
console.log('📋 错误信息:', error.message)
|
|
||||||
console.log('🏷️ 提供商ID:', error.providerId)
|
|
||||||
console.log('🏷️ 模型ID:', error.modelId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 运行示例
|
|
||||||
if (require.main === module) {
|
|
||||||
main()
|
|
||||||
.then(() => {
|
|
||||||
console.log('🎉 所有示例完成!')
|
|
||||||
return errorHandlingExample()
|
|
||||||
})
|
|
||||||
.then(() => {
|
|
||||||
console.log('🎯 示例程序结束')
|
|
||||||
process.exit(0)
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
console.error('💥 程序执行出错:', error)
|
|
||||||
process.exit(1)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@cherrystudio/ai-core",
|
"name": "@cherrystudio/ai-core",
|
||||||
"version": "1.0.0-alpha.12",
|
"version": "1.0.0-alpha.13",
|
||||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"module": "dist/index.mjs",
|
"module": "dist/index.mjs",
|
||||||
@ -39,8 +39,8 @@
|
|||||||
"@ai-sdk/anthropic": "^2.0.5",
|
"@ai-sdk/anthropic": "^2.0.5",
|
||||||
"@ai-sdk/azure": "^2.0.16",
|
"@ai-sdk/azure": "^2.0.16",
|
||||||
"@ai-sdk/deepseek": "^1.0.9",
|
"@ai-sdk/deepseek": "^1.0.9",
|
||||||
"@ai-sdk/google": "^2.0.7",
|
"@ai-sdk/google": "^2.0.13",
|
||||||
"@ai-sdk/openai": "^2.0.19",
|
"@ai-sdk/openai": "^2.0.26",
|
||||||
"@ai-sdk/openai-compatible": "^1.0.9",
|
"@ai-sdk/openai-compatible": "^1.0.9",
|
||||||
"@ai-sdk/provider": "^2.0.0",
|
"@ai-sdk/provider": "^2.0.0",
|
||||||
"@ai-sdk/provider-utils": "^3.0.4",
|
"@ai-sdk/provider-utils": "^3.0.4",
|
||||||
|
|||||||
@ -84,7 +84,6 @@ export class ModelResolver {
|
|||||||
*/
|
*/
|
||||||
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
|
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
|
||||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||||
console.log('fullModelId', fullModelId)
|
|
||||||
return globalRegistryManagement.languageModel(fullModelId as any)
|
return globalRegistryManagement.languageModel(fullModelId as any)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -27,10 +27,20 @@ export class StreamEventManager {
|
|||||||
/**
|
/**
|
||||||
* 发送步骤完成事件
|
* 发送步骤完成事件
|
||||||
*/
|
*/
|
||||||
sendStepFinishEvent(controller: StreamController, chunk: any): void {
|
sendStepFinishEvent(
|
||||||
|
controller: StreamController,
|
||||||
|
chunk: any,
|
||||||
|
context: AiRequestContext,
|
||||||
|
finishReason: string = 'stop'
|
||||||
|
): void {
|
||||||
|
// 累加当前步骤的 usage
|
||||||
|
if (chunk.usage && context.accumulatedUsage) {
|
||||||
|
this.accumulateUsage(context.accumulatedUsage, chunk.usage)
|
||||||
|
}
|
||||||
|
|
||||||
controller.enqueue({
|
controller.enqueue({
|
||||||
type: 'finish-step',
|
type: 'finish-step',
|
||||||
finishReason: 'stop',
|
finishReason,
|
||||||
response: chunk.response,
|
response: chunk.response,
|
||||||
usage: chunk.usage,
|
usage: chunk.usage,
|
||||||
providerMetadata: chunk.providerMetadata
|
providerMetadata: chunk.providerMetadata
|
||||||
@ -43,28 +53,32 @@ export class StreamEventManager {
|
|||||||
async handleRecursiveCall(
|
async handleRecursiveCall(
|
||||||
controller: StreamController,
|
controller: StreamController,
|
||||||
recursiveParams: any,
|
recursiveParams: any,
|
||||||
context: AiRequestContext,
|
context: AiRequestContext
|
||||||
stepId: string
|
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
try {
|
// try {
|
||||||
console.log('[MCP Prompt] Starting recursive call after tool execution...')
|
// 重置工具执行状态,准备处理新的步骤
|
||||||
|
context.hasExecutedToolsInCurrentStep = false
|
||||||
|
|
||||||
const recursiveResult = await context.recursiveCall(recursiveParams)
|
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||||
|
|
||||||
if (recursiveResult && recursiveResult.fullStream) {
|
if (recursiveResult && recursiveResult.fullStream) {
|
||||||
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
|
await this.pipeRecursiveStream(controller, recursiveResult.fullStream, context)
|
||||||
} else {
|
} else {
|
||||||
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
this.handleRecursiveCallError(controller, error, stepId)
|
|
||||||
}
|
}
|
||||||
|
// } catch (error) {
|
||||||
|
// this.handleRecursiveCallError(controller, error, stepId)
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 将递归流的数据传递到当前流
|
* 将递归流的数据传递到当前流
|
||||||
*/
|
*/
|
||||||
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
|
private async pipeRecursiveStream(
|
||||||
|
controller: StreamController,
|
||||||
|
recursiveStream: ReadableStream,
|
||||||
|
context?: AiRequestContext
|
||||||
|
): Promise<void> {
|
||||||
const reader = recursiveStream.getReader()
|
const reader = recursiveStream.getReader()
|
||||||
try {
|
try {
|
||||||
while (true) {
|
while (true) {
|
||||||
@ -73,9 +87,16 @@ export class StreamEventManager {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
if (value.type === 'finish') {
|
if (value.type === 'finish') {
|
||||||
// 迭代的流不发finish
|
// 迭代的流不发finish,但需要累加其 usage
|
||||||
|
if (value.usage && context?.accumulatedUsage) {
|
||||||
|
this.accumulateUsage(context.accumulatedUsage, value.usage)
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
// 对于 finish-step 类型,累加其 usage
|
||||||
|
if (value.type === 'finish-step' && value.usage && context?.accumulatedUsage) {
|
||||||
|
this.accumulateUsage(context.accumulatedUsage, value.usage)
|
||||||
|
}
|
||||||
// 将递归流的数据传递到当前流
|
// 将递归流的数据传递到当前流
|
||||||
controller.enqueue(value)
|
controller.enqueue(value)
|
||||||
}
|
}
|
||||||
@ -87,25 +108,25 @@ export class StreamEventManager {
|
|||||||
/**
|
/**
|
||||||
* 处理递归调用错误
|
* 处理递归调用错误
|
||||||
*/
|
*/
|
||||||
private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void {
|
// private handleRecursiveCallError(controller: StreamController, error: unknown): void {
|
||||||
console.error('[MCP Prompt] Recursive call failed:', error)
|
// console.error('[MCP Prompt] Recursive call failed:', error)
|
||||||
|
|
||||||
// 使用 AI SDK 标准错误格式,但不中断流
|
// // 使用 AI SDK 标准错误格式,但不中断流
|
||||||
controller.enqueue({
|
// controller.enqueue({
|
||||||
type: 'error',
|
// type: 'error',
|
||||||
error: {
|
// error: {
|
||||||
message: error instanceof Error ? error.message : String(error),
|
// message: error instanceof Error ? error.message : String(error),
|
||||||
name: error instanceof Error ? error.name : 'RecursiveCallError'
|
// name: error instanceof Error ? error.name : 'RecursiveCallError'
|
||||||
}
|
// }
|
||||||
})
|
// })
|
||||||
|
|
||||||
// 继续发送文本增量,保持流的连续性
|
// // // 继续发送文本增量,保持流的连续性
|
||||||
controller.enqueue({
|
// // controller.enqueue({
|
||||||
type: 'text-delta',
|
// // type: 'text-delta',
|
||||||
id: stepId,
|
// // id: stepId,
|
||||||
text: '\n\n[工具执行后递归调用失败,继续对话...]'
|
// // text: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||||
})
|
// // })
|
||||||
}
|
// }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 构建递归调用的参数
|
* 构建递归调用的参数
|
||||||
@ -136,4 +157,18 @@ export class StreamEventManager {
|
|||||||
|
|
||||||
return recursiveParams
|
return recursiveParams
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 累加 usage 数据
|
||||||
|
*/
|
||||||
|
private accumulateUsage(target: any, source: any): void {
|
||||||
|
if (!target || !source) return
|
||||||
|
|
||||||
|
// 累加各种 token 类型
|
||||||
|
target.inputTokens = (target.inputTokens || 0) + (source.inputTokens || 0)
|
||||||
|
target.outputTokens = (target.outputTokens || 0) + (source.outputTokens || 0)
|
||||||
|
target.totalTokens = (target.totalTokens || 0) + (source.totalTokens || 0)
|
||||||
|
target.reasoningTokens = (target.reasoningTokens || 0) + (source.reasoningTokens || 0)
|
||||||
|
target.cachedInputTokens = (target.cachedInputTokens || 0) + (source.cachedInputTokens || 0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
* 负责工具的执行、结果格式化和相关事件发送
|
* 负责工具的执行、结果格式化和相关事件发送
|
||||||
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||||
*/
|
*/
|
||||||
import type { ToolSet } from 'ai'
|
import type { ToolSet, TypedToolError } from 'ai'
|
||||||
|
|
||||||
import type { ToolUseResult } from './type'
|
import type { ToolUseResult } from './type'
|
||||||
|
|
||||||
@ -38,7 +38,6 @@ export class ToolExecutor {
|
|||||||
controller: StreamController
|
controller: StreamController
|
||||||
): Promise<ExecutedResult[]> {
|
): Promise<ExecutedResult[]> {
|
||||||
const executedResults: ExecutedResult[] = []
|
const executedResults: ExecutedResult[] = []
|
||||||
|
|
||||||
for (const toolUse of toolUses) {
|
for (const toolUse of toolUses) {
|
||||||
try {
|
try {
|
||||||
const tool = tools[toolUse.toolName]
|
const tool = tools[toolUse.toolName]
|
||||||
@ -46,17 +45,12 @@ export class ToolExecutor {
|
|||||||
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
|
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发送工具调用开始事件
|
|
||||||
this.sendToolStartEvents(controller, toolUse)
|
|
||||||
|
|
||||||
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
|
|
||||||
|
|
||||||
// 发送 tool-call 事件
|
// 发送 tool-call 事件
|
||||||
controller.enqueue({
|
controller.enqueue({
|
||||||
type: 'tool-call',
|
type: 'tool-call',
|
||||||
toolCallId: toolUse.id,
|
toolCallId: toolUse.id,
|
||||||
toolName: toolUse.toolName,
|
toolName: toolUse.toolName,
|
||||||
input: tool.inputSchema
|
input: toolUse.arguments
|
||||||
})
|
})
|
||||||
|
|
||||||
const result = await tool.execute(toolUse.arguments, {
|
const result = await tool.execute(toolUse.arguments, {
|
||||||
@ -111,45 +105,46 @@ export class ToolExecutor {
|
|||||||
/**
|
/**
|
||||||
* 发送工具调用开始相关事件
|
* 发送工具调用开始相关事件
|
||||||
*/
|
*/
|
||||||
private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
|
// private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
|
||||||
// 发送 tool-input-start 事件
|
// // 发送 tool-input-start 事件
|
||||||
controller.enqueue({
|
// controller.enqueue({
|
||||||
type: 'tool-input-start',
|
// type: 'tool-input-start',
|
||||||
id: toolUse.id,
|
// id: toolUse.id,
|
||||||
toolName: toolUse.toolName
|
// toolName: toolUse.toolName
|
||||||
})
|
// })
|
||||||
}
|
// }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 处理工具执行错误
|
* 处理工具执行错误
|
||||||
*/
|
*/
|
||||||
private handleToolError(
|
private handleToolError<T extends ToolSet>(
|
||||||
toolUse: ToolUseResult,
|
toolUse: ToolUseResult,
|
||||||
error: unknown,
|
error: unknown,
|
||||||
controller: StreamController
|
controller: StreamController
|
||||||
// _tools: ToolSet
|
|
||||||
): ExecutedResult {
|
): ExecutedResult {
|
||||||
// 使用 AI SDK 标准错误格式
|
// 使用 AI SDK 标准错误格式
|
||||||
// const toolError: TypedToolError<typeof _tools> = {
|
const toolError: TypedToolError<T> = {
|
||||||
// type: 'tool-error',
|
type: 'tool-error',
|
||||||
// toolCallId: toolUse.id,
|
toolCallId: toolUse.id,
|
||||||
// toolName: toolUse.toolName,
|
toolName: toolUse.toolName,
|
||||||
// input: toolUse.arguments,
|
input: toolUse.arguments,
|
||||||
// error: error instanceof Error ? error.message : String(error)
|
error
|
||||||
// }
|
}
|
||||||
|
|
||||||
// controller.enqueue(toolError)
|
controller.enqueue(toolError)
|
||||||
|
|
||||||
// 发送标准错误事件
|
// 发送标准错误事件
|
||||||
controller.enqueue({
|
// controller.enqueue({
|
||||||
type: 'error',
|
// type: 'tool-error',
|
||||||
error: error instanceof Error ? error.message : String(error)
|
// toolCallId: toolUse.id,
|
||||||
})
|
// error: error instanceof Error ? error.message : String(error),
|
||||||
|
// input: toolUse.arguments
|
||||||
|
// })
|
||||||
|
|
||||||
return {
|
return {
|
||||||
toolCallId: toolUse.id,
|
toolCallId: toolUse.id,
|
||||||
toolName: toolUse.toolName,
|
toolName: toolUse.toolName,
|
||||||
result: error instanceof Error ? error.message : String(error),
|
result: error,
|
||||||
isError: true
|
isError: true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,9 +8,19 @@ import type { TextStreamPart, ToolSet } from 'ai'
|
|||||||
import { definePlugin } from '../../index'
|
import { definePlugin } from '../../index'
|
||||||
import type { AiRequestContext } from '../../types'
|
import type { AiRequestContext } from '../../types'
|
||||||
import { StreamEventManager } from './StreamEventManager'
|
import { StreamEventManager } from './StreamEventManager'
|
||||||
|
import { type TagConfig, TagExtractor } from './tagExtraction'
|
||||||
import { ToolExecutor } from './ToolExecutor'
|
import { ToolExecutor } from './ToolExecutor'
|
||||||
import { PromptToolUseConfig, ToolUseResult } from './type'
|
import { PromptToolUseConfig, ToolUseResult } from './type'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 工具使用标签配置
|
||||||
|
*/
|
||||||
|
const TOOL_USE_TAG_CONFIG: TagConfig = {
|
||||||
|
openingTag: '<tool_use>',
|
||||||
|
closingTag: '</tool_use>',
|
||||||
|
separator: '\n'
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 默认系统提示符模板(提取自 Cherry Studio)
|
* 默认系统提示符模板(提取自 Cherry Studio)
|
||||||
*/
|
*/
|
||||||
@ -249,13 +259,11 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
context.mcpTools = params.tools
|
context.mcpTools = params.tools
|
||||||
console.log('tools stored in context', params.tools)
|
|
||||||
|
|
||||||
// 构建系统提示符
|
// 构建系统提示符
|
||||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||||
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
|
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
|
||||||
let systemMessage: string | null = systemPrompt
|
let systemMessage: string | null = systemPrompt
|
||||||
console.log('config.context', context)
|
|
||||||
if (config.createSystemMessage) {
|
if (config.createSystemMessage) {
|
||||||
// 🎯 如果用户提供了自定义处理函数,使用它
|
// 🎯 如果用户提供了自定义处理函数,使用它
|
||||||
systemMessage = config.createSystemMessage(systemPrompt, params, context)
|
systemMessage = config.createSystemMessage(systemPrompt, params, context)
|
||||||
@ -268,20 +276,40 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
|||||||
tools: undefined
|
tools: undefined
|
||||||
}
|
}
|
||||||
context.originalParams = transformedParams
|
context.originalParams = transformedParams
|
||||||
console.log('transformedParams', transformedParams)
|
|
||||||
return transformedParams
|
return transformedParams
|
||||||
},
|
},
|
||||||
transformStream: (_: any, context: AiRequestContext) => () => {
|
transformStream: (_: any, context: AiRequestContext) => () => {
|
||||||
let textBuffer = ''
|
let textBuffer = ''
|
||||||
let stepId = ''
|
// let stepId = ''
|
||||||
|
|
||||||
if (!context.mcpTools) {
|
if (!context.mcpTools) {
|
||||||
throw new Error('No tools available')
|
throw new Error('No tools available')
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建工具执行器和流事件管理器
|
// 从 context 中获取或初始化 usage 累加器
|
||||||
|
if (!context.accumulatedUsage) {
|
||||||
|
context.accumulatedUsage = {
|
||||||
|
inputTokens: 0,
|
||||||
|
outputTokens: 0,
|
||||||
|
totalTokens: 0,
|
||||||
|
reasoningTokens: 0,
|
||||||
|
cachedInputTokens: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建工具执行器、流事件管理器和标签提取器
|
||||||
const toolExecutor = new ToolExecutor()
|
const toolExecutor = new ToolExecutor()
|
||||||
const streamEventManager = new StreamEventManager()
|
const streamEventManager = new StreamEventManager()
|
||||||
|
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
|
||||||
|
|
||||||
|
// 在context中初始化工具执行状态,避免递归调用时状态丢失
|
||||||
|
if (!context.hasExecutedToolsInCurrentStep) {
|
||||||
|
context.hasExecutedToolsInCurrentStep = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用于hold text-start事件,直到确认有非工具标签内容
|
||||||
|
let pendingTextStart: TextStreamPart<TOOLS> | null = null
|
||||||
|
let hasStartedText = false
|
||||||
|
|
||||||
type TOOLS = NonNullable<typeof context.mcpTools>
|
type TOOLS = NonNullable<typeof context.mcpTools>
|
||||||
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
|
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
|
||||||
@ -289,83 +317,106 @@ export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
|||||||
chunk: TextStreamPart<TOOLS>,
|
chunk: TextStreamPart<TOOLS>,
|
||||||
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
|
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
|
||||||
) {
|
) {
|
||||||
// 收集文本内容
|
// Hold住text-start事件,直到确认有非工具标签内容
|
||||||
if (chunk.type === 'text-delta') {
|
if ((chunk as any).type === 'text-start') {
|
||||||
textBuffer += chunk.text || ''
|
pendingTextStart = chunk
|
||||||
stepId = chunk.id || ''
|
|
||||||
controller.enqueue(chunk)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
|
// text-delta阶段:收集文本内容并过滤工具标签
|
||||||
const tools = context.mcpTools
|
if (chunk.type === 'text-delta') {
|
||||||
if (!tools || Object.keys(tools).length === 0) {
|
textBuffer += chunk.text || ''
|
||||||
controller.enqueue(chunk)
|
// stepId = chunk.id || ''
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析工具调用
|
// 使用TagExtractor过滤工具标签,只传递非标签内容到UI层
|
||||||
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
|
const extractionResults = tagExtractor.processText(chunk.text || '')
|
||||||
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
|
||||||
|
|
||||||
// 如果没有有效的工具调用,直接传递原始事件
|
for (const result of extractionResults) {
|
||||||
if (validToolUses.length === 0) {
|
// 只传递非标签内容到UI层
|
||||||
controller.enqueue(chunk)
|
if (!result.isTagContent && result.content) {
|
||||||
return
|
// 如果还没有发送text-start且有pending的text-start,先发送它
|
||||||
}
|
if (!hasStartedText && pendingTextStart) {
|
||||||
|
controller.enqueue(pendingTextStart)
|
||||||
if (chunk.type === 'text-end') {
|
hasStartedText = true
|
||||||
controller.enqueue({
|
pendingTextStart = null
|
||||||
type: 'text-end',
|
|
||||||
id: stepId,
|
|
||||||
providerMetadata: {
|
|
||||||
text: {
|
|
||||||
value: parsedContent
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
|
||||||
return
|
const filteredChunk = {
|
||||||
|
...chunk,
|
||||||
|
text: result.content
|
||||||
|
}
|
||||||
|
controller.enqueue(filteredChunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chunk.type === 'text-end') {
|
||||||
|
// 只有当已经发送了text-start时才发送text-end
|
||||||
|
if (hasStartedText) {
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chunk.type === 'finish-step') {
|
||||||
|
// 统一在finish-step阶段检查并执行工具调用
|
||||||
|
const tools = context.mcpTools
|
||||||
|
if (tools && Object.keys(tools).length > 0 && !context.hasExecutedToolsInCurrentStep) {
|
||||||
|
// 解析完整的textBuffer来检测工具调用
|
||||||
|
const { results: parsedTools } = parseToolUse(textBuffer, tools)
|
||||||
|
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||||
|
|
||||||
|
if (validToolUses.length > 0) {
|
||||||
|
context.hasExecutedToolsInCurrentStep = true
|
||||||
|
|
||||||
|
// 执行工具调用(不需要手动发送 start-step,外部流已经处理)
|
||||||
|
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
|
||||||
|
|
||||||
|
// 发送步骤完成事件,使用 tool-calls 作为 finishReason
|
||||||
|
streamEventManager.sendStepFinishEvent(controller, chunk, context, 'tool-calls')
|
||||||
|
|
||||||
|
// 处理递归调用
|
||||||
|
const toolResultsText = toolExecutor.formatToolResults(executedResults)
|
||||||
|
const recursiveParams = streamEventManager.buildRecursiveParams(
|
||||||
|
context,
|
||||||
|
textBuffer,
|
||||||
|
toolResultsText,
|
||||||
|
tools
|
||||||
|
)
|
||||||
|
|
||||||
|
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.enqueue({
|
// 如果没有执行工具调用,直接传递原始finish-step事件
|
||||||
...chunk,
|
controller.enqueue(chunk)
|
||||||
finishReason: 'tool-calls'
|
|
||||||
})
|
|
||||||
|
|
||||||
// 发送步骤开始事件
|
|
||||||
streamEventManager.sendStepStartEvent(controller)
|
|
||||||
|
|
||||||
// 执行工具调用
|
|
||||||
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
|
|
||||||
|
|
||||||
// 发送步骤完成事件
|
|
||||||
streamEventManager.sendStepFinishEvent(controller, chunk)
|
|
||||||
|
|
||||||
// 处理递归调用
|
|
||||||
if (validToolUses.length > 0) {
|
|
||||||
const toolResultsText = toolExecutor.formatToolResults(executedResults)
|
|
||||||
const recursiveParams = streamEventManager.buildRecursiveParams(
|
|
||||||
context,
|
|
||||||
textBuffer,
|
|
||||||
toolResultsText,
|
|
||||||
tools
|
|
||||||
)
|
|
||||||
|
|
||||||
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 清理状态
|
// 清理状态
|
||||||
textBuffer = ''
|
textBuffer = ''
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 对于其他类型的事件,直接传递
|
// 处理 finish 类型,使用累加后的 totalUsage
|
||||||
controller.enqueue(chunk)
|
if (chunk.type === 'finish') {
|
||||||
|
controller.enqueue({
|
||||||
|
...chunk,
|
||||||
|
totalUsage: context.accumulatedUsage
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对于其他类型的事件,直接传递(不包括text-start,已在上面处理)
|
||||||
|
if ((chunk as any).type !== 'text-start') {
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
flush() {
|
flush() {
|
||||||
// 流结束时的清理工作
|
// 清理pending状态
|
||||||
console.log('[MCP Prompt] Stream ended, cleaning up...')
|
pendingTextStart = null
|
||||||
|
hasStartedText = false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -27,7 +27,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
|
|||||||
case 'openai': {
|
case 'openai': {
|
||||||
if (config.openai) {
|
if (config.openai) {
|
||||||
if (!params.tools) params.tools = {}
|
if (!params.tools) params.tools = {}
|
||||||
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
|
params.tools.web_search = openai.tools.webSearch(config.openai)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,8 @@
|
|||||||
// 核心类型和接口
|
// 核心类型和接口
|
||||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
|
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
|
||||||
|
import type { ImageModelV2 } from '@ai-sdk/provider'
|
||||||
|
import type { LanguageModel } from 'ai'
|
||||||
|
|
||||||
import type { ProviderId } from '../providers'
|
import type { ProviderId } from '../providers'
|
||||||
import type { AiPlugin, AiRequestContext } from './types'
|
import type { AiPlugin, AiRequestContext } from './types'
|
||||||
|
|
||||||
@ -9,16 +12,16 @@ export { PluginManager } from './manager'
|
|||||||
// 工具函数
|
// 工具函数
|
||||||
export function createContext<T extends ProviderId>(
|
export function createContext<T extends ProviderId>(
|
||||||
providerId: T,
|
providerId: T,
|
||||||
modelId: string,
|
model: LanguageModel | ImageModelV2,
|
||||||
originalParams: any
|
originalParams: any
|
||||||
): AiRequestContext {
|
): AiRequestContext {
|
||||||
return {
|
return {
|
||||||
providerId,
|
providerId,
|
||||||
modelId,
|
model,
|
||||||
originalParams,
|
originalParams,
|
||||||
metadata: {},
|
metadata: {},
|
||||||
startTime: Date.now(),
|
startTime: Date.now(),
|
||||||
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
requestId: `${providerId}-${typeof model === 'string' ? model : model?.modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||||
// 占位
|
// 占位
|
||||||
recursiveCall: () => Promise.resolve(null)
|
recursiveCall: () => Promise.resolve(null)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,7 +14,7 @@ export type RecursiveCallFn = (newParams: any) => Promise<any>
|
|||||||
*/
|
*/
|
||||||
export interface AiRequestContext {
|
export interface AiRequestContext {
|
||||||
providerId: ProviderId
|
providerId: ProviderId
|
||||||
modelId: string
|
model: LanguageModel | ImageModelV2
|
||||||
originalParams: any
|
originalParams: any
|
||||||
metadata: Record<string, any>
|
metadata: Record<string, any>
|
||||||
startTime: number
|
startTime: number
|
||||||
|
|||||||
@ -83,7 +83,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 使用正确的createContext创建请求上下文
|
// 使用正确的createContext创建请求上下文
|
||||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
const context = _context ? _context : createContext(this.providerId, model, params)
|
||||||
|
|
||||||
// 🔥 为上下文添加递归调用能力
|
// 🔥 为上下文添加递归调用能力
|
||||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||||
@ -159,7 +159,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 使用正确的createContext创建请求上下文
|
// 使用正确的createContext创建请求上下文
|
||||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
const context = _context ? _context : createContext(this.providerId, model, params)
|
||||||
|
|
||||||
// 🔥 为上下文添加递归调用能力
|
// 🔥 为上下文添加递归调用能力
|
||||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||||
@ -235,7 +235,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建请求上下文
|
// 创建请求上下文
|
||||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
const context = _context ? _context : createContext(this.providerId, model, params)
|
||||||
|
|
||||||
// 🔥 为上下文添加递归调用能力
|
// 🔥 为上下文添加递归调用能力
|
||||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||||
|
|||||||
@ -152,12 +152,14 @@ export class AiSdkToChunkAdapter {
|
|||||||
// this.toolCallHandler.handleToolCallCreated(chunk)
|
// this.toolCallHandler.handleToolCallCreated(chunk)
|
||||||
// break
|
// break
|
||||||
case 'tool-call':
|
case 'tool-call':
|
||||||
// 原始的工具调用(未被中间件处理)
|
|
||||||
this.toolCallHandler.handleToolCall(chunk)
|
this.toolCallHandler.handleToolCall(chunk)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
case 'tool-error':
|
||||||
|
this.toolCallHandler.handleToolError(chunk)
|
||||||
|
break
|
||||||
|
|
||||||
case 'tool-result':
|
case 'tool-result':
|
||||||
// 原始的工具调用结果(未被中间件处理)
|
|
||||||
this.toolCallHandler.handleToolResult(chunk)
|
this.toolCallHandler.handleToolResult(chunk)
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -167,7 +169,6 @@ export class AiSdkToChunkAdapter {
|
|||||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||||
// })
|
// })
|
||||||
// break
|
// break
|
||||||
// TODO: 需要区分接口开始和步骤开始
|
|
||||||
// case 'start-step':
|
// case 'start-step':
|
||||||
// this.onChunk({
|
// this.onChunk({
|
||||||
// type: ChunkType.BLOCK_CREATED
|
// type: ChunkType.BLOCK_CREATED
|
||||||
@ -305,8 +306,6 @@ export class AiSdkToChunkAdapter {
|
|||||||
break
|
break
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// 其他类型的 chunk 可以忽略或记录日志
|
|
||||||
// console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,34 +8,61 @@ import { loggerService } from '@logger'
|
|||||||
import { processKnowledgeReferences } from '@renderer/services/KnowledgeService'
|
import { processKnowledgeReferences } from '@renderer/services/KnowledgeService'
|
||||||
import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
|
import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
|
||||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||||
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai'
|
import type { ToolSet, TypedToolCall, TypedToolError, TypedToolResult } from 'ai'
|
||||||
// import type {
|
|
||||||
// AnthropicSearchOutput,
|
|
||||||
// WebSearchPluginConfig
|
|
||||||
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('ToolCallChunkHandler')
|
const logger = loggerService.withContext('ToolCallChunkHandler')
|
||||||
|
|
||||||
|
export type ToolcallsMap = {
|
||||||
|
toolCallId: string
|
||||||
|
toolName: string
|
||||||
|
args: any
|
||||||
|
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
|
||||||
|
tool: BaseTool
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* 工具调用处理器类
|
* 工具调用处理器类
|
||||||
*/
|
*/
|
||||||
export class ToolCallChunkHandler {
|
export class ToolCallChunkHandler {
|
||||||
// private onChunk: (chunk: Chunk) => void
|
private static globalActiveToolCalls = new Map<string, ToolcallsMap>()
|
||||||
private activeToolCalls = new Map<
|
|
||||||
string,
|
private activeToolCalls = ToolCallChunkHandler.globalActiveToolCalls
|
||||||
{
|
|
||||||
toolCallId: string
|
|
||||||
toolName: string
|
|
||||||
args: any
|
|
||||||
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
|
|
||||||
tool: BaseTool
|
|
||||||
}
|
|
||||||
>()
|
|
||||||
constructor(
|
constructor(
|
||||||
private onChunk: (chunk: Chunk) => void,
|
private onChunk: (chunk: Chunk) => void,
|
||||||
private mcpTools: MCPTool[]
|
private mcpTools: MCPTool[]
|
||||||
) {}
|
) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 内部静态方法:添加活跃工具调用的核心逻辑
|
||||||
|
*/
|
||||||
|
private static addActiveToolCallImpl(toolCallId: string, map: ToolcallsMap): boolean {
|
||||||
|
if (!ToolCallChunkHandler.globalActiveToolCalls.has(toolCallId)) {
|
||||||
|
ToolCallChunkHandler.globalActiveToolCalls.set(toolCallId, map)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 实例方法:添加活跃工具调用
|
||||||
|
*/
|
||||||
|
private addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
|
||||||
|
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取全局活跃的工具调用
|
||||||
|
*/
|
||||||
|
public static getActiveToolCalls() {
|
||||||
|
return ToolCallChunkHandler.globalActiveToolCalls
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 静态方法:添加活跃工具调用(外部访问)
|
||||||
|
*/
|
||||||
|
public static addActiveToolCall(toolCallId: string, map: ToolcallsMap): boolean {
|
||||||
|
return ToolCallChunkHandler.addActiveToolCallImpl(toolCallId, map)
|
||||||
|
}
|
||||||
|
|
||||||
// /**
|
// /**
|
||||||
// * 设置 onChunk 回调
|
// * 设置 onChunk 回调
|
||||||
// */
|
// */
|
||||||
@ -43,103 +70,103 @@ export class ToolCallChunkHandler {
|
|||||||
// this.onChunk = callback
|
// this.onChunk = callback
|
||||||
// }
|
// }
|
||||||
|
|
||||||
handleToolCallCreated(
|
// handleToolCallCreated(
|
||||||
chunk:
|
// chunk:
|
||||||
| {
|
// | {
|
||||||
type: 'tool-input-start'
|
// type: 'tool-input-start'
|
||||||
id: string
|
// id: string
|
||||||
toolName: string
|
// toolName: string
|
||||||
providerMetadata?: ProviderMetadata
|
// providerMetadata?: ProviderMetadata
|
||||||
providerExecuted?: boolean
|
// providerExecuted?: boolean
|
||||||
}
|
// }
|
||||||
| {
|
// | {
|
||||||
type: 'tool-input-end'
|
// type: 'tool-input-end'
|
||||||
id: string
|
// id: string
|
||||||
providerMetadata?: ProviderMetadata
|
// providerMetadata?: ProviderMetadata
|
||||||
}
|
// }
|
||||||
| {
|
// | {
|
||||||
type: 'tool-input-delta'
|
// type: 'tool-input-delta'
|
||||||
id: string
|
// id: string
|
||||||
delta: string
|
// delta: string
|
||||||
providerMetadata?: ProviderMetadata
|
// providerMetadata?: ProviderMetadata
|
||||||
}
|
// }
|
||||||
): void {
|
// ): void {
|
||||||
switch (chunk.type) {
|
// switch (chunk.type) {
|
||||||
case 'tool-input-start': {
|
// case 'tool-input-start': {
|
||||||
// 能拿到说明是mcpTool
|
// // 能拿到说明是mcpTool
|
||||||
// if (this.activeToolCalls.get(chunk.id)) return
|
// // if (this.activeToolCalls.get(chunk.id)) return
|
||||||
|
|
||||||
const tool: BaseTool | MCPTool = {
|
// const tool: BaseTool | MCPTool = {
|
||||||
id: chunk.id,
|
// id: chunk.id,
|
||||||
name: chunk.toolName,
|
// name: chunk.toolName,
|
||||||
description: chunk.toolName,
|
// description: chunk.toolName,
|
||||||
type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
|
// type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
|
||||||
}
|
// }
|
||||||
this.activeToolCalls.set(chunk.id, {
|
// this.activeToolCalls.set(chunk.id, {
|
||||||
toolCallId: chunk.id,
|
// toolCallId: chunk.id,
|
||||||
toolName: chunk.toolName,
|
// toolName: chunk.toolName,
|
||||||
args: '',
|
// args: '',
|
||||||
tool
|
// tool
|
||||||
})
|
// })
|
||||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
// const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||||
id: chunk.id,
|
// id: chunk.id,
|
||||||
tool: tool,
|
// tool: tool,
|
||||||
arguments: {},
|
// arguments: {},
|
||||||
status: 'pending',
|
// status: 'pending',
|
||||||
toolCallId: chunk.id
|
// toolCallId: chunk.id
|
||||||
}
|
// }
|
||||||
this.onChunk({
|
// this.onChunk({
|
||||||
type: ChunkType.MCP_TOOL_PENDING,
|
// type: ChunkType.MCP_TOOL_PENDING,
|
||||||
responses: [toolResponse]
|
// responses: [toolResponse]
|
||||||
})
|
// })
|
||||||
break
|
// break
|
||||||
}
|
// }
|
||||||
case 'tool-input-delta': {
|
// case 'tool-input-delta': {
|
||||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
// const toolCall = this.activeToolCalls.get(chunk.id)
|
||||||
if (!toolCall) {
|
// if (!toolCall) {
|
||||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
// logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||||
return
|
// return
|
||||||
}
|
// }
|
||||||
toolCall.args += chunk.delta
|
// toolCall.args += chunk.delta
|
||||||
break
|
// break
|
||||||
}
|
// }
|
||||||
case 'tool-input-end': {
|
// case 'tool-input-end': {
|
||||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
// const toolCall = this.activeToolCalls.get(chunk.id)
|
||||||
this.activeToolCalls.delete(chunk.id)
|
// this.activeToolCalls.delete(chunk.id)
|
||||||
if (!toolCall) {
|
// if (!toolCall) {
|
||||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
// logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||||
return
|
// return
|
||||||
}
|
// }
|
||||||
// const toolResponse: ToolCallResponse = {
|
// // const toolResponse: ToolCallResponse = {
|
||||||
// id: toolCall.toolCallId,
|
// // id: toolCall.toolCallId,
|
||||||
// tool: toolCall.tool,
|
// // tool: toolCall.tool,
|
||||||
// arguments: toolCall.args,
|
// // arguments: toolCall.args,
|
||||||
// status: 'pending',
|
// // status: 'pending',
|
||||||
// toolCallId: toolCall.toolCallId
|
// // toolCallId: toolCall.toolCallId
|
||||||
// }
|
// // }
|
||||||
// logger.debug('toolResponse', toolResponse)
|
// // logger.debug('toolResponse', toolResponse)
|
||||||
// this.onChunk({
|
// // this.onChunk({
|
||||||
// type: ChunkType.MCP_TOOL_PENDING,
|
// // type: ChunkType.MCP_TOOL_PENDING,
|
||||||
// responses: [toolResponse]
|
// // responses: [toolResponse]
|
||||||
// })
|
// // })
|
||||||
break
|
// break
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
// if (!toolCall) {
|
// // if (!toolCall) {
|
||||||
// Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
// // Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||||
// return
|
// // return
|
||||||
// }
|
// // }
|
||||||
// this.onChunk({
|
// // this.onChunk({
|
||||||
// type: ChunkType.MCP_TOOL_CREATED,
|
// // type: ChunkType.MCP_TOOL_CREATED,
|
||||||
// tool_calls: [
|
// // tool_calls: [
|
||||||
// {
|
// // {
|
||||||
// id: chunk.id,
|
// // id: chunk.id,
|
||||||
// name: chunk.toolName,
|
// // name: chunk.toolName,
|
||||||
// status: 'pending'
|
// // status: 'pending'
|
||||||
// }
|
// // }
|
||||||
// ]
|
// // ]
|
||||||
// })
|
// // })
|
||||||
}
|
// }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 处理工具调用事件
|
* 处理工具调用事件
|
||||||
@ -158,7 +185,6 @@ export class ToolCallChunkHandler {
|
|||||||
|
|
||||||
let tool: BaseTool
|
let tool: BaseTool
|
||||||
let mcpTool: MCPTool | undefined
|
let mcpTool: MCPTool | undefined
|
||||||
|
|
||||||
// 根据 providerExecuted 标志区分处理逻辑
|
// 根据 providerExecuted 标志区分处理逻辑
|
||||||
if (providerExecuted) {
|
if (providerExecuted) {
|
||||||
// 如果是 Provider 执行的工具(如 web_search)
|
// 如果是 Provider 执行的工具(如 web_search)
|
||||||
@ -196,27 +222,25 @@ export class ToolCallChunkHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 记录活跃的工具调用
|
this.addActiveToolCall(toolCallId, {
|
||||||
this.activeToolCalls.set(toolCallId, {
|
|
||||||
toolCallId,
|
toolCallId,
|
||||||
toolName,
|
toolName,
|
||||||
args,
|
args,
|
||||||
tool
|
tool
|
||||||
})
|
})
|
||||||
|
|
||||||
// 创建 MCPToolResponse 格式
|
// 创建 MCPToolResponse 格式
|
||||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||||
id: toolCallId,
|
id: toolCallId,
|
||||||
tool: tool,
|
tool: tool,
|
||||||
arguments: args,
|
arguments: args,
|
||||||
status: 'pending',
|
status: 'pending', // 统一使用 pending 状态
|
||||||
toolCallId: toolCallId
|
toolCallId: toolCallId
|
||||||
}
|
}
|
||||||
|
|
||||||
// 调用 onChunk
|
// 调用 onChunk
|
||||||
if (this.onChunk) {
|
if (this.onChunk) {
|
||||||
this.onChunk({
|
this.onChunk({
|
||||||
type: ChunkType.MCP_TOOL_PENDING,
|
type: ChunkType.MCP_TOOL_PENDING, // 统一发送 pending 状态
|
||||||
responses: [toolResponse]
|
responses: [toolResponse]
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -276,4 +300,33 @@ export class ToolCallChunkHandler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
handleToolError(
|
||||||
|
chunk: {
|
||||||
|
type: 'tool-error'
|
||||||
|
} & TypedToolError<ToolSet>
|
||||||
|
): void {
|
||||||
|
const { toolCallId, error, input } = chunk
|
||||||
|
const toolCallInfo = this.activeToolCalls.get(toolCallId)
|
||||||
|
if (!toolCallInfo) {
|
||||||
|
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||||
|
id: toolCallId,
|
||||||
|
tool: toolCallInfo.tool,
|
||||||
|
arguments: input,
|
||||||
|
status: 'error',
|
||||||
|
response: error,
|
||||||
|
toolCallId: toolCallId
|
||||||
|
}
|
||||||
|
this.activeToolCalls.delete(toolCallId)
|
||||||
|
if (this.onChunk) {
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||||
|
responses: [toolResponse]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const addActiveToolCall = ToolCallChunkHandler.addActiveToolCall.bind(ToolCallChunkHandler)
|
||||||
|
|||||||
@ -265,15 +265,15 @@ export default class ModernAiProvider {
|
|||||||
params: StreamTextParams,
|
params: StreamTextParams,
|
||||||
config: ModernAiProviderConfig
|
config: ModernAiProviderConfig
|
||||||
): Promise<CompletionsResult> {
|
): Promise<CompletionsResult> {
|
||||||
const modelId = this.model!.id
|
// const modelId = this.model!.id
|
||||||
logger.info('Starting modernCompletions', {
|
// logger.info('Starting modernCompletions', {
|
||||||
modelId,
|
// modelId,
|
||||||
providerId: this.config!.providerId,
|
// providerId: this.config!.providerId,
|
||||||
topicId: config.topicId,
|
// topicId: config.topicId,
|
||||||
hasOnChunk: !!config.onChunk,
|
// hasOnChunk: !!config.onChunk,
|
||||||
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
// hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||||
toolCount: params.tools ? Object.keys(params.tools).length : 0
|
// toolCount: params.tools ? Object.keys(params.tools).length : 0
|
||||||
})
|
// })
|
||||||
|
|
||||||
// 根据条件构建插件数组
|
// 根据条件构建插件数组
|
||||||
const plugins = buildPlugins(config)
|
const plugins = buildPlugins(config)
|
||||||
|
|||||||
@ -35,7 +35,7 @@ export function buildPlugins(
|
|||||||
plugins.push(webSearchPlugin())
|
plugins.push(webSearchPlugin())
|
||||||
}
|
}
|
||||||
// 2. 支持工具调用时添加搜索插件
|
// 2. 支持工具调用时添加搜索插件
|
||||||
if (middlewareConfig.isSupportedToolUse) {
|
if (middlewareConfig.isSupportedToolUse || middlewareConfig.isPromptToolUse) {
|
||||||
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant, middlewareConfig.topicId || ''))
|
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant, middlewareConfig.topicId || ''))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,12 +45,13 @@ export function buildPlugins(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 4. 启用Prompt工具调用时添加工具插件
|
// 4. 启用Prompt工具调用时添加工具插件
|
||||||
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
if (middlewareConfig.isPromptToolUse) {
|
||||||
plugins.push(
|
plugins.push(
|
||||||
createPromptToolUsePlugin({
|
createPromptToolUsePlugin({
|
||||||
enabled: true,
|
enabled: true,
|
||||||
createSystemMessage: (systemPrompt, params, context) => {
|
createSystemMessage: (systemPrompt, params, context) => {
|
||||||
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
|
const modelId = typeof context.model === 'string' ? context.model : context.model.modelId
|
||||||
|
if (modelId.includes('o1-mini') || modelId.includes('o1-preview')) {
|
||||||
if (context.isRecursiveCall) {
|
if (context.isRecursiveCall) {
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,7 +19,8 @@ import store from '@renderer/store'
|
|||||||
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
||||||
import type { Assistant } from '@renderer/types'
|
import type { Assistant } from '@renderer/types'
|
||||||
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
|
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
|
||||||
import type { ModelMessage } from 'ai'
|
import type { LanguageModel, ModelMessage } from 'ai'
|
||||||
|
import { generateText } from 'ai'
|
||||||
import { isEmpty } from 'lodash'
|
import { isEmpty } from 'lodash'
|
||||||
|
|
||||||
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||||
@ -76,9 +77,7 @@ async function analyzeSearchIntent(
|
|||||||
shouldKnowledgeSearch?: boolean
|
shouldKnowledgeSearch?: boolean
|
||||||
shouldMemorySearch?: boolean
|
shouldMemorySearch?: boolean
|
||||||
lastAnswer?: ModelMessage
|
lastAnswer?: ModelMessage
|
||||||
context: AiRequestContext & {
|
context: AiRequestContext
|
||||||
isAnalyzing?: boolean
|
|
||||||
}
|
|
||||||
topicId: string
|
topicId: string
|
||||||
}
|
}
|
||||||
): Promise<ExtractResults | undefined> {
|
): Promise<ExtractResults | undefined> {
|
||||||
@ -122,9 +121,7 @@ async function analyzeSearchIntent(
|
|||||||
logger.error('Provider not found or missing API key')
|
logger.error('Provider not found or missing API key')
|
||||||
return getFallbackResult()
|
return getFallbackResult()
|
||||||
}
|
}
|
||||||
// console.log('formattedPrompt', schema)
|
|
||||||
try {
|
try {
|
||||||
context.isAnalyzing = true
|
|
||||||
logger.info('Starting intent analysis generateText call', {
|
logger.info('Starting intent analysis generateText call', {
|
||||||
modelId: model.id,
|
modelId: model.id,
|
||||||
topicId: options.topicId,
|
topicId: options.topicId,
|
||||||
@ -133,18 +130,16 @@ async function analyzeSearchIntent(
|
|||||||
hasKnowledgeSearch: needKnowledgeExtract
|
hasKnowledgeSearch: needKnowledgeExtract
|
||||||
})
|
})
|
||||||
|
|
||||||
const { text: result } = await context.executor
|
const { text: result } = await generateText({
|
||||||
.generateText(model.id, {
|
model: context.model as LanguageModel,
|
||||||
prompt: formattedPrompt
|
prompt: formattedPrompt
|
||||||
})
|
}).finally(() => {
|
||||||
.finally(() => {
|
logger.info('Intent analysis generateText call completed', {
|
||||||
context.isAnalyzing = false
|
modelId: model.id,
|
||||||
logger.info('Intent analysis generateText call completed', {
|
topicId: options.topicId,
|
||||||
modelId: model.id,
|
requestId: context.requestId
|
||||||
topicId: options.topicId,
|
|
||||||
requestId: context.requestId
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
})
|
||||||
const parsedResult = extractInfoFromXML(result)
|
const parsedResult = extractInfoFromXML(result)
|
||||||
logger.debug('Intent analysis result', { parsedResult })
|
logger.debug('Intent analysis result', { parsedResult })
|
||||||
|
|
||||||
@ -183,7 +178,6 @@ async function storeConversationMemory(
|
|||||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||||
|
|
||||||
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||||
// console.log('Memory storage is disabled')
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,25 +239,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
|||||||
// 存储意图分析结果
|
// 存储意图分析结果
|
||||||
const intentAnalysisResults: { [requestId: string]: ExtractResults } = {}
|
const intentAnalysisResults: { [requestId: string]: ExtractResults } = {}
|
||||||
const userMessages: { [requestId: string]: ModelMessage } = {}
|
const userMessages: { [requestId: string]: ModelMessage } = {}
|
||||||
let currentContext: AiRequestContext | null = null
|
|
||||||
|
|
||||||
return definePlugin({
|
return definePlugin({
|
||||||
name: 'search-orchestration',
|
name: 'search-orchestration',
|
||||||
enforce: 'pre', // 确保在其他插件之前执行
|
enforce: 'pre', // 确保在其他插件之前执行
|
||||||
|
|
||||||
configureContext: (context: AiRequestContext) => {
|
|
||||||
if (currentContext) {
|
|
||||||
context.isAnalyzing = currentContext.isAnalyzing
|
|
||||||
}
|
|
||||||
currentContext = context
|
|
||||||
},
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 🔍 Step 1: 意图识别阶段
|
* 🔍 Step 1: 意图识别阶段
|
||||||
*/
|
*/
|
||||||
onRequestStart: async (context: AiRequestContext) => {
|
onRequestStart: async (context: AiRequestContext) => {
|
||||||
if (context.isAnalyzing) return
|
|
||||||
|
|
||||||
// 没开启任何搜索则不进行意图分析
|
// 没开启任何搜索则不进行意图分析
|
||||||
if (!(assistant.webSearchProviderId || assistant.knowledge_bases?.length || assistant.enableMemory)) return
|
if (!(assistant.webSearchProviderId || assistant.knowledge_bases?.length || assistant.enableMemory)) return
|
||||||
|
|
||||||
@ -315,7 +298,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
|||||||
* 🔧 Step 2: 工具配置阶段
|
* 🔧 Step 2: 工具配置阶段
|
||||||
*/
|
*/
|
||||||
transformParams: async (params: any, context: AiRequestContext) => {
|
transformParams: async (params: any, context: AiRequestContext) => {
|
||||||
if (context.isAnalyzing) return params
|
|
||||||
// logger.info('🔧 Configuring tools based on intent...', context.requestId)
|
// logger.info('🔧 Configuring tools based on intent...', context.requestId)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -409,7 +391,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
|||||||
// context.isAnalyzing = false
|
// context.isAnalyzing = false
|
||||||
// logger.info('context.isAnalyzing', context, result)
|
// logger.info('context.isAnalyzing', context, result)
|
||||||
// logger.info('💾 Starting memory storage...', context.requestId)
|
// logger.info('💾 Starting memory storage...', context.requestId)
|
||||||
if (context.isAnalyzing) return
|
|
||||||
try {
|
try {
|
||||||
const messages = context.originalParams.messages
|
const messages = context.originalParams.messages
|
||||||
|
|
||||||
|
|||||||
@ -19,8 +19,6 @@ export const memorySearchTool = () => {
|
|||||||
limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return')
|
limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return')
|
||||||
}),
|
}),
|
||||||
execute: async ({ query, limit = 5 }) => {
|
execute: async ({ query, limit = 5 }) => {
|
||||||
// console.log('🧠 [memorySearchTool] Searching memories:', { query, limit })
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||||
if (!globalMemoryEnabled) {
|
if (!globalMemoryEnabled) {
|
||||||
@ -29,7 +27,6 @@ export const memorySearchTool = () => {
|
|||||||
|
|
||||||
const memoryConfig = selectMemoryConfig(store.getState())
|
const memoryConfig = selectMemoryConfig(store.getState())
|
||||||
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||||
// console.warn('Memory search skipped: embedding or LLM model not configured')
|
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,12 +37,10 @@ export const memorySearchTool = () => {
|
|||||||
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
|
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
|
||||||
|
|
||||||
if (relevantMemories?.length > 0) {
|
if (relevantMemories?.length > 0) {
|
||||||
// console.log('🧠 [memorySearchTool] Found memories:', relevantMemories.length)
|
|
||||||
return relevantMemories
|
return relevantMemories
|
||||||
}
|
}
|
||||||
return []
|
return []
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// console.error('🧠 [memorySearchTool] Error:', error)
|
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -84,8 +79,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
|
|||||||
.optional()
|
.optional()
|
||||||
}) satisfies z.ZodSchema<MemorySearchWithExtractionInput>,
|
}) satisfies z.ZodSchema<MemorySearchWithExtractionInput>,
|
||||||
execute: async ({ userMessage }) => {
|
execute: async ({ userMessage }) => {
|
||||||
// console.log('🧠 [memorySearchToolWithExtraction] Processing:', { userMessage, lastAnswer })
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||||
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||||
@ -97,7 +90,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
|
|||||||
|
|
||||||
const memoryConfig = selectMemoryConfig(store.getState())
|
const memoryConfig = selectMemoryConfig(store.getState())
|
||||||
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||||
// console.warn('Memory search skipped: embedding or LLM model not configured')
|
|
||||||
return {
|
return {
|
||||||
extractedKeywords: 'Memory models not configured',
|
extractedKeywords: 'Memory models not configured',
|
||||||
searchResults: []
|
searchResults: []
|
||||||
@ -125,7 +117,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (relevantMemories?.length > 0) {
|
if (relevantMemories?.length > 0) {
|
||||||
// console.log('🧠 [memorySearchToolWithExtraction] Found memories:', relevantMemories.length)
|
|
||||||
return {
|
return {
|
||||||
extractedKeywords: content,
|
extractedKeywords: content,
|
||||||
searchResults: relevantMemories
|
searchResults: relevantMemories
|
||||||
@ -137,7 +128,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
|
|||||||
searchResults: []
|
searchResults: []
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// console.error('🧠 [memorySearchToolWithExtraction] Error:', error)
|
|
||||||
return {
|
return {
|
||||||
extractedKeywords: 'Search failed',
|
extractedKeywords: 'Search failed',
|
||||||
searchResults: []
|
searchResults: []
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
|
||||||
import { MCPTool, MCPToolResponse } from '@renderer/types'
|
import { MCPTool, MCPToolResponse } from '@renderer/types'
|
||||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
|
||||||
import { callMCPTool, getMcpServerByTool, isToolAutoApproved } from '@renderer/utils/mcp-tools'
|
import { callMCPTool, getMcpServerByTool, isToolAutoApproved } from '@renderer/utils/mcp-tools'
|
||||||
import { requestToolConfirmation } from '@renderer/utils/userConfirmation'
|
import { requestToolConfirmation } from '@renderer/utils/userConfirmation'
|
||||||
import { type Tool, type ToolSet } from 'ai'
|
import { type Tool, type ToolSet } from 'ai'
|
||||||
@ -33,8 +31,36 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
|
|||||||
tools[mcpTool.name] = tool({
|
tools[mcpTool.name] = tool({
|
||||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||||
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||||
execute: async (params, { toolCallId, experimental_context }) => {
|
execute: async (params, { toolCallId }) => {
|
||||||
const { onChunk } = experimental_context as { onChunk: (chunk: Chunk) => void }
|
// 检查是否启用自动批准
|
||||||
|
const server = getMcpServerByTool(mcpTool)
|
||||||
|
const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server)
|
||||||
|
|
||||||
|
let confirmed = true
|
||||||
|
|
||||||
|
if (!isAutoApproveEnabled) {
|
||||||
|
// 请求用户确认
|
||||||
|
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`)
|
||||||
|
confirmed = await requestToolConfirmation(toolCallId)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!confirmed) {
|
||||||
|
// 用户拒绝执行工具
|
||||||
|
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: `User declined to execute tool "${mcpTool.name}".`
|
||||||
|
}
|
||||||
|
],
|
||||||
|
isError: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用户确认或自动批准,执行工具
|
||||||
|
logger.debug(`Executing tool: ${mcpTool.name}`)
|
||||||
|
|
||||||
// 创建适配的 MCPToolResponse 对象
|
// 创建适配的 MCPToolResponse 对象
|
||||||
const toolResponse: MCPToolResponse = {
|
const toolResponse: MCPToolResponse = {
|
||||||
id: toolCallId,
|
id: toolCallId,
|
||||||
@ -44,53 +70,18 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
|
|||||||
toolCallId
|
toolCallId
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
const result = await callMCPTool(toolResponse)
|
||||||
// 检查是否启用自动批准
|
|
||||||
const server = getMcpServerByTool(mcpTool)
|
|
||||||
const isAutoApproveEnabled = isToolAutoApproved(mcpTool, server)
|
|
||||||
|
|
||||||
let confirmed = true
|
// 返回结果,AI SDK 会处理序列化
|
||||||
if (!isAutoApproveEnabled) {
|
if (result.isError) {
|
||||||
// 请求用户确认
|
// throw new Error(result.content?.[0]?.text || 'Tool execution failed')
|
||||||
logger.debug(`Requesting user confirmation for tool: ${mcpTool.name}`)
|
return Promise.reject(result)
|
||||||
confirmed = await requestToolConfirmation(toolResponse.id)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!confirmed) {
|
|
||||||
// 用户拒绝执行工具
|
|
||||||
logger.debug(`User cancelled tool execution: ${mcpTool.name}`)
|
|
||||||
return {
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: 'text',
|
|
||||||
text: `User declined to execute tool "${mcpTool.name}".`
|
|
||||||
}
|
|
||||||
],
|
|
||||||
isError: false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 用户确认或自动批准,执行工具
|
|
||||||
toolResponse.status = 'invoking'
|
|
||||||
logger.debug(`Executing tool: ${mcpTool.name}`)
|
|
||||||
|
|
||||||
onChunk({
|
|
||||||
type: ChunkType.MCP_TOOL_IN_PROGRESS,
|
|
||||||
responses: [toolResponse]
|
|
||||||
})
|
|
||||||
|
|
||||||
const result = await callMCPTool(toolResponse)
|
|
||||||
|
|
||||||
// 返回结果,AI SDK 会处理序列化
|
|
||||||
if (result.isError) {
|
|
||||||
throw new Error(result.content?.[0]?.text || 'Tool execution failed')
|
|
||||||
}
|
|
||||||
// 返回工具执行结果
|
|
||||||
return result
|
|
||||||
} catch (error) {
|
|
||||||
logger.error(`MCP Tool execution failed: ${mcpTool.name}`, { error })
|
|
||||||
throw error
|
|
||||||
}
|
}
|
||||||
|
// 返回工具执行结果
|
||||||
|
return result
|
||||||
|
// } catch (error) {
|
||||||
|
// logger.error(`MCP Tool execution failed: ${mcpTool.name}`, { error })
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -105,7 +105,6 @@ const ThinkingBlock: React.FC<Props> = ({ block }) => {
|
|||||||
const ThinkingTimeSeconds = memo(
|
const ThinkingTimeSeconds = memo(
|
||||||
({ blockThinkingTime, isThinking }: { blockThinkingTime: number; isThinking: boolean }) => {
|
({ blockThinkingTime, isThinking }: { blockThinkingTime: number; isThinking: boolean }) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
// console.log('blockThinkingTime', blockThinkingTime)
|
|
||||||
// const [thinkingTime, setThinkingTime] = useState(blockThinkingTime || 0)
|
// const [thinkingTime, setThinkingTime] = useState(blockThinkingTime || 0)
|
||||||
|
|
||||||
// FIXME: 这里统计的和请求处统计的有一定误差
|
// FIXME: 这里统计的和请求处统计的有一定误差
|
||||||
|
|||||||
@ -186,7 +186,6 @@ const MessageMenubar: FC<Props> = (props) => {
|
|||||||
try {
|
try {
|
||||||
await translateText(mainTextContent, language, translationUpdater)
|
await translateText(mainTextContent, language, translationUpdater)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// console.error('Translation failed:', error)
|
|
||||||
window.message.error({ content: t('translate.error.failed'), key: 'translate-message' })
|
window.message.error({ content: t('translate.error.failed'), key: 'translate-message' })
|
||||||
// 理应只有一个
|
// 理应只有一个
|
||||||
const translationBlocks = findTranslationBlocksById(message.id)
|
const translationBlocks = findTranslationBlocksById(message.id)
|
||||||
|
|||||||
@ -60,14 +60,29 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
|||||||
const toolResponse = block.metadata?.rawMcpToolResponse
|
const toolResponse = block.metadata?.rawMcpToolResponse
|
||||||
|
|
||||||
const { id, tool, status, response } = toolResponse!
|
const { id, tool, status, response } = toolResponse!
|
||||||
|
|
||||||
const isPending = status === 'pending'
|
const isPending = status === 'pending'
|
||||||
const isInvoking = status === 'invoking'
|
|
||||||
const isDone = status === 'done'
|
const isDone = status === 'done'
|
||||||
|
const isError = status === 'error'
|
||||||
|
|
||||||
|
const isAutoApproved = useMemo(
|
||||||
|
() =>
|
||||||
|
isToolAutoApproved(
|
||||||
|
tool,
|
||||||
|
mcpServers.find((s) => s.id === tool.serverId)
|
||||||
|
),
|
||||||
|
[tool, mcpServers]
|
||||||
|
)
|
||||||
|
|
||||||
|
// 增加本地状态来跟踪用户确认
|
||||||
|
const [isConfirmed, setIsConfirmed] = useState(isAutoApproved)
|
||||||
|
|
||||||
|
// 判断不同的UI状态
|
||||||
|
const isWaitingConfirmation = isPending && !isAutoApproved && !isConfirmed
|
||||||
|
const isExecuting = isPending && (isAutoApproved || isConfirmed)
|
||||||
|
|
||||||
const timer = useRef<NodeJS.Timeout | null>(null)
|
const timer = useRef<NodeJS.Timeout | null>(null)
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!isPending) return
|
if (!isWaitingConfirmation) return
|
||||||
|
|
||||||
if (countdown > 0) {
|
if (countdown > 0) {
|
||||||
timer.current = setTimeout(() => {
|
timer.current = setTimeout(() => {
|
||||||
@ -75,6 +90,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
|||||||
setCountdown((prev) => prev - 1)
|
setCountdown((prev) => prev - 1)
|
||||||
}, 1000)
|
}, 1000)
|
||||||
} else if (countdown === 0) {
|
} else if (countdown === 0) {
|
||||||
|
setIsConfirmed(true)
|
||||||
confirmToolAction(id)
|
confirmToolAction(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -83,7 +99,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
|||||||
clearTimeout(timer.current)
|
clearTimeout(timer.current)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, [countdown, id, isPending])
|
}, [countdown, id, isWaitingConfirmation])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const removeListener = window.electron.ipcRenderer.on(
|
const removeListener = window.electron.ipcRenderer.on(
|
||||||
@ -146,6 +162,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
|||||||
|
|
||||||
const handleConfirmTool = () => {
|
const handleConfirmTool = () => {
|
||||||
cancelCountdown()
|
cancelCountdown()
|
||||||
|
setIsConfirmed(true)
|
||||||
confirmToolAction(id)
|
confirmToolAction(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,6 +212,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
|||||||
updateMCPServer(updatedServer)
|
updateMCPServer(updatedServer)
|
||||||
|
|
||||||
// Also confirm the current tool
|
// Also confirm the current tool
|
||||||
|
setIsConfirmed(true)
|
||||||
confirmToolAction(id)
|
confirmToolAction(id)
|
||||||
|
|
||||||
window.message.success({
|
window.message.success({
|
||||||
@ -206,32 +224,31 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
|||||||
const renderStatusIndicator = (status: string, hasError: boolean) => {
|
const renderStatusIndicator = (status: string, hasError: boolean) => {
|
||||||
let label = ''
|
let label = ''
|
||||||
let icon: React.ReactNode | null = null
|
let icon: React.ReactNode | null = null
|
||||||
switch (status) {
|
|
||||||
case 'pending':
|
if (status === 'pending') {
|
||||||
|
if (isWaitingConfirmation) {
|
||||||
label = t('message.tools.pending', 'Awaiting Approval')
|
label = t('message.tools.pending', 'Awaiting Approval')
|
||||||
icon = <LoadingIcon style={{ marginLeft: 6, color: 'var(--status-color-warning)' }} />
|
icon = <LoadingIcon style={{ marginLeft: 6, color: 'var(--status-color-warning)' }} />
|
||||||
break
|
} else if (isExecuting) {
|
||||||
case 'invoking':
|
|
||||||
label = t('message.tools.invoking')
|
label = t('message.tools.invoking')
|
||||||
icon = <LoadingIcon style={{ marginLeft: 6 }} />
|
icon = <LoadingIcon style={{ marginLeft: 6 }} />
|
||||||
break
|
}
|
||||||
case 'cancelled':
|
} else if (status === 'cancelled') {
|
||||||
label = t('message.tools.cancelled')
|
label = t('message.tools.cancelled')
|
||||||
icon = <X size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
icon = <X size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
||||||
break
|
} else if (status === 'done') {
|
||||||
case 'done':
|
if (hasError) {
|
||||||
if (hasError) {
|
label = t('message.tools.error')
|
||||||
label = t('message.tools.error')
|
icon = <TriangleAlert size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
||||||
icon = <TriangleAlert size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
} else {
|
||||||
} else {
|
label = t('message.tools.completed')
|
||||||
label = t('message.tools.completed')
|
icon = <Check size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
||||||
icon = <Check size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
}
|
||||||
}
|
} else if (status === 'error') {
|
||||||
break
|
label = t('message.tools.error')
|
||||||
default:
|
icon = <TriangleAlert size={13} style={{ marginLeft: 6 }} className="lucide-custom" />
|
||||||
label = ''
|
|
||||||
icon = null
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<StatusIndicator status={status} hasError={hasError}>
|
<StatusIndicator status={status} hasError={hasError}>
|
||||||
{label}
|
{label}
|
||||||
@ -248,7 +265,6 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
|||||||
params: toolResponse.arguments,
|
params: toolResponse.arguments,
|
||||||
response: toolResponse.response
|
response: toolResponse.response
|
||||||
}
|
}
|
||||||
|
|
||||||
items.push({
|
items.push({
|
||||||
key: id,
|
key: id,
|
||||||
label: (
|
label: (
|
||||||
@ -283,7 +299,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
|||||||
<Maximize size={14} />
|
<Maximize size={14} />
|
||||||
</ActionButton>
|
</ActionButton>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
{!isPending && !isInvoking && (
|
{!isPending && (
|
||||||
<Tooltip title={t('common.copy')} mouseEnterDelay={0.5}>
|
<Tooltip title={t('common.copy')} mouseEnterDelay={0.5}>
|
||||||
<ActionButton
|
<ActionButton
|
||||||
className="message-action-button"
|
className="message-action-button"
|
||||||
@ -301,7 +317,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
|||||||
</MessageTitleLabel>
|
</MessageTitleLabel>
|
||||||
),
|
),
|
||||||
children:
|
children:
|
||||||
isDone && result ? (
|
(isDone || isError) && result ? (
|
||||||
<ToolResponseContainer
|
<ToolResponseContainer
|
||||||
style={{
|
style={{
|
||||||
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
|
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
|
||||||
@ -370,7 +386,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
|||||||
}
|
}
|
||||||
}}>
|
}}>
|
||||||
<ToolContainer>
|
<ToolContainer>
|
||||||
<ToolContentWrapper className={status}>
|
<ToolContentWrapper className={isPending ? 'pending' : status}>
|
||||||
<CollapseContainer
|
<CollapseContainer
|
||||||
ghost
|
ghost
|
||||||
activeKey={activeKeys}
|
activeKey={activeKeys}
|
||||||
@ -383,14 +399,16 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
|||||||
<ExpandIcon $isActive={isActive} size={18} color="var(--color-text-3)" strokeWidth={1.5} />
|
<ExpandIcon $isActive={isActive} size={18} color="var(--color-text-3)" strokeWidth={1.5} />
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
{(isPending || isInvoking) && (
|
{isPending && (
|
||||||
<ActionsBar>
|
<ActionsBar>
|
||||||
<ActionLabel>
|
<ActionLabel>
|
||||||
{isPending ? t('settings.mcp.tools.autoApprove.tooltip.confirm') : t('message.tools.invoking')}
|
{isWaitingConfirmation
|
||||||
|
? t('settings.mcp.tools.autoApprove.tooltip.confirm')
|
||||||
|
: t('message.tools.invoking')}
|
||||||
</ActionLabel>
|
</ActionLabel>
|
||||||
|
|
||||||
<ActionButtonsGroup>
|
<ActionButtonsGroup>
|
||||||
{isPending && (
|
{isWaitingConfirmation && (
|
||||||
<Button
|
<Button
|
||||||
color="danger"
|
color="danger"
|
||||||
variant="filled"
|
variant="filled"
|
||||||
@ -402,7 +420,7 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
|||||||
{t('common.cancel')}
|
{t('common.cancel')}
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
{isInvoking && toolResponse?.id ? (
|
{isExecuting && toolResponse?.id ? (
|
||||||
<Button
|
<Button
|
||||||
size="small"
|
size="small"
|
||||||
color="danger"
|
color="danger"
|
||||||
@ -416,29 +434,31 @@ const MessageMcpTool: FC<Props> = ({ block }) => {
|
|||||||
{t('chat.input.pause')}
|
{t('chat.input.pause')}
|
||||||
</Button>
|
</Button>
|
||||||
) : (
|
) : (
|
||||||
<StyledDropdownButton
|
isWaitingConfirmation && (
|
||||||
size="small"
|
<StyledDropdownButton
|
||||||
type="primary"
|
size="small"
|
||||||
icon={<ChevronDown size={14} />}
|
type="primary"
|
||||||
onClick={() => {
|
icon={<ChevronDown size={14} />}
|
||||||
handleConfirmTool()
|
onClick={() => {
|
||||||
}}
|
handleConfirmTool()
|
||||||
menu={{
|
}}
|
||||||
items: [
|
menu={{
|
||||||
{
|
items: [
|
||||||
key: 'autoApprove',
|
{
|
||||||
label: t('settings.mcp.tools.autoApprove.label'),
|
key: 'autoApprove',
|
||||||
onClick: () => {
|
label: t('settings.mcp.tools.autoApprove.label'),
|
||||||
handleAutoApprove()
|
onClick: () => {
|
||||||
|
handleAutoApprove()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
]
|
||||||
]
|
}}>
|
||||||
}}>
|
<CirclePlay size={15} className="lucide-custom" />
|
||||||
<CirclePlay size={15} className="lucide-custom" />
|
<CountdownText>
|
||||||
<CountdownText>
|
{t('settings.mcp.tools.run', 'Run')} ({countdown}s)
|
||||||
{t('settings.mcp.tools.run', 'Run')} ({countdown}s)
|
</CountdownText>
|
||||||
</CountdownText>
|
</StyledDropdownButton>
|
||||||
</StyledDropdownButton>
|
)
|
||||||
)}
|
)}
|
||||||
</ActionButtonsGroup>
|
</ActionButtonsGroup>
|
||||||
</ActionsBar>
|
</ActionsBar>
|
||||||
@ -542,8 +562,7 @@ const ToolContentWrapper = styled.div`
|
|||||||
border: 1px solid var(--color-border);
|
border: 1px solid var(--color-border);
|
||||||
}
|
}
|
||||||
|
|
||||||
&.pending,
|
&.pending {
|
||||||
&.invoking {
|
|
||||||
background-color: var(--color-background-soft);
|
background-color: var(--color-background-soft);
|
||||||
.ant-collapse {
|
.ant-collapse {
|
||||||
border: none;
|
border: none;
|
||||||
@ -663,6 +682,8 @@ const StatusIndicator = styled.span<{ status: string; hasError?: boolean }>`
|
|||||||
return 'var(--status-color-error)'
|
return 'var(--status-color-error)'
|
||||||
case 'done':
|
case 'done':
|
||||||
return props.hasError ? 'var(--status-color-error)' : 'var(--status-color-success)'
|
return props.hasError ? 'var(--status-color-error)' : 'var(--status-color-success)'
|
||||||
|
case 'error':
|
||||||
|
return 'var(--status-color-error)'
|
||||||
default:
|
default:
|
||||||
return 'var(--color-text)'
|
return 'var(--color-text)'
|
||||||
}
|
}
|
||||||
|
|||||||
@ -99,6 +99,8 @@ export async function fetchChatCompletion({
|
|||||||
const provider = AI.getActualProvider()
|
const provider = AI.getActualProvider()
|
||||||
|
|
||||||
const mcpTools: MCPTool[] = []
|
const mcpTools: MCPTool[] = []
|
||||||
|
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||||
|
|
||||||
if (isPromptToolUse(assistant) || isSupportedToolUse(assistant)) {
|
if (isPromptToolUse(assistant) || isSupportedToolUse(assistant)) {
|
||||||
mcpTools.push(...(await fetchMcpTools(assistant)))
|
mcpTools.push(...(await fetchMcpTools(assistant)))
|
||||||
}
|
}
|
||||||
@ -137,7 +139,6 @@ export async function fetchChatCompletion({
|
|||||||
}
|
}
|
||||||
|
|
||||||
// --- Call AI Completions ---
|
// --- Call AI Completions ---
|
||||||
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
|
||||||
await AI.completions(modelId, aiSdkParams, {
|
await AI.completions(modelId, aiSdkParams, {
|
||||||
...middlewareConfig,
|
...middlewareConfig,
|
||||||
assistant,
|
assistant,
|
||||||
|
|||||||
@ -178,12 +178,7 @@ class WebSearchService {
|
|||||||
formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}`
|
formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}`
|
||||||
}
|
}
|
||||||
|
|
||||||
// try {
|
|
||||||
return await webSearchEngine.search(formattedQuery, websearch, httpOptions)
|
return await webSearchEngine.search(formattedQuery, websearch, httpOptions)
|
||||||
// } catch (error) {
|
|
||||||
// console.error('Search failed:', error)
|
|
||||||
// throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -69,16 +69,6 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
|
|||||||
})
|
})
|
||||||
await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
|
await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
|
||||||
},
|
},
|
||||||
// onBlockCreated: async () => {
|
|
||||||
// if (blockManager.hasInitialPlaceholder) {
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
// console.log('onBlockCreated')
|
|
||||||
// const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, {
|
|
||||||
// status: MessageBlockStatus.PROCESSING
|
|
||||||
// })
|
|
||||||
// await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
|
|
||||||
// },
|
|
||||||
|
|
||||||
onError: async (error: AISDKError) => {
|
onError: async (error: AISDKError) => {
|
||||||
logger.debug('onError', error)
|
logger.debug('onError', error)
|
||||||
|
|||||||
@ -51,28 +51,6 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
|
|
||||||
// 根据 toolResponse.id 查找对应的块ID
|
|
||||||
const targetBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
|
||||||
|
|
||||||
if (targetBlockId && toolResponse.status === 'invoking') {
|
|
||||||
const changes = {
|
|
||||||
status: MessageBlockStatus.PROCESSING,
|
|
||||||
metadata: { rawMcpToolResponse: toolResponse }
|
|
||||||
}
|
|
||||||
blockManager.smartBlockUpdate(targetBlockId, changes, MessageBlockType.TOOL)
|
|
||||||
} else if (!targetBlockId) {
|
|
||||||
logger.warn(
|
|
||||||
`[onToolCallInProgress] No block ID found for tool ID: ${toolResponse.id}. Available mappings:`,
|
|
||||||
Array.from(toolCallIdToBlockIdMap.entries())
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
logger.warn(
|
|
||||||
`[onToolCallInProgress] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
|
||||||
)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
onToolCallComplete: (toolResponse: MCPToolResponse) => {
|
onToolCallComplete: (toolResponse: MCPToolResponse) => {
|
||||||
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
||||||
toolCallIdToBlockIdMap.delete(toolResponse.id)
|
toolCallIdToBlockIdMap.delete(toolResponse.id)
|
||||||
@ -104,7 +82,6 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
|||||||
stack: null
|
stack: null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
|
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
|
||||||
|
|
||||||
// Handle citation block creation for web search results
|
// Handle citation block creation for web search results
|
||||||
@ -132,7 +109,6 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
|||||||
citationBlockId = citationBlock.id
|
citationBlockId = citationBlock.id
|
||||||
blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||||
}
|
}
|
||||||
// TODO: 处理 memory 引用
|
|
||||||
} else {
|
} else {
|
||||||
logger.warn(
|
logger.warn(
|
||||||
`[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
`[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}`
|
||||||
|
|||||||
@ -458,18 +458,18 @@ describe('streamCallback Integration Tests', () => {
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
// {
|
||||||
type: ChunkType.MCP_TOOL_IN_PROGRESS,
|
// type: ChunkType.MCP_TOOL_PENDING,
|
||||||
responses: [
|
// responses: [
|
||||||
{
|
// {
|
||||||
id: 'tool-call-1',
|
// id: 'tool-call-1',
|
||||||
tool: mockTool,
|
// tool: mockTool,
|
||||||
arguments: { testArg: 'value' },
|
// arguments: { testArg: 'value' },
|
||||||
status: 'invoking' as const,
|
// status: 'invoking' as const,
|
||||||
response: ''
|
// response: ''
|
||||||
}
|
// }
|
||||||
]
|
// ]
|
||||||
},
|
// },
|
||||||
{
|
{
|
||||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||||
responses: [
|
responses: [
|
||||||
@ -611,18 +611,18 @@ describe('streamCallback Integration Tests', () => {
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
// {
|
||||||
type: ChunkType.MCP_TOOL_IN_PROGRESS,
|
// type: ChunkType.MCP_TOOL_PENDING,
|
||||||
responses: [
|
// responses: [
|
||||||
{
|
// {
|
||||||
id: 'tool-call-1',
|
// id: 'tool-call-1',
|
||||||
tool: mockCalculatorTool,
|
// tool: mockCalculatorTool,
|
||||||
arguments: { operation: 'add', a: 1, b: 2 },
|
// arguments: { operation: 'add', a: 1, b: 2 },
|
||||||
status: 'invoking' as const,
|
// status: 'invoking' as const,
|
||||||
response: ''
|
// response: ''
|
||||||
}
|
// }
|
||||||
]
|
// ]
|
||||||
},
|
// },
|
||||||
{
|
{
|
||||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||||
responses: [
|
responses: [
|
||||||
|
|||||||
@ -826,9 +826,7 @@ export function mcpToolCallResponseToAwsBedrockMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 是否启用工具使用
|
* 是否启用工具使用(function call)
|
||||||
* 1. 如果模型支持函数调用,则启用工具使用
|
|
||||||
* 2. 如果工具使用模式为 prompt,则启用工具使用
|
|
||||||
* @param assistant
|
* @param assistant
|
||||||
* @returns 是否启用工具使用
|
* @returns 是否启用工具使用
|
||||||
*/
|
*/
|
||||||
|
|||||||
51
yarn.lock
51
yarn.lock
@ -179,15 +179,15 @@ __metadata:
|
|||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
"@ai-sdk/google@npm:^2.0.7":
|
"@ai-sdk/google@npm:^2.0.13":
|
||||||
version: 2.0.7
|
version: 2.0.13
|
||||||
resolution: "@ai-sdk/google@npm:2.0.7"
|
resolution: "@ai-sdk/google@npm:2.0.13"
|
||||||
dependencies:
|
dependencies:
|
||||||
"@ai-sdk/provider": "npm:2.0.0"
|
"@ai-sdk/provider": "npm:2.0.0"
|
||||||
"@ai-sdk/provider-utils": "npm:3.0.4"
|
"@ai-sdk/provider-utils": "npm:3.0.8"
|
||||||
peerDependencies:
|
peerDependencies:
|
||||||
zod: ^3.25.76 || ^4
|
zod: ^3.25.76 || ^4
|
||||||
checksum: 10c0/bde4c95a2a167355cda18de9d5b273d562d2a724f650ca69016daa8df2766280487e143cf0cdd96f6654c255d587a680c6a937b280eb734ca2c35d6f9b9e943c
|
checksum: 10c0/a05210de11d7ab41d49bcd0330c37f4116441b149d8ccc9b6bc5eaa12ea42bae82364dc2cd09502734b15115071f07395525806ea4998930b285b1ce74102186
|
||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
@ -227,15 +227,15 @@ __metadata:
|
|||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
"@ai-sdk/openai@npm:^2.0.19":
|
"@ai-sdk/openai@npm:^2.0.26":
|
||||||
version: 2.0.19
|
version: 2.0.26
|
||||||
resolution: "@ai-sdk/openai@npm:2.0.19"
|
resolution: "@ai-sdk/openai@npm:2.0.26"
|
||||||
dependencies:
|
dependencies:
|
||||||
"@ai-sdk/provider": "npm:2.0.0"
|
"@ai-sdk/provider": "npm:2.0.0"
|
||||||
"@ai-sdk/provider-utils": "npm:3.0.5"
|
"@ai-sdk/provider-utils": "npm:3.0.8"
|
||||||
peerDependencies:
|
peerDependencies:
|
||||||
zod: ^3.25.76 || ^4
|
zod: ^3.25.76 || ^4
|
||||||
checksum: 10c0/04db695669d783a810b80283e0cd48f6e7654667fd76ca2d35c7cffae6fdd68fb0473118e4e097ef1352f4432dd7c15c07f873d712b940c72495e5839b0ede98
|
checksum: 10c0/b8cb01c0c38525c38901f41f1693cd15589932a2aceddea14bed30f44719532a5e74615fb0e974eff1a0513048ac204c27456ff8829a9c811d1461cc635c9cc5
|
||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
@ -267,20 +267,6 @@ __metadata:
|
|||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
"@ai-sdk/provider-utils@npm:3.0.5":
|
|
||||||
version: 3.0.5
|
|
||||||
resolution: "@ai-sdk/provider-utils@npm:3.0.5"
|
|
||||||
dependencies:
|
|
||||||
"@ai-sdk/provider": "npm:2.0.0"
|
|
||||||
"@standard-schema/spec": "npm:^1.0.0"
|
|
||||||
eventsource-parser: "npm:^3.0.3"
|
|
||||||
zod-to-json-schema: "npm:^3.24.1"
|
|
||||||
peerDependencies:
|
|
||||||
zod: ^3.25.76 || ^4
|
|
||||||
checksum: 10c0/4057810b320bda149a178dc1bfc9cdd592ca88b736c3c22bd0c1f8111c75ef69beec4a523f363e5d0d120348b876942fd66c0bb4965864da4c12c5cfddee15a3
|
|
||||||
languageName: node
|
|
||||||
linkType: hard
|
|
||||||
|
|
||||||
"@ai-sdk/provider-utils@npm:3.0.7":
|
"@ai-sdk/provider-utils@npm:3.0.7":
|
||||||
version: 3.0.7
|
version: 3.0.7
|
||||||
resolution: "@ai-sdk/provider-utils@npm:3.0.7"
|
resolution: "@ai-sdk/provider-utils@npm:3.0.7"
|
||||||
@ -294,6 +280,19 @@ __metadata:
|
|||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
|
"@ai-sdk/provider-utils@npm:3.0.8":
|
||||||
|
version: 3.0.8
|
||||||
|
resolution: "@ai-sdk/provider-utils@npm:3.0.8"
|
||||||
|
dependencies:
|
||||||
|
"@ai-sdk/provider": "npm:2.0.0"
|
||||||
|
"@standard-schema/spec": "npm:^1.0.0"
|
||||||
|
eventsource-parser: "npm:^3.0.5"
|
||||||
|
peerDependencies:
|
||||||
|
zod: ^3.25.76 || ^4
|
||||||
|
checksum: 10c0/f466657c886cbb9f7ecbcd2dd1abc51a88af9d3f1cff030f7e97e70a4790a99f3338ad886e9c0dccf04dacdcc84522c7d57119b9a4e8e1d84f2dae9c893c397e
|
||||||
|
languageName: node
|
||||||
|
linkType: hard
|
||||||
|
|
||||||
"@ai-sdk/provider@npm:2.0.0, @ai-sdk/provider@npm:^2.0.0":
|
"@ai-sdk/provider@npm:2.0.0, @ai-sdk/provider@npm:^2.0.0":
|
||||||
version: 2.0.0
|
version: 2.0.0
|
||||||
resolution: "@ai-sdk/provider@npm:2.0.0"
|
resolution: "@ai-sdk/provider@npm:2.0.0"
|
||||||
@ -2249,8 +2248,8 @@ __metadata:
|
|||||||
"@ai-sdk/anthropic": "npm:^2.0.5"
|
"@ai-sdk/anthropic": "npm:^2.0.5"
|
||||||
"@ai-sdk/azure": "npm:^2.0.16"
|
"@ai-sdk/azure": "npm:^2.0.16"
|
||||||
"@ai-sdk/deepseek": "npm:^1.0.9"
|
"@ai-sdk/deepseek": "npm:^1.0.9"
|
||||||
"@ai-sdk/google": "npm:^2.0.7"
|
"@ai-sdk/google": "npm:^2.0.13"
|
||||||
"@ai-sdk/openai": "npm:^2.0.19"
|
"@ai-sdk/openai": "npm:^2.0.26"
|
||||||
"@ai-sdk/openai-compatible": "npm:^1.0.9"
|
"@ai-sdk/openai-compatible": "npm:^1.0.9"
|
||||||
"@ai-sdk/provider": "npm:^2.0.0"
|
"@ai-sdk/provider": "npm:^2.0.0"
|
||||||
"@ai-sdk/provider-utils": "npm:^3.0.4"
|
"@ai-sdk/provider-utils": "npm:^3.0.4"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user