From ff378ca567384cbf4b4c6313d77dfcffd81a08af Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat Date: Thu, 9 Oct 2025 17:44:42 +0800 Subject: [PATCH] feat: enhance web search functionality with abort signal support - Updated WebSearchTool to accept an abort signal in the execute method. - Modified various WebSearchProvider classes to include httpOptions for search methods, allowing for abort signal handling. - Improved WebSearchService to prioritize external abort signals for better request management. - Enhanced MessageTool to reflect tool status with appropriate UI feedback. --- .../src/aiCore/tools/WebSearchTool.ts | 12 +++++-- .../pages/home/Messages/Tools/MessageTool.tsx | 22 +++++++++++-- .../WebSearchProvider/BochaProvider.ts | 9 ++++-- .../WebSearchProvider/DefaultProvider.ts | 3 +- .../WebSearchProvider/ExaProvider.ts | 23 ++++++++++++-- .../WebSearchProvider/SearxngProvider.ts | 8 +++-- .../WebSearchProvider/TavilyProvider.ts | 31 +++++++++++++++---- .../WebSearchProvider/ZhipuProvider.ts | 9 ++++-- src/renderer/src/services/WebSearchService.ts | 7 +++-- 9 files changed, 101 insertions(+), 23 deletions(-) diff --git a/src/renderer/src/aiCore/tools/WebSearchTool.ts b/src/renderer/src/aiCore/tools/WebSearchTool.ts index 2d6e318306..89859fc04e 100644 --- a/src/renderer/src/aiCore/tools/WebSearchTool.ts +++ b/src/renderer/src/aiCore/tools/WebSearchTool.ts @@ -40,7 +40,7 @@ You can use this tool as-is to search with the prepared queries, or provide addi .describe('Optional additional context, keywords, or specific focus to enhance the search') }), - execute: async ({ additionalContext }) => { + execute: async ({ additionalContext }, { abortSignal }) => { let finalQueries = [...extractedKeywords.question] if (additionalContext?.trim()) { @@ -67,7 +67,15 @@ You can use this tool as-is to search with the prepared queries, or provide addi links: extractedKeywords.links } } - searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId) + // abortSignal?.addEventListener('abort', () => { + // console.log('tool_call_abortSignal', abortSignal?.aborted) + // }) + searchResults = await WebSearchService.processWebsearch( + webSearchProvider!, + extractResults, + requestId, + abortSignal + ) return searchResults }, diff --git a/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx b/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx index 88cf2dab56..b25e7fe01e 100644 --- a/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx +++ b/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx @@ -1,5 +1,8 @@ import { NormalToolResponse } from '@renderer/types' -import type { ToolMessageBlock } from '@renderer/types/newMessage' +import { MessageBlockStatus, ToolMessageBlock } from '@renderer/types/newMessage' +import { TFunction } from 'i18next' +import { Pause } from 'lucide-react' +import { useTranslation } from 'react-i18next' import { MessageAgentTools } from './MessageAgentTools' import { MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch' @@ -35,11 +38,23 @@ const isAgentTool = (toolName: string) => { return false } -const ChooseTool = (toolResponse: NormalToolResponse): React.ReactNode | null => { +const ChooseTool = ( + toolResponse: NormalToolResponse, + status: MessageBlockStatus, + t: TFunction +): React.ReactNode | null => { let toolName = toolResponse.tool.name const toolType = toolResponse.tool.type if (toolName.startsWith(prefix)) { toolName = toolName.slice(prefix.length) + if (status === MessageBlockStatus.PAUSED) { + return ( +
+ + {t('message.tools.aborted')} +
+ ) + } switch (toolName) { case 'web_search': case 'web_search_preview': @@ -58,12 +73,13 @@ const ChooseTool = (toolResponse: NormalToolResponse): React.ReactNode | null => } export default function MessageTool({ block }: Props) { + const { t } = useTranslation() // FIXME: 语义错误,这里已经不是 MCP tool 了,更改rawMcpToolResponse需要改用户数据, 所以暂时保留 const toolResponse = block.metadata?.rawMcpToolResponse as NormalToolResponse if (!toolResponse) return null - const toolRenderer = ChooseTool(toolResponse as NormalToolResponse) + const toolRenderer = ChooseTool(toolResponse as NormalToolResponse, block.status, t) if (!toolRenderer) return null diff --git a/src/renderer/src/providers/WebSearchProvider/BochaProvider.ts b/src/renderer/src/providers/WebSearchProvider/BochaProvider.ts index 81d85df5d3..cf2be01e0c 100644 --- a/src/renderer/src/providers/WebSearchProvider/BochaProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/BochaProvider.ts @@ -18,7 +18,11 @@ export default class BochaProvider extends BaseWebSearchProvider { } } - public async search(query: string, websearch: WebSearchState): Promise { + public async search( + query: string, + websearch: WebSearchState, + httpOptions?: RequestInit + ): Promise { try { if (!query.trim()) { throw new Error('Search query cannot be empty') @@ -44,7 +48,8 @@ export default class BochaProvider extends BaseWebSearchProvider { headers: { ...this.defaultHeaders(), ...headers - } + }, + signal: httpOptions?.signal }) if (!response.ok) { diff --git a/src/renderer/src/providers/WebSearchProvider/DefaultProvider.ts b/src/renderer/src/providers/WebSearchProvider/DefaultProvider.ts index 25d2d46a43..9b00f52ea9 100644 --- a/src/renderer/src/providers/WebSearchProvider/DefaultProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/DefaultProvider.ts @@ -1,9 +1,10 @@ +import { WebSearchState } from '@renderer/store/websearch' import { WebSearchProviderResponse } from '@renderer/types' import BaseWebSearchProvider from './BaseWebSearchProvider' export default class DefaultProvider extends BaseWebSearchProvider { - search(): Promise { + search(_query: string, _websearch: WebSearchState, _httpOptions?: RequestInit): Promise { throw new Error('Method not implemented.') } } diff --git a/src/renderer/src/providers/WebSearchProvider/ExaProvider.ts b/src/renderer/src/providers/WebSearchProvider/ExaProvider.ts index a01ba93543..7cabebe95f 100644 --- a/src/renderer/src/providers/WebSearchProvider/ExaProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/ExaProvider.ts @@ -20,13 +20,18 @@ export default class ExaProvider extends BaseWebSearchProvider { this.exa = new ExaClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost }) } - public async search(query: string, websearch: WebSearchState): Promise { + public async search( + query: string, + websearch: WebSearchState, + httpOptions?: RequestInit + ): Promise { try { if (!query.trim()) { throw new Error('Search query cannot be empty') } - const response = await this.exa.search({ + // 使用 Promise.race 来支持 abort signal + const searchPromise = this.exa.search({ query, numResults: Math.max(1, websearch.maxResults), contents: { @@ -34,6 +39,20 @@ export default class ExaProvider extends BaseWebSearchProvider { } }) + let response: Awaited + if (httpOptions?.signal) { + response = await Promise.race([ + searchPromise, + new Promise((_, reject) => { + httpOptions.signal?.addEventListener('abort', () => { + reject(new DOMException('The operation was aborted.', 'AbortError')) + }) + }) + ]) + } else { + response = await searchPromise + } + return { query: response.autopromptString, results: response.results.slice(0, websearch.maxResults).map((result) => { diff --git a/src/renderer/src/providers/WebSearchProvider/SearxngProvider.ts b/src/renderer/src/providers/WebSearchProvider/SearxngProvider.ts index cf3a75f84b..15c500694e 100644 --- a/src/renderer/src/providers/WebSearchProvider/SearxngProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/SearxngProvider.ts @@ -95,7 +95,11 @@ export default class SearxngProvider extends BaseWebSearchProvider { } } - public async search(query: string, websearch: WebSearchState): Promise { + public async search( + query: string, + websearch: WebSearchState, + httpOptions?: RequestInit + ): Promise { try { if (!query) { throw new Error('Search query cannot be empty') @@ -124,7 +128,7 @@ export default class SearxngProvider extends BaseWebSearchProvider { // Fetch content for each URL concurrently const fetchPromises = validItems.map(async (item) => { // Logger.log(`Fetching content for ${item.url}...`) - return await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser) + return await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser, httpOptions) }) // Wait for all fetches to complete diff --git a/src/renderer/src/providers/WebSearchProvider/TavilyProvider.ts b/src/renderer/src/providers/WebSearchProvider/TavilyProvider.ts index 3ca9b0676d..ac89523889 100644 --- a/src/renderer/src/providers/WebSearchProvider/TavilyProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/TavilyProvider.ts @@ -20,23 +20,42 @@ export default class TavilyProvider extends BaseWebSearchProvider { this.tvly = new TavilyClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost }) } - public async search(query: string, websearch: WebSearchState): Promise { + public async search( + query: string, + websearch: WebSearchState, + httpOptions?: RequestInit + ): Promise { try { if (!query.trim()) { throw new Error('Search query cannot be empty') } - const result = await this.tvly.search({ + // 使用 Promise.race 来支持 abort signal + const searchPromise = this.tvly.search({ query, max_results: Math.max(1, websearch.maxResults) }) + + let result: Awaited + if (httpOptions?.signal) { + result = await Promise.race([ + searchPromise, + new Promise((_, reject) => { + httpOptions.signal?.addEventListener('abort', () => { + reject(new DOMException('The operation was aborted.', 'AbortError')) + }) + }) + ]) + } else { + result = await searchPromise + } return { query: result.query, - results: result.results.slice(0, websearch.maxResults).map((result) => { + results: result.results.slice(0, websearch.maxResults).map((item) => { return { - title: result.title || 'No title', - content: result.content || '', - url: result.url || '' + title: item.title || 'No title', + content: item.content || '', + url: item.url || '' } }) } diff --git a/src/renderer/src/providers/WebSearchProvider/ZhipuProvider.ts b/src/renderer/src/providers/WebSearchProvider/ZhipuProvider.ts index e7c95fb1ce..b98c0cc170 100644 --- a/src/renderer/src/providers/WebSearchProvider/ZhipuProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/ZhipuProvider.ts @@ -43,7 +43,11 @@ export default class ZhipuProvider extends BaseWebSearchProvider { } } - public async search(query: string, websearch: WebSearchState): Promise { + public async search( + query: string, + websearch: WebSearchState, + httpOptions?: RequestInit + ): Promise { try { if (!query.trim()) { throw new Error('Search query cannot be empty') @@ -62,7 +66,8 @@ export default class ZhipuProvider extends BaseWebSearchProvider { 'Content-Type': 'application/json', ...this.defaultHeaders() }, - body: JSON.stringify(requestBody) + body: JSON.stringify(requestBody), + signal: httpOptions?.signal }) if (!response.ok) { diff --git a/src/renderer/src/services/WebSearchService.ts b/src/renderer/src/services/WebSearchService.ts index d7d30c1dbc..ed30e618f0 100644 --- a/src/renderer/src/services/WebSearchService.ts +++ b/src/renderer/src/services/WebSearchService.ts @@ -430,7 +430,8 @@ class WebSearchService { public async processWebsearch( webSearchProvider: WebSearchProvider, extractResults: ExtractResults, - requestId: string + requestId: string, + externalSignal?: AbortSignal ): Promise { // 重置状态 await this.setWebSearchStatus(requestId, { phase: 'default' }) @@ -441,8 +442,8 @@ class WebSearchService { return { results: [] } } - // 使用请求特定的signal,如果没有则回退到全局signal - const signal = this.getRequestState(requestId).signal || this.signal + // 优先使用外部传入的signal,其次是请求特定的signal,最后回退到全局signal + const signal = externalSignal || this.getRequestState(requestId).signal || this.signal const span = webSearchProvider.topicId ? addSpan({