From a05d7cbe2dbbfa9bb215a1dea963ec58f4a4d3e4 Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat Date: Tue, 29 Jul 2025 12:16:06 +0800 Subject: [PATCH] refactor: enhance search orchestration and web search tool integration - Updated `searchOrchestrationPlugin` to improve handling of assistant configurations and prevent concurrent analysis. - Refactored `webSearchTool` to utilize pre-extracted keywords for more efficient web searches. - Introduced a new `MessageKnowledgeSearch` component for displaying knowledge search results. - Cleaned up commented-out code and improved type safety across various components. - Enhanced the integration of web search results in the UI for better user experience. --- .../src/aiCore/chunk/AiSdkToChunkAdapter.ts | 1 - src/renderer/src/aiCore/index_new.ts | 4 +- .../aisdk/AiSdkMiddlewareBuilder.ts | 2 +- .../plugins/searchOrchestrationPlugin.ts | 114 +++--- .../src/aiCore/tools/WebSearchTool.ts | 335 ++++++++++++------ .../Messages/Tools/MessageKnowledgeSearch.tsx | 72 ++++ .../pages/home/Messages/Tools/MessageTool.tsx | 53 ++- .../Messages/Tools/MessageWebSearchTool.tsx | 26 +- src/renderer/src/store/thunk/messageThunk.ts | 1 - 9 files changed, 433 insertions(+), 175 deletions(-) create mode 100644 src/renderer/src/pages/home/Messages/Tools/MessageKnowledgeSearch.tsx diff --git a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts index e7a76757da..e9b5430863 100644 --- a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts +++ b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts @@ -97,7 +97,6 @@ export class AiSdkToChunkAdapter { type: ChunkType.TEXT_DELTA, text: final.text || '' }) - console.log('final.text', final.text) break case 'text-end': this.onChunk({ diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index cfba0db37a..b5cffe8b76 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -21,7 +21,7 @@ import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-cor import { isDedicatedImageGenerationModel } from '@renderer/config/models' import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import { getProviderByModel } from '@renderer/services/AssistantService' -import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types' +import type { GenerateImageParams, Model, Provider } from '@renderer/types' import { formatApiHost } from '@renderer/utils/api' import { cloneDeep } from 'lodash' @@ -163,7 +163,7 @@ export default class ModernAiProvider { // 内置了默认搜索参数,如果改的话可以传config进去 plugins.push(webSearchPlugin()) } - plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant as Assistant)) + plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant)) // 2. 推理模型时添加推理插件 if (middlewareConfig.enableReasoning) { diff --git a/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts index b0e8e70554..e63af4bcad 100644 --- a/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts @@ -20,7 +20,7 @@ export interface AiSdkMiddlewareConfig { enableWebSearch?: boolean mcpTools?: BaseTool[] // TODO assistant - assistant?: Assistant + assistant: Assistant } /** diff --git a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts index 801f2ccd9a..916ee074ec 100644 --- a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts +++ b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts @@ -8,7 +8,6 @@ */ import type { AiRequestContext, ModelMessage } from '@cherrystudio/ai-core' import { definePlugin } from '@cherrystudio/ai-core' -import { RuntimeExecutor } from '@cherrystudio/ai-core/core/runtime/executor' // import { generateObject } from '@cherrystudio/ai-core' import { SEARCH_SUMMARY_PROMPT, @@ -19,13 +18,13 @@ import { getDefaultModel, getProviderByModel } from '@renderer/services/Assistan import store from '@renderer/store' import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory' import type { Assistant } from '@renderer/types' +import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract' import { isEmpty } from 'lodash' -import { z } from 'zod' import { MemoryProcessor } from '../../services/MemoryProcessor' import { knowledgeSearchTool } from '../tools/KnowledgeSearchTool' import { memorySearchTool } from '../tools/MemorySearchTool' -import { webSearchTool } from '../tools/WebSearchTool' +import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool' const getMessageContent = (message: ModelMessage) => { if (typeof message.content === 'string') return message.content @@ -39,32 +38,32 @@ const getMessageContent = (message: ModelMessage) => { // === Schema Definitions === -const WebSearchSchema = z.object({ - question: z - .array(z.string()) - .describe('Search queries for web search. Use "not_needed" if no web search is required.'), - links: z.array(z.string()).optional().describe('Specific URLs to search or summarize if mentioned in the query.') -}) +// const WebSearchSchema = z.object({ +// question: z +// .array(z.string()) +// .describe('Search queries for web search. Use "not_needed" if no web search is required.'), +// links: z.array(z.string()).optional().describe('Specific URLs to search or summarize if mentioned in the query.') +// }) -const KnowledgeSearchSchema = z.object({ - question: z - .array(z.string()) - .describe('Search queries for knowledge base. Use "not_needed" if no knowledge search is required.'), - rewrite: z - .string() - .describe('Rewritten query with alternative phrasing while preserving original intent and meaning.') -}) +// const KnowledgeSearchSchema = z.object({ +// question: z +// .array(z.string()) +// .describe('Search queries for knowledge base. Use "not_needed" if no knowledge search is required.'), +// rewrite: z +// .string() +// .describe('Rewritten query with alternative phrasing while preserving original intent and meaning.') +// }) -const SearchIntentAnalysisSchema = z.object({ - websearch: WebSearchSchema.optional().describe('Web search intent analysis results.'), - knowledge: KnowledgeSearchSchema.optional().describe('Knowledge base search intent analysis results.') -}) +// const SearchIntentAnalysisSchema = z.object({ +// websearch: WebSearchSchema.optional().describe('Web search intent analysis results.'), +// knowledge: KnowledgeSearchSchema.optional().describe('Knowledge base search intent analysis results.') +// }) -type SearchIntentResult = z.infer +// type SearchIntentResult = z.infer -let isAnalyzing = false +// let isAnalyzing = false /** - * 🧠 意图分析函数 - 使用结构化输出重构 + * 🧠 意图分析函数 - 使用 XML 解析 */ async function analyzeSearchIntent( lastUserMessage: ModelMessage, @@ -74,13 +73,11 @@ async function analyzeSearchIntent( shouldKnowledgeSearch?: boolean shouldMemorySearch?: boolean lastAnswer?: ModelMessage - context?: - | AiRequestContext - | { - executor: RuntimeExecutor - } - } = {} -): Promise { + context: AiRequestContext & { + isAnalyzing?: boolean + } + } +): Promise { const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer, context } = options if (!lastUserMessage) return undefined @@ -91,19 +88,19 @@ async function analyzeSearchIntent( if (!needWebExtract && !needKnowledgeExtract) return undefined - // 选择合适的提示词和schema + // 选择合适的提示词 let prompt: string - let schema: z.Schema + // let schema: z.Schema if (needWebExtract && !needKnowledgeExtract) { prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY - schema = z.object({ websearch: WebSearchSchema }) + // schema = z.object({ websearch: WebSearchSchema }) } else if (!needWebExtract && needKnowledgeExtract) { prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY - schema = z.object({ knowledge: KnowledgeSearchSchema }) + // schema = z.object({ knowledge: KnowledgeSearchSchema }) } else { prompt = SEARCH_SUMMARY_PROMPT - schema = SearchIntentAnalysisSchema + // schema = SearchIntentAnalysisSchema } // 构建消息上下文 - 简化逻辑 @@ -121,16 +118,15 @@ async function analyzeSearchIntent( console.error('Provider not found or missing API key') return getFallbackResult() } - + // console.log('formattedPrompt', schema) try { - isAnalyzing = true - const result = await context?.executor?.generateObject(model.id, { - schema, + context.isAnalyzing = true + const { text: result } = await context.executor.generateText(model.id, { prompt: formattedPrompt }) - isAnalyzing = false - console.log('result', context) - const parsedResult = result?.object as SearchIntentResult + context.isAnalyzing = false + const parsedResult = extractInfoFromXML(result) + console.log('parsedResult', parsedResult) // 根据需求过滤结果 return { @@ -142,7 +138,7 @@ async function analyzeSearchIntent( return getFallbackResult() } - function getFallbackResult(): SearchIntentResult { + function getFallbackResult(): ExtractResults { const fallbackContent = getMessageContent(lastUserMessage) return { websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined, @@ -159,7 +155,11 @@ async function analyzeSearchIntent( /** * 🧠 记忆存储函数 - 基于注释代码中的 processConversationMemory */ -async function storeConversationMemory(messages: ModelMessage[], assistant: Assistant): Promise { +async function storeConversationMemory( + messages: ModelMessage[], + assistant: Assistant, + context: AiRequestContext +): Promise { const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) if (!globalMemoryEnabled || !assistant.enableMemory) { @@ -185,14 +185,13 @@ async function storeConversationMemory(messages: ModelMessage[], assistant: Assi } const currentUserId = selectCurrentUserId(store.getState()) - const lastUserMessage = messages.findLast((m) => m.role === 'user') + // const lastUserMessage = messages.findLast((m) => m.role === 'user') const processorConfig = MemoryProcessor.getProcessorConfig( memoryConfig, assistant.id, currentUserId, - // TODO - lastUserMessage?.id + context.requestId ) console.log('Processing conversation memory...', { messageCount: conversationMessages.length }) @@ -224,9 +223,10 @@ async function storeConversationMemory(messages: ModelMessage[], assistant: Assi */ export const searchOrchestrationPlugin = (assistant: Assistant) => { // 存储意图分析结果 - const intentAnalysisResults: { [requestId: string]: SearchIntentResult } = {} + const intentAnalysisResults: { [requestId: string]: ExtractResults } = {} const userMessages: { [requestId: string]: ModelMessage } = {} console.log('searchOrchestrationPlugin', assistant) + return definePlugin({ name: 'search-orchestration', enforce: 'pre', // 确保在其他插件之前执行 @@ -235,7 +235,8 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => { * 🔍 Step 1: 意图识别阶段 */ onRequestStart: async (context: AiRequestContext) => { - if (isAnalyzing) return + console.log('onRequestStart', context.isAnalyzing) + if (context.isAnalyzing) return console.log('🧠 [SearchOrchestration] Starting intent analysis...', context.requestId) try { @@ -294,7 +295,7 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => { * 🔧 Step 2: 工具配置阶段 */ transformParams: async (params: any, context: AiRequestContext) => { - if (isAnalyzing) return + if (context.isAnalyzing) return params console.log('🔧 [SearchOrchestration] Configuring tools based on intent...', context.requestId) try { @@ -314,8 +315,13 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => { const needsSearch = analysisResult.websearch.question && analysisResult.websearch.question[0] !== 'not_needed' if (needsSearch) { - console.log('🌐 [SearchOrchestration] Adding web search tool') - params.tools['builtin_web_search'] = webSearchTool(assistant.webSearchProviderId) + // onChunk({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS }) + console.log('🌐 [SearchOrchestration] Adding web search tool with pre-extracted keywords') + params.tools['builtin_web_search'] = webSearchToolWithPreExtractedKeywords( + assistant.webSearchProviderId, + analysisResult.websearch, + context.requestId + ) } } @@ -370,7 +376,7 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => { const messages = context.originalParams.messages if (messages && assistant) { - await storeConversationMemory(messages, assistant) + await storeConversationMemory(messages, assistant, context) } // 清理缓存 diff --git a/src/renderer/src/aiCore/tools/WebSearchTool.ts b/src/renderer/src/aiCore/tools/WebSearchTool.ts index cb1b118b11..8334089866 100644 --- a/src/renderer/src/aiCore/tools/WebSearchTool.ts +++ b/src/renderer/src/aiCore/tools/WebSearchTool.ts @@ -1,131 +1,262 @@ -import { extractSearchKeywords } from '@renderer/aiCore/transformParameters' +import { REFERENCE_PROMPT } from '@renderer/config/prompts' import WebSearchService from '@renderer/services/WebSearchService' -import { Assistant, Message, WebSearchProvider } from '@renderer/types' -import { UserMessageStatus } from '@renderer/types/newMessage' +import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' import { ExtractResults } from '@renderer/utils/extract' -import { type InferToolInput, type InferToolOutput, tool } from 'ai' +import { InferToolInput, InferToolOutput, tool } from 'ai' import { z } from 'zod' // import { AiSdkTool, ToolCallResult } from './types' -const WebSearchProviderResult = z.object({ - query: z.string().optional(), - results: z.array( - z.object({ - title: z.string(), - content: z.string(), - url: z.string() - }) - ) -}) -const webSearchToolInputSchema = z.object({ - query: z.string().describe('The query to search for') -}) +// const WebSearchResult = z.array( +// z.object({ +// query: z.string().optional(), +// results: z.array( +// z.object({ +// title: z.string(), +// content: z.string(), +// url: z.string() +// }) +// ) +// }) +// ) +// const webSearchToolInputSchema = z.object({ +// query: z.string().describe('The query to search for') +// }) -export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => { - const webSearchService = WebSearchService.getInstance(webSearchProviderId) - return tool({ - name: 'builtin_web_search', - description: 'Search the web for information', - inputSchema: webSearchToolInputSchema, - outputSchema: WebSearchProviderResult, - execute: async ({ query }) => { - console.log('webSearchTool', query) - const response = await webSearchService.search(query) - console.log('webSearchTool response', response) - return response - } - }) -} -export type WebSearchToolInput = InferToolInput> -export type WebSearchToolOutput = InferToolOutput> +// export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => { +// const webSearchService = WebSearchService.getInstance(webSearchProviderId) +// return tool({ +// name: 'builtin_web_search', +// description: 'Search the web for information', +// inputSchema: webSearchToolInputSchema, +// outputSchema: WebSearchProviderResult, +// execute: async ({ query }) => { +// console.log('webSearchTool', query) +// const response = await webSearchService.search(query) +// console.log('webSearchTool response', response) +// return response +// } +// }) +// } +// export type WebSearchToolInput = InferToolInput> +// export type WebSearchToolOutput = InferToolOutput> -export const webSearchToolWithExtraction = ( +/** + * 使用预提取关键词的网络搜索工具 + * 这个工具直接使用插件阶段分析的搜索意图,避免重复分析 + */ +export const webSearchToolWithPreExtractedKeywords = ( webSearchProviderId: WebSearchProvider['id'], - requestId: string, - assistant: Assistant + extractedKeywords: { + question: string[] + links?: string[] + }, + requestId: string ) => { const webSearchService = WebSearchService.getInstance(webSearchProviderId) return tool({ - name: 'web_search_with_extraction', - description: 'Search the web for information with automatic keyword extraction from user messages', + name: 'builtin_web_search', + description: `Search the web and return citable sources using pre-analyzed search intent. + +Pre-extracted search keywords: "${extractedKeywords.question.join(', ')}"${ + extractedKeywords.links + ? ` +Relevant links: ${extractedKeywords.links.join(', ')}` + : '' + } + +This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response. + +Call this tool to execute the search. You can optionally provide additional context to refine the search.`, + inputSchema: z.object({ - userMessage: z.object({ - content: z.string().describe('The main content of the message'), - role: z.enum(['user', 'assistant', 'system']).describe('Message role') - }), - lastAnswer: z.object({ - content: z.string().describe('The main content of the message'), - role: z.enum(['user', 'assistant', 'system']).describe('Message role') - }) + additionalContext: z + .string() + .optional() + .describe('Optional additional context, keywords, or specific focus to enhance the search') }), - outputSchema: z.object({ - extractedKeywords: z.object({ - question: z.array(z.string()), - links: z.array(z.string()).optional() - }), - searchResults: z.array( - z.object({ - query: z.string(), - results: WebSearchProviderResult - }) - ) - }), - execute: async ({ userMessage, lastAnswer }) => { - const lastUserMessage: Message = { - id: requestId, - role: userMessage.role, - assistantId: assistant.id, - topicId: 'temp', - createdAt: new Date().toISOString(), - status: UserMessageStatus.SUCCESS, - blocks: [] + + execute: async ({ additionalContext }) => { + let finalQueries = [...extractedKeywords.question] + + if (additionalContext?.trim()) { + // 如果大模型提供了额外上下文,使用更具体的描述 + console.log(`🔍 AI enhanced search with: ${additionalContext}`) + const cleanContext = additionalContext.trim() + if (cleanContext) { + finalQueries = [cleanContext] + console.log(`➕ Added additional context: ${cleanContext}`) + } } - const lastAnswerMessage: Message | undefined = lastAnswer - ? { - id: requestId + '_answer', - role: lastAnswer.role, - assistantId: assistant.id, - topicId: 'temp', - createdAt: new Date().toISOString(), - status: UserMessageStatus.SUCCESS, - blocks: [] - } - : undefined + const searchResults: WebSearchProviderResponse[] = [] - const extractResults = await extractSearchKeywords(lastUserMessage, assistant, { - shouldWebSearch: true, - shouldKnowledgeSearch: false, - lastAnswer: lastAnswerMessage - }) - - if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') { - return 'No search needed or extraction failed' + // 检查是否需要搜索 + if (finalQueries[0] === 'not_needed') { + return { + summary: 'No search needed based on the query analysis.', + searchResults: [], + sources: '', + instructions: '', + rawResults: [] + } } - const searchQueries = extractResults.websearch.question - const searchResults: Array<{ query: string; results: any }> = [] - - for (const query of searchQueries) { - // 构建单个查询的ExtractResults结构 - const queryExtractResults: ExtractResults = { + try { + // 构建 ExtractResults 结构用于 processWebsearch + const extractResults: ExtractResults = { websearch: { - question: [query], - links: extractResults.websearch.links + question: finalQueries, + links: extractedKeywords.links } } - const response = await webSearchService.processWebsearch(queryExtractResults, requestId) - searchResults.push({ - query, - results: response - }) + console.log('extractResults', extractResults) + const response = await webSearchService.processWebsearch(extractResults, requestId) + searchResults.push(response) + } catch (error) { + console.error(`Web search failed for query "${finalQueries}":`, error) + return { + summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`, + searchResults: [], + sources: '', + instructions: '', + rawResults: [] + } } - return { extractedKeywords: extractResults.websearch, searchResults } + if (searchResults.length === 0 || !searchResults[0].results) { + return { + summary: 'No search results found for the given query.', + searchResults: [], + sources: '', + instructions: '', + rawResults: [] + } + } + + const results = searchResults[0].results + const citationData = results.map((result, index) => ({ + number: index + 1, + title: result.title, + content: result.content, + url: result.url + })) + + // 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑 + const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\`` + + // 构建完整的引用指导文本 + const fullInstructions = REFERENCE_PROMPT.replace( + '{question}', + "Based on the search results, please answer the user's question with proper citations." + ).replace('{references}', referenceContent) + + return { + summary: `Found ${citationData.length} relevant sources. Use [number] format to cite specific information.`, + searchResults, + sources: citationData + .map((source) => `[${source.number}] ${source.title}\n${source.content}\nURL: ${source.url}`) + .join('\n\n'), + + instructions: fullInstructions, + + // 原始数据,便于后续处理 + rawResults: citationData + } } }) } -export type WebSearchToolWithExtractionOutput = InferToolOutput> +// export const webSearchToolWithExtraction = ( +// webSearchProviderId: WebSearchProvider['id'], +// requestId: string, +// assistant: Assistant +// ) => { +// const webSearchService = WebSearchService.getInstance(webSearchProviderId) + +// return tool({ +// name: 'web_search_with_extraction', +// description: 'Search the web for information with automatic keyword extraction from user messages', +// inputSchema: z.object({ +// userMessage: z.object({ +// content: z.string().describe('The main content of the message'), +// role: z.enum(['user', 'assistant', 'system']).describe('Message role') +// }), +// lastAnswer: z.object({ +// content: z.string().describe('The main content of the message'), +// role: z.enum(['user', 'assistant', 'system']).describe('Message role') +// }) +// }), +// outputSchema: z.object({ +// extractedKeywords: z.object({ +// question: z.array(z.string()), +// links: z.array(z.string()).optional() +// }), +// searchResults: z.array( +// z.object({ +// query: z.string(), +// results: WebSearchProviderResult +// }) +// ) +// }), +// execute: async ({ userMessage, lastAnswer }) => { +// const lastUserMessage: Message = { +// id: requestId, +// role: userMessage.role, +// assistantId: assistant.id, +// topicId: 'temp', +// createdAt: new Date().toISOString(), +// status: UserMessageStatus.SUCCESS, +// blocks: [] +// } + +// const lastAnswerMessage: Message | undefined = lastAnswer +// ? { +// id: requestId + '_answer', +// role: lastAnswer.role, +// assistantId: assistant.id, +// topicId: 'temp', +// createdAt: new Date().toISOString(), +// status: UserMessageStatus.SUCCESS, +// blocks: [] +// } +// : undefined + +// const extractResults = await extractSearchKeywords(lastUserMessage, assistant, { +// shouldWebSearch: true, +// shouldKnowledgeSearch: false, +// lastAnswer: lastAnswerMessage +// }) + +// if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') { +// return 'No search needed or extraction failed' +// } + +// const searchQueries = extractResults.websearch.question +// const searchResults: Array<{ query: string; results: any }> = [] + +// for (const query of searchQueries) { +// // 构建单个查询的ExtractResults结构 +// const queryExtractResults: ExtractResults = { +// websearch: { +// question: [query], +// links: extractResults.websearch.links +// } +// } +// const response = await webSearchService.processWebsearch(queryExtractResults, requestId) +// searchResults.push({ +// query, +// results: response +// }) +// } + +// return { extractedKeywords: extractResults.websearch, searchResults } +// } +// }) +// } + +// export type WebSearchToolWithExtractionOutput = InferToolOutput> + +export type WebSearchToolOutput = InferToolOutput> +export type WebSearchToolInput = InferToolInput> diff --git a/src/renderer/src/pages/home/Messages/Tools/MessageKnowledgeSearch.tsx b/src/renderer/src/pages/home/Messages/Tools/MessageKnowledgeSearch.tsx new file mode 100644 index 0000000000..d44c59e895 --- /dev/null +++ b/src/renderer/src/pages/home/Messages/Tools/MessageKnowledgeSearch.tsx @@ -0,0 +1,72 @@ +import { KnowledgeSearchToolInput, KnowledgeSearchToolOutput } from '@renderer/aiCore/tools/KnowledgeSearchTool' +import Spinner from '@renderer/components/Spinner' +import i18n from '@renderer/i18n' +import { MCPToolResponse } from '@renderer/types' +import { Typography } from 'antd' +import { FileSearch } from 'lucide-react' +import styled from 'styled-components' + +const { Text } = Typography +export function MessageKnowledgeSearchToolTitle({ toolResponse }: { toolResponse: MCPToolResponse }) { + const toolInput = toolResponse.arguments as KnowledgeSearchToolInput + const toolOutput = toolResponse.response as KnowledgeSearchToolOutput + + return toolResponse.status !== 'done' ? ( + + {i18n.t('message.searching')} + {toolInput?.rewrite ?? toolInput?.query ?? ''} + + } + /> + ) : ( + + + {i18n.t('message.websearch.fetch_complete', { count: toolOutput.length ?? 0 })} + + ) +} + +export function MessageKnowledgeSearchToolBody({ toolResponse }: { toolResponse: MCPToolResponse }) { + const toolOutput = toolResponse.response as KnowledgeSearchToolOutput + + return toolResponse.status === 'done' ? ( + + {toolOutput.map((result) => ( +
  • + {result.id} + {result.content} +
  • + ))} +
    + ) : null +} + +const PrepareToolWrapper = styled.span` + display: flex; + align-items: center; + gap: 4px; + font-size: 14px; + padding-left: 0; +` +const MessageWebSearchToolTitleTextWrapper = styled(Text)` + display: flex; + align-items: center; + gap: 4px; +` + +const MessageWebSearchToolBodyUlWrapper = styled.ul` + display: flex; + flex-direction: column; + gap: 4px; + padding: 0; + > li { + padding: 0; + margin: 0; + max-width: 70%; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + } +` diff --git a/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx b/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx index 272e68a439..a0c10dfcf4 100644 --- a/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx +++ b/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx @@ -1,11 +1,14 @@ +import { MCPToolResponse } from '@renderer/types' import type { ToolMessageBlock } from '@renderer/types/newMessage' import { Collapse } from 'antd' +import { MessageKnowledgeSearchToolBody, MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch' import { MessageWebSearchToolBody, MessageWebSearchToolTitle } from './MessageWebSearchTool' interface Props { block: ToolMessageBlock } +const prefix = 'builtin_' // const toolNameMapText = { // web_search: i18n.t('message.searching') @@ -41,18 +44,62 @@ interface Props { // return

    {toolDoneNameText}

    // } +// const ToolLabelComponents = ({ toolResponse }: { toolResponse: MCPToolResponse }) => { +// if (webSearchToolNames.includes(toolResponse.tool.name)) { +// return +// } +// return +// } + +// const ToolBodyComponents = ({ toolResponse }: { toolResponse: MCPToolResponse }) => { +// if (webSearchToolNames.includes(toolResponse.tool.name)) { +// return +// } +// return +// } + +const ChooseTool = ( + toolResponse: MCPToolResponse +): { + label: React.ReactNode + body: React.ReactNode +} => { + let toolName = toolResponse.tool.name + if (toolName.startsWith(prefix)) { + toolName = toolName.slice(prefix.length) + } + + switch (toolName) { + case 'web_search': + case 'web_search_preview': + return { + label: , + body: + } + case 'knowledge_search': + return { + label: , + body: + } + default: + return { + label: , + body: + } + } +} + export default function MessageTool({ block }: Props) { const toolResponse = block.metadata?.rawMcpToolResponse if (!toolResponse) return null - console.log('toolResponse', toolResponse) return ( , - children: , + label: ChooseTool(toolResponse).label, + children: ChooseTool(toolResponse).body, showArrow: false, styles: { header: { diff --git a/src/renderer/src/pages/home/Messages/Tools/MessageWebSearchTool.tsx b/src/renderer/src/pages/home/Messages/Tools/MessageWebSearchTool.tsx index f361d8a927..6ccb3c4c2b 100644 --- a/src/renderer/src/pages/home/Messages/Tools/MessageWebSearchTool.tsx +++ b/src/renderer/src/pages/home/Messages/Tools/MessageWebSearchTool.tsx @@ -17,14 +17,16 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT text={ {i18n.t('message.searching')} - {toolInput?.query ?? ''} + {toolInput?.additionalContext ?? ''} } /> ) : ( - {i18n.t('message.websearch.fetch_complete', { count: toolOutput.results.length ?? 0 })} + {i18n.t('message.websearch.fetch_complete', { + count: toolOutput?.searchResults?.reduce((acc, result) => acc + result.results.length, 0) ?? 0 + })} ) } @@ -32,15 +34,17 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT export const MessageWebSearchToolBody = ({ toolResponse }: { toolResponse: MCPToolResponse }) => { const toolOutput = toolResponse.response as WebSearchToolOutput - return toolResponse.status === 'done' ? ( - - {toolOutput.results.map((result) => ( -
  • - {result.title} -
  • - ))} -
    - ) : null + return toolResponse.status === 'done' + ? toolOutput?.searchResults?.map((result, index) => ( + + {result.results.map((item, index) => ( +
  • + {item.title} +
  • + ))} +
    + )) + : null } const PrepareToolWrapper = styled.span` diff --git a/src/renderer/src/store/thunk/messageThunk.ts b/src/renderer/src/store/thunk/messageThunk.ts index 141ea9b997..eed4bd57cf 100644 --- a/src/renderer/src/store/thunk/messageThunk.ts +++ b/src/renderer/src/store/thunk/messageThunk.ts @@ -881,7 +881,6 @@ const fetchAndProcessAssistantResponseImpl = async ( saveUpdatesToDB, assistant }) - console.log('callbacks', callbacks) const streamProcessorCallbacks = createStreamProcessor(callbacks) const abortController = new AbortController()