From ff7ad52ad5b3037e8b3ed01ef64e80dede16613d Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat Date: Fri, 8 Aug 2025 15:20:02 +0800 Subject: [PATCH] feat(tests): add unit tests for utility functions in utils.test.ts - Implemented tests for `createErrorChunk`, `capitalize`, and `isAsyncIterable` functions. - Ensured comprehensive coverage for various input scenarios, including error handling and edge cases. --- .../middleware/__tests__/utils.test.ts | 0 .../plugins/searchOrchestrationPlugin.ts | 22 ++++- .../src/aiCore/tools/KnowledgeSearchTool.ts | 82 ++++++++++++------- .../src/aiCore/tools/WebSearchTool.ts | 41 +--------- .../src/services/SpanManagerService.ts | 2 +- src/renderer/src/services/WebSearchService.ts | 61 +++++++------- 6 files changed, 105 insertions(+), 103 deletions(-) rename src/renderer/src/aiCore/{ => legacy}/middleware/__tests__/utils.test.ts (100%) diff --git a/src/renderer/src/aiCore/middleware/__tests__/utils.test.ts b/src/renderer/src/aiCore/legacy/middleware/__tests__/utils.test.ts similarity index 100% rename from src/renderer/src/aiCore/middleware/__tests__/utils.test.ts rename to src/renderer/src/aiCore/legacy/middleware/__tests__/utils.test.ts diff --git a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts index cbf1f0baeb..68f556f354 100644 --- a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts +++ b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts @@ -332,9 +332,18 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => { if (hasKnowledgeBase) { if (knowledgeRecognition === 'off') { - // off 模式:直接添加知识库搜索工具,跳过意图识别 + // off 模式:直接添加知识库搜索工具,使用用户消息作为搜索关键词 + const userMessage = userMessages[context.requestId] + const fallbackKeywords = { + question: [getMessageContent(userMessage) || 'search'], + rewrite: getMessageContent(userMessage) || 'search' + } console.log('📚 [SearchOrchestration] Adding knowledge search tool (force mode)') - params.tools['builtin_knowledge_search'] = knowledgeSearchTool(assistant) + params.tools['builtin_knowledge_search'] = knowledgeSearchTool( + assistant, + fallbackKeywords, + getMessageContent(userMessage) + ) params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' } } else { // on 模式:根据意图识别结果决定是否添加工具 @@ -343,9 +352,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => { analysisResult.knowledge.question && analysisResult.knowledge.question[0] !== 'not_needed' - if (needsKnowledgeSearch) { + if (needsKnowledgeSearch && analysisResult.knowledge) { console.log('📚 [SearchOrchestration] Adding knowledge search tool (intent-based)') - params.tools['builtin_knowledge_search'] = knowledgeSearchTool(assistant) + const userMessage = userMessages[context.requestId] + params.tools['builtin_knowledge_search'] = knowledgeSearchTool( + assistant, + analysisResult.knowledge, + getMessageContent(userMessage) + ) } } } diff --git a/src/renderer/src/aiCore/tools/KnowledgeSearchTool.ts b/src/renderer/src/aiCore/tools/KnowledgeSearchTool.ts index 3626506707..7d4d85d377 100644 --- a/src/renderer/src/aiCore/tools/KnowledgeSearchTool.ts +++ b/src/renderer/src/aiCore/tools/KnowledgeSearchTool.ts @@ -1,30 +1,37 @@ import { processKnowledgeSearch } from '@renderer/services/KnowledgeService' import type { Assistant, KnowledgeReference } from '@renderer/types' -import { ExtractResults } from '@renderer/utils/extract' +import { ExtractResults, KnowledgeExtractResults } from '@renderer/utils/extract' import { type InferToolInput, type InferToolOutput, tool } from 'ai' import { isEmpty } from 'lodash' import { z } from 'zod' -// Schema definitions - 添加 userMessage 字段来获取用户消息 -const KnowledgeSearchInputSchema = z.object({ - query: z.string().describe('The search query for knowledge base'), - rewrite: z.string().optional().describe('Optional rewritten query with alternative phrasing'), - userMessage: z.string().describe('The original user message content for direct search mode') -}) - -export type KnowledgeSearchToolInput = InferToolInput> -export type KnowledgeSearchToolOutput = InferToolOutput> - /** * 知识库搜索工具 - * 基于 ApiService.ts 中的 searchKnowledgeBase 逻辑实现 + * 使用预提取关键词,直接使用插件阶段分析的搜索意图,避免重复分析 */ -export const knowledgeSearchTool = (assistant: Assistant) => { +export const knowledgeSearchTool = ( + assistant: Assistant, + extractedKeywords: KnowledgeExtractResults, + userMessage?: string +) => { return tool({ name: 'builtin_knowledge_search', - description: 'Search the knowledge base for relevant information', - inputSchema: KnowledgeSearchInputSchema, - execute: async ({ query, rewrite, userMessage }) => { + description: `Search the knowledge base for relevant information using pre-analyzed search intent. + +Pre-extracted search queries: "${extractedKeywords.question.join(', ')}" +Rewritten query: "${extractedKeywords.rewrite}" + +This tool searches your knowledge base for relevant documents and returns results for easy reference. +Call this tool to execute the search. You can optionally provide additional context to refine the search.`, + + inputSchema: z.object({ + additionalContext: z + .string() + .optional() + .describe('Optional additional context or specific focus to enhance the knowledge search') + }), + + execute: async ({ additionalContext }) => { try { // 获取助手的知识库配置 const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) @@ -36,35 +43,51 @@ export const knowledgeSearchTool = (assistant: Assistant) => { return [] } - // 构建搜索条件 - 复制原逻辑 + let finalQueries = [...extractedKeywords.question] + let finalRewrite = extractedKeywords.rewrite + + if (additionalContext?.trim()) { + // 如果大模型提供了额外上下文,使用更具体的描述 + console.log(`🔍 AI enhanced knowledge search with: ${additionalContext}`) + const cleanContext = additionalContext.trim() + if (cleanContext) { + finalQueries = [cleanContext] + finalRewrite = cleanContext + console.log(`➕ Added additional context: ${cleanContext}`) + } + } + + // 检查是否需要搜索 + if (finalQueries[0] === 'not_needed') { + return [] + } + + // 构建搜索条件 let searchCriteria: { question: string[]; rewrite: string } if (knowledgeRecognition === 'off') { - // 直接模式:使用用户消息内容 (类似原逻辑的 getMainTextContent(lastUserMessage)) - const directContent = userMessage || query || 'search' + // 直接模式:使用用户消息内容 + const directContent = userMessage || finalQueries[0] || 'search' searchCriteria = { question: [directContent], rewrite: directContent } } else { - // 自动模式:使用意图识别的结果 (类似原逻辑的 extractResults.knowledge) + // 自动模式:使用意图识别的结果 searchCriteria = { - question: [query], - rewrite: rewrite || query + question: finalQueries, + rewrite: finalRewrite } } - // 检查是否需要搜索 - if (searchCriteria.question[0] === 'not_needed') { - return [] - } - - // 构建 ExtractResults 对象 - 与原逻辑一致 + // 构建 ExtractResults 对象 const extractResults: ExtractResults = { websearch: undefined, knowledge: searchCriteria } + console.log('Knowledge search extractResults:', extractResults) + // 执行知识库搜索 const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds) @@ -86,4 +109,7 @@ export const knowledgeSearchTool = (assistant: Assistant) => { }) } +export type KnowledgeSearchToolInput = InferToolInput> +export type KnowledgeSearchToolOutput = InferToolOutput> + export default knowledgeSearchTool diff --git a/src/renderer/src/aiCore/tools/WebSearchTool.ts b/src/renderer/src/aiCore/tools/WebSearchTool.ts index 8334089866..cbad5f2040 100644 --- a/src/renderer/src/aiCore/tools/WebSearchTool.ts +++ b/src/renderer/src/aiCore/tools/WebSearchTool.ts @@ -5,42 +5,6 @@ import { ExtractResults } from '@renderer/utils/extract' import { InferToolInput, InferToolOutput, tool } from 'ai' import { z } from 'zod' -// import { AiSdkTool, ToolCallResult } from './types' - -// const WebSearchResult = z.array( -// z.object({ -// query: z.string().optional(), -// results: z.array( -// z.object({ -// title: z.string(), -// content: z.string(), -// url: z.string() -// }) -// ) -// }) -// ) -// const webSearchToolInputSchema = z.object({ -// query: z.string().describe('The query to search for') -// }) - -// export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => { -// const webSearchService = WebSearchService.getInstance(webSearchProviderId) -// return tool({ -// name: 'builtin_web_search', -// description: 'Search the web for information', -// inputSchema: webSearchToolInputSchema, -// outputSchema: WebSearchProviderResult, -// execute: async ({ query }) => { -// console.log('webSearchTool', query) -// const response = await webSearchService.search(query) -// console.log('webSearchTool response', response) -// return response -// } -// }) -// } -// export type WebSearchToolInput = InferToolInput> -// export type WebSearchToolOutput = InferToolOutput> - /** * 使用预提取关键词的网络搜索工具 * 这个工具直接使用插件阶段分析的搜索意图,避免重复分析 @@ -53,7 +17,7 @@ export const webSearchToolWithPreExtractedKeywords = ( }, requestId: string ) => { - const webSearchService = WebSearchService.getInstance(webSearchProviderId) + const webSearchProvider = WebSearchService.getWebSearchProvider(webSearchProviderId) return tool({ name: 'builtin_web_search', @@ -112,7 +76,8 @@ Call this tool to execute the search. You can optionally provide additional cont } } console.log('extractResults', extractResults) - const response = await webSearchService.processWebsearch(extractResults, requestId) + console.log('webSearchProvider', webSearchProvider) + const response = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId) searchResults.push(response) } catch (error) { console.error(`Web search failed for query "${finalQueries}":`, error) diff --git a/src/renderer/src/services/SpanManagerService.ts b/src/renderer/src/services/SpanManagerService.ts index 13b056f7ea..6ab6951647 100644 --- a/src/renderer/src/services/SpanManagerService.ts +++ b/src/renderer/src/services/SpanManagerService.ts @@ -3,7 +3,7 @@ import { loggerService } from '@logger' import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' import { cleanContext, endContext, getContext, startContext } from '@mcp-trace/trace-web' import { Context, context, Span, SpanStatusCode, trace } from '@opentelemetry/api' -import { isAsyncIterable } from '@renderer/aiCore/middleware/utils' +import { isAsyncIterable } from '@renderer/aiCore/legacy/middleware/utils' import { db } from '@renderer/databases' import { getEnableDeveloperMode } from '@renderer/hooks/useSettings' import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService' diff --git a/src/renderer/src/services/WebSearchService.ts b/src/renderer/src/services/WebSearchService.ts index 4140bb0ae6..db7bb57055 100644 --- a/src/renderer/src/services/WebSearchService.ts +++ b/src/renderer/src/services/WebSearchService.ts @@ -40,21 +40,7 @@ interface RequestState { /** * 提供网络搜索相关功能的服务类 */ -export default class WebSearchService { - private static instance: WebSearchService - private webSearchProviderId: WebSearchProvider['id'] - - private constructor(webSearchProviderId: WebSearchProvider['id']) { - this.webSearchProviderId = webSearchProviderId - } - - public static getInstance(webSearchProviderId: WebSearchProvider['id']): WebSearchService { - if (!WebSearchService.instance) { - WebSearchService.instance = new WebSearchService(webSearchProviderId) - } - return WebSearchService.instance - } - +class WebSearchService { /** * 是否暂停 */ @@ -113,7 +99,7 @@ export default class WebSearchService { * @private * @returns 网络搜索状态 */ - private static getWebSearchState(): WebSearchState { + private getWebSearchState(): WebSearchState { return store.getState().websearch } @@ -122,8 +108,8 @@ export default class WebSearchService { * @public * @returns 如果默认搜索提供商已启用则返回true,否则返回false */ - public static isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean { - const { providers } = WebSearchService.getWebSearchState() + public isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean { + const { providers } = this.getWebSearchState() const provider = providers.find((provider) => provider.id === providerId) if (!provider) { @@ -153,7 +139,7 @@ export default class WebSearchService { * @returns 如果启用覆盖搜索则返回true,否则返回false */ public isOverwriteEnabled(): boolean { - const { overwrite } = WebSearchService.getWebSearchState() + const { overwrite } = this.getWebSearchState() return overwrite } @@ -163,7 +149,8 @@ export default class WebSearchService { * @returns 网络搜索提供商 */ public getWebSearchProvider(providerId?: string): WebSearchProvider | undefined { - const { providers } = WebSearchService.getWebSearchState() + const { providers } = this.getWebSearchState() + console.log('providers', providers) const provider = providers.find((provider) => provider.id === providerId) return provider @@ -172,16 +159,18 @@ export default class WebSearchService { /** * 使用指定的提供商执行网络搜索 * @public + * @param provider 搜索提供商 * @param query 搜索查询 * @returns 搜索响应 */ - public async search(query: string, httpOptions?: RequestInit, spanId?: string): Promise { - const websearch = WebSearchService.getWebSearchState() - const webSearchProvider = this.getWebSearchProvider(this.webSearchProviderId) - if (!webSearchProvider) { - throw new Error(`WebSearchProvider ${this.webSearchProviderId} not found`) - } - const webSearchEngine = new WebSearchEngineProvider(webSearchProvider, spanId) + public async search( + provider: WebSearchProvider, + query: string, + httpOptions?: RequestInit, + spanId?: string + ): Promise { + const websearch = this.getWebSearchState() + const webSearchEngine = new WebSearchEngineProvider(provider, spanId) let formattedQuery = query // FIXME: 有待商榷,效果一般 @@ -203,9 +192,9 @@ export default class WebSearchService { * @param provider 要检查的搜索提供商 * @returns 如果提供商可用返回true,否则返回false */ - public async checkSearch(): Promise<{ valid: boolean; error?: any }> { + public async checkSearch(provider: WebSearchProvider): Promise<{ valid: boolean; error?: any }> { try { - const response = await this.search('test query') + const response = await this.search(provider, 'test query') logger.debug('Search response:', response) // 优化的判断条件:检查结果是否有效且没有错误 return { valid: response.results !== undefined, error: undefined } @@ -437,7 +426,11 @@ export default class WebSearchService { * * @returns 包含搜索结果的响应对象 */ - public async processWebsearch(extractResults: ExtractResults, requestId: string): Promise { + public async processWebsearch( + webSearchProvider: WebSearchProvider, + extractResults: ExtractResults, + requestId: string + ): Promise { // 重置状态 await this.setWebSearchStatus(requestId, { phase: 'default' }) @@ -479,7 +472,9 @@ export default class WebSearchService { return { query: 'summaries', results: contents } } - const searchPromises = questions.map((q) => this.search(q, { signal }, span?.spanContext().spanId)) + const searchPromises = questions.map((q) => + this.search(webSearchProvider, q, { signal }, span?.spanContext().spanId) + ) const searchResults = await Promise.allSettled(searchPromises) // 统计成功完成的搜索数量 @@ -524,7 +519,7 @@ export default class WebSearchService { } } - const { compressionConfig } = WebSearchService.getWebSearchState() + const { compressionConfig } = this.getWebSearchState() // RAG压缩处理 if (compressionConfig?.method === 'rag' && requestId) { @@ -578,3 +573,5 @@ export default class WebSearchService { } } } + +export default new WebSearchService()