diff --git a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts index 7e662ddee6..4c2d0b3449 100644 --- a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts +++ b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts @@ -24,8 +24,10 @@ import { generateText } from 'ai' import { isEmpty } from 'lodash' import { MemoryProcessor } from '../../services/MemoryProcessor' +import { exaSearchTool } from '../tools/ExaSearchTool' import { knowledgeSearchTool } from '../tools/KnowledgeSearchTool' import { memorySearchTool } from '../tools/MemorySearchTool' +import { tavilySearchTool } from '../tools/TavilySearchTool' import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool' const logger = loggerService.withContext('SearchOrchestrationPlugin') @@ -316,13 +318,28 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) const needsSearch = analysisResult.websearch.question && analysisResult.websearch.question[0] !== 'not_needed' if (needsSearch) { - // onChunk({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS }) - // logger.info('🌐 Adding web search tool with pre-extracted keywords') - params.tools['builtin_web_search'] = webSearchToolWithPreExtractedKeywords( - assistant.webSearchProviderId, - analysisResult.websearch, - context.requestId - ) + // 根据 Provider ID 动态选择工具 + switch (assistant.webSearchProviderId) { + case 'exa': + logger.info('🌐 Adding Exa search tool (provider-specific)') + // Exa 工具直接接受单个查询字符串,使用第一个问题或合并所有问题 + params.tools['builtin_exa_search'] = exaSearchTool(context.requestId) + break + case 'tavily': + logger.info('🌐 Adding Tavily search tool (provider-specific)') + // Tavily 工具直接接受单个查询字符串 + params.tools['builtin_tavily_search'] = tavilySearchTool(context.requestId) + break + default: + logger.info('🌐 Adding web search tool with pre-extracted keywords') + // 其他 Provider 使用通用的 WebSearchTool + params.tools['builtin_web_search'] = webSearchToolWithPreExtractedKeywords( + assistant.webSearchProviderId, + analysisResult.websearch, + context.requestId + ) + break + } } } diff --git a/src/renderer/src/aiCore/tools/ExaSearchTool.ts b/src/renderer/src/aiCore/tools/ExaSearchTool.ts new file mode 100644 index 0000000000..8a710899fd --- /dev/null +++ b/src/renderer/src/aiCore/tools/ExaSearchTool.ts @@ -0,0 +1,166 @@ +import { loggerService } from '@logger' +import { REFERENCE_PROMPT } from '@renderer/config/prompts' +import WebSearchService from '@renderer/services/WebSearchService' +import { ProviderSpecificParams, WebSearchProviderResponse } from '@renderer/types' +import { ExtractResults } from '@renderer/utils/extract' +import { type InferToolInput, type InferToolOutput, tool } from 'ai' +import { z } from 'zod' + +const logger = loggerService.withContext('ExaSearchTool') + +/** + * Exa 专用搜索工具 - 暴露 Exa 的高级搜索能力给 LLM + * 支持 Neural Search、Category Filtering、Date Range 等功能 + */ +export const exaSearchTool = (requestId: string) => { + const webSearchProvider = WebSearchService.getWebSearchProvider('exa') + + if (!webSearchProvider) { + throw new Error('Exa provider not found or not configured') + } + + return tool({ + name: 'builtin_exa_search', + description: `Advanced AI-powered search using Exa.ai with neural understanding and filtering capabilities. + +Key Features: +- Neural Search: AI-powered semantic search that understands intent +- Search Type: Choose between neural (AI), keyword (traditional), or auto mode +- Category Filter: Focus on specific content types (company, research paper, news, etc.) +- Date Range: Filter by publication date +- Auto-prompt: Let Exa optimize your query automatically + +Best for: Research, finding specific types of content, semantic search, and understanding complex queries.`, + + inputSchema: z.object({ + query: z.string().describe('The search query - be specific and clear'), + numResults: z.number().min(1).max(20).optional().describe('Number of results to return (1-20, default: 5)'), + type: z + .enum(['neural', 'keyword', 'auto', 'fast']) + .optional() + .describe( + 'Search type: neural (embeddings-based), keyword (Google-like SERP), auto (default, intelligently combines both), or fast (streamlined versions)' + ), + category: z + .string() + .optional() + .describe( + 'Filter by content category: company, research paper, news, github, tweet, movie, song, personal site, pdf, etc.' + ), + startPublishedDate: z + .string() + .optional() + .describe('Start date filter based on published date in ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SSZ)'), + endPublishedDate: z + .string() + .optional() + .describe('End date filter based on published date in ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SSZ)'), + startCrawlDate: z + .string() + .optional() + .describe('Start date filter based on crawl date in ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SSZ)'), + endCrawlDate: z + .string() + .optional() + .describe('End date filter based on crawl date in ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SSZ)'), + useAutoprompt: z.boolean().optional().describe('Let Exa optimize your query automatically (recommended: true)') + }), + + execute: async (params, { abortSignal }) => { + // 构建 provider 特定参数(排除 query 和 numResults,这些由系统控制) + const providerParams: ProviderSpecificParams = { + exa: { + type: params.type, + category: params.category, + startPublishedDate: params.startPublishedDate, + endPublishedDate: params.endPublishedDate, + startCrawlDate: params.startCrawlDate, + endCrawlDate: params.endCrawlDate, + useAutoprompt: params.useAutoprompt + } + } + // 构建 ExtractResults 结构 + const extractResults: ExtractResults = { + websearch: { + question: [params.query] + } + } + + // 统一调用 processWebsearch - 保留所有中间件(时间戳、黑名单、tracing、压缩) + const finalResults: WebSearchProviderResponse = await WebSearchService.processWebsearch( + webSearchProvider, + extractResults, + requestId, + abortSignal, + providerParams + ) + + logger.info(`Exa search completed: ${finalResults.results.length} results for "${params.query}"`) + + return finalResults + }, + + toModelOutput: (results) => { + let summary = 'No search results found.' + if (results.query && results.results.length > 0) { + summary = `Found ${results.results.length} relevant sources using Exa AI search. Use [number] format to cite specific information.` + } + + const citationData = results.results.map((result, index) => { + const citation: any = { + number: index + 1, + title: result.title, + content: result.content, + url: result.url + } + + // 添加 Exa 特有的元数据 + if ('favicon' in result && result.favicon) { + citation.favicon = result.favicon + } + if ('author' in result && result.author) { + citation.author = result.author + } + if ('publishedDate' in result && result.publishedDate) { + citation.publishedDate = result.publishedDate + } + if ('score' in result && result.score !== undefined) { + citation.score = result.score + } + if ('highlights' in result && result.highlights) { + citation.highlights = result.highlights + } + + return citation + }) + + // 使用 REFERENCE_PROMPT 格式化引用 + const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\`` + const fullInstructions = REFERENCE_PROMPT.replace( + '{question}', + "Based on the Exa search results, please answer the user's question with proper citations." + ).replace('{references}', referenceContent) + + return { + type: 'content', + value: [ + { + type: 'text', + text: 'Exa AI Search: Neural search with semantic understanding and rich metadata (author, publish date, highlights).' + }, + { + type: 'text', + text: summary + }, + { + type: 'text', + text: fullInstructions + } + ] + } + } + }) +} + +export type ExaSearchToolOutput = InferToolOutput> +export type ExaSearchToolInput = InferToolInput> diff --git a/src/renderer/src/aiCore/tools/TavilySearchTool.ts b/src/renderer/src/aiCore/tools/TavilySearchTool.ts new file mode 100644 index 0000000000..41e6bcf3b6 --- /dev/null +++ b/src/renderer/src/aiCore/tools/TavilySearchTool.ts @@ -0,0 +1,161 @@ +import { loggerService } from '@logger' +import { REFERENCE_PROMPT } from '@renderer/config/prompts' +import WebSearchService from '@renderer/services/WebSearchService' +import { ProviderSpecificParams, WebSearchProviderResponse } from '@renderer/types' +import { ExtractResults } from '@renderer/utils/extract' +import { type InferToolInput, type InferToolOutput, tool } from 'ai' +import { z } from 'zod' + +const logger = loggerService.withContext('TavilySearchTool') + +/** + * Tavily 专用搜索工具 - 暴露 Tavily 的高级搜索能力给 LLM + * 支持 AI-powered answers、Search depth control、Topic filtering 等功能 + */ +export const tavilySearchTool = (requestId: string) => { + const webSearchProvider = WebSearchService.getWebSearchProvider('tavily') + + if (!webSearchProvider) { + throw new Error('Tavily provider not found or not configured') + } + + return tool({ + name: 'builtin_tavily_search', + description: `AI-powered search using Tavily with direct answers and comprehensive content extraction. + +Key Features: +- Direct AI Answer: Get a concise, factual answer extracted from search results +- Search Depth: Choose between basic (fast) or advanced (comprehensive) search +- Topic Focus: Filter by general, news, or finance topics +- Full Content: Access complete webpage content, not just snippets +- Rich Media: Optionally include relevant images from search results + +Best for: Quick factual answers, news monitoring, financial research, and comprehensive content analysis.`, + + inputSchema: z.object({ + query: z.string().describe('The search query - be specific and clear'), + maxResults: z + .number() + .min(1) + .max(20) + .optional() + .describe('Maximum number of results to return (1-20, default: 5)'), + topic: z + .enum(['general', 'news', 'finance']) + .optional() + .describe('Topic filter: general (default), news (latest news), or finance (financial/market data)'), + searchDepth: z + .enum(['basic', 'advanced']) + .optional() + .describe('Search depth: basic (faster, top results) or advanced (slower, more comprehensive)'), + includeAnswer: z + .boolean() + .optional() + .describe('Include AI-generated direct answer extracted from results (default: true)'), + includeRawContent: z + .boolean() + .optional() + .describe('Include full webpage content instead of just snippets (default: true)'), + includeImages: z.boolean().optional().describe('Include relevant images from search results (default: false)') + }), + + execute: async (params, { abortSignal }) => { + try { + // 构建 provider 特定参数 + const providerParams: ProviderSpecificParams = { + tavily: { + topic: params.topic, + searchDepth: params.searchDepth, + includeAnswer: params.includeAnswer, + includeRawContent: params.includeRawContent, + includeImages: params.includeImages + } + } + + // 构建 ExtractResults 结构 + const extractResults: ExtractResults = { + websearch: { + question: [params.query] + } + } + + // 统一调用 processWebsearch - 保留所有中间件(时间戳、黑名单、tracing、压缩) + const finalResults: WebSearchProviderResponse = await WebSearchService.processWebsearch( + webSearchProvider, + extractResults, + requestId, + abortSignal, + providerParams + ) + + logger.info(`Tavily search completed: ${finalResults.results.length} results for "${params.query}"`) + + return finalResults + } catch (error) { + if (error instanceof DOMException && error.name === 'AbortError') { + logger.info('Tavily search aborted') + throw error + } + logger.error('Tavily search failed:', error as Error) + throw new Error(`Tavily search failed: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }, + + toModelOutput: (results) => { + let summary = 'No search results found.' + if (results.query && results.results.length > 0) { + summary = `Found ${results.results.length} relevant sources using Tavily AI search. Use [number] format to cite specific information.` + } + + const citationData = results.results.map((result, index) => { + const citation: any = { + number: index + 1, + title: result.title, + content: result.content, + url: result.url + } + + // 添加 Tavily 特有的元数据 + if ('answer' in result && result.answer) { + citation.answer = result.answer // Tavily 的直接答案 + } + if ('images' in result && result.images && result.images.length > 0) { + citation.images = result.images // Tavily 的图片 + } + if ('score' in result && result.score !== undefined) { + citation.score = result.score + } + + return citation + }) + + // 使用 REFERENCE_PROMPT 格式化引用 + const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\`` + const fullInstructions = REFERENCE_PROMPT.replace( + '{question}', + "Based on the Tavily search results, please answer the user's question with proper citations." + ).replace('{references}', referenceContent) + + return { + type: 'content', + value: [ + { + type: 'text', + text: 'Tavily AI Search: AI-powered with direct answers, full content extraction, and optional image results.' + }, + { + type: 'text', + text: summary + }, + { + type: 'text', + text: fullInstructions + } + ] + } + } + }) +} + +export type TavilySearchToolOutput = InferToolOutput> +export type TavilySearchToolInput = InferToolInput> diff --git a/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx b/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx index b25e7fe01e..54c0797482 100644 --- a/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx +++ b/src/renderer/src/pages/home/Messages/Tools/MessageTool.tsx @@ -58,6 +58,8 @@ const ChooseTool = ( switch (toolName) { case 'web_search': case 'web_search_preview': + case 'exa_search': + case 'tavily_search': return toolType === 'provider' ? null : case 'knowledge_search': return diff --git a/src/renderer/src/pages/home/Messages/Tools/MessageWebSearch.tsx b/src/renderer/src/pages/home/Messages/Tools/MessageWebSearch.tsx index 5fe71bbae8..19b8737b61 100644 --- a/src/renderer/src/pages/home/Messages/Tools/MessageWebSearch.tsx +++ b/src/renderer/src/pages/home/Messages/Tools/MessageWebSearch.tsx @@ -1,3 +1,5 @@ +import { ExaSearchToolInput, ExaSearchToolOutput } from '@renderer/aiCore/tools/ExaSearchTool' +import { TavilySearchToolInput, TavilySearchToolOutput } from '@renderer/aiCore/tools/TavilySearchTool' import { WebSearchToolInput, WebSearchToolOutput } from '@renderer/aiCore/tools/WebSearchTool' import Spinner from '@renderer/components/Spinner' import { NormalToolResponse } from '@renderer/types' @@ -8,17 +10,31 @@ import styled from 'styled-components' const { Text } = Typography +// 联合类型 - 支持多种搜索工具 +type SearchToolInput = WebSearchToolInput | ExaSearchToolInput | TavilySearchToolInput +type SearchToolOutput = WebSearchToolOutput | ExaSearchToolOutput | TavilySearchToolOutput + export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: NormalToolResponse }) => { const { t } = useTranslation() - const toolInput = toolResponse.arguments as WebSearchToolInput - const toolOutput = toolResponse.response as WebSearchToolOutput + const toolInput = toolResponse.arguments as SearchToolInput + const toolOutput = toolResponse.response as SearchToolOutput + // 根据不同的工具类型获取查询内容 + const getQueryText = () => { + if ('additionalContext' in toolInput) { + return toolInput.additionalContext ?? '' + } + if ('query' in toolInput) { + return toolInput.query ?? '' + } + return '' + } return toolResponse.status !== 'done' ? ( {t('message.searching')} - {toolInput?.additionalContext ?? ''} + {getQueryText()} } /> diff --git a/src/renderer/src/providers/WebSearchProvider/BaseWebSearchProvider.ts b/src/renderer/src/providers/WebSearchProvider/BaseWebSearchProvider.ts index 558a328b5e..110e2f7c9e 100644 --- a/src/renderer/src/providers/WebSearchProvider/BaseWebSearchProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/BaseWebSearchProvider.ts @@ -1,5 +1,5 @@ import { WebSearchState } from '@renderer/store/websearch' -import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' +import { ProviderSpecificParams, WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' export default abstract class BaseWebSearchProvider { // @ts-ignore this @@ -16,7 +16,8 @@ export default abstract class BaseWebSearchProvider { abstract search( query: string, websearch: WebSearchState, - httpOptions?: RequestInit + httpOptions?: RequestInit, + providerParams?: ProviderSpecificParams ): Promise public getApiHost() { diff --git a/src/renderer/src/providers/WebSearchProvider/BochaProvider.ts b/src/renderer/src/providers/WebSearchProvider/BochaProvider.ts index cf2be01e0c..aad91ff432 100644 --- a/src/renderer/src/providers/WebSearchProvider/BochaProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/BochaProvider.ts @@ -1,6 +1,6 @@ import { loggerService } from '@logger' import { WebSearchState } from '@renderer/store/websearch' -import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' +import { ProviderSpecificParams, WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' import { BochaSearchParams, BochaSearchResponse } from '@renderer/utils/bocha' import BaseWebSearchProvider from './BaseWebSearchProvider' @@ -21,7 +21,8 @@ export default class BochaProvider extends BaseWebSearchProvider { public async search( query: string, websearch: WebSearchState, - httpOptions?: RequestInit + httpOptions?: RequestInit, + _providerParams?: ProviderSpecificParams ): Promise { try { if (!query.trim()) { diff --git a/src/renderer/src/providers/WebSearchProvider/DefaultProvider.ts b/src/renderer/src/providers/WebSearchProvider/DefaultProvider.ts index 9b00f52ea9..0416d495b7 100644 --- a/src/renderer/src/providers/WebSearchProvider/DefaultProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/DefaultProvider.ts @@ -1,10 +1,15 @@ import { WebSearchState } from '@renderer/store/websearch' -import { WebSearchProviderResponse } from '@renderer/types' +import { ProviderSpecificParams, WebSearchProviderResponse } from '@renderer/types' import BaseWebSearchProvider from './BaseWebSearchProvider' export default class DefaultProvider extends BaseWebSearchProvider { - search(_query: string, _websearch: WebSearchState, _httpOptions?: RequestInit): Promise { + search( + _query: string, + _websearch: WebSearchState, + _httpOptions?: RequestInit, + _providerParams?: ProviderSpecificParams + ): Promise { throw new Error('Method not implemented.') } } diff --git a/src/renderer/src/providers/WebSearchProvider/ExaProvider.ts b/src/renderer/src/providers/WebSearchProvider/ExaProvider.ts index 7cabebe95f..947a3407c6 100644 --- a/src/renderer/src/providers/WebSearchProvider/ExaProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/ExaProvider.ts @@ -1,14 +1,53 @@ -import { ExaClient } from '@agentic/exa' import { loggerService } from '@logger' import { WebSearchState } from '@renderer/store/websearch' -import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' +import { + ExaSearchResult as ExaSearchResultType, + ProviderSpecificParams, + WebSearchProvider, + WebSearchProviderResponse +} from '@renderer/types' import BaseWebSearchProvider from './BaseWebSearchProvider' const logger = loggerService.withContext('ExaProvider') -export default class ExaProvider extends BaseWebSearchProvider { - private exa: ExaClient +interface ExaSearchRequest { + query: string + numResults: number + contents?: { + text?: boolean + highlights?: boolean + summary?: boolean + } + useAutoprompt?: boolean + category?: string + type?: 'keyword' | 'neural' | 'auto' | 'fast' + startPublishedDate?: string + endPublishedDate?: string + startCrawlDate?: string + endCrawlDate?: string + includeDomains?: string[] + excludeDomains?: string[] +} + +interface ExaSearchResult { + title: string | null + url: string | null + text?: string | null + author?: string | null + score?: number + publishedDate?: string | null + favicon?: string | null + highlights?: string[] +} + +interface ExaSearchResponse { + autopromptString?: string + results: ExaSearchResult[] + resolvedSearchType?: string +} + +export default class ExaProvider extends BaseWebSearchProvider { constructor(provider: WebSearchProvider) { super(provider) if (!this.apiKey) { @@ -17,53 +56,138 @@ export default class ExaProvider extends BaseWebSearchProvider { if (!this.apiHost) { throw new Error('API host is required for Exa provider') } - this.exa = new ExaClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost }) } + /** + * 统一的搜索方法 - 根据 providerParams 决定是否使用高级参数 + */ public async search( query: string, websearch: WebSearchState, - httpOptions?: RequestInit + httpOptions?: RequestInit, + providerParams?: ProviderSpecificParams ): Promise { + // 如果提供了 Exa 特定参数,使用高级搜索 + if (providerParams?.exa) { + return this.searchWithParams({ + query, + numResults: websearch.maxResults, + ...providerParams.exa, // 展开高级参数 + signal: httpOptions?.signal ?? undefined + }) + } + + // 否则使用默认参数 + return this.searchWithParams({ + query, + numResults: websearch.maxResults, + useAutoprompt: true, + signal: httpOptions?.signal ?? undefined + }) + } + + /** + * 使用完整参数进行搜索(支持 Exa 的所有高级功能) + */ + public async searchWithParams(params: { + query: string + numResults?: number + type?: 'keyword' | 'neural' | 'auto' | 'fast' + category?: string + startPublishedDate?: string + endPublishedDate?: string + startCrawlDate?: string + endCrawlDate?: string + useAutoprompt?: boolean + includeDomains?: string[] + excludeDomains?: string[] + signal?: AbortSignal + }): Promise { try { - if (!query.trim()) { + if (!params.query.trim()) { throw new Error('Search query cannot be empty') } - // 使用 Promise.race 来支持 abort signal - const searchPromise = this.exa.search({ - query, - numResults: Math.max(1, websearch.maxResults), + const requestBody: ExaSearchRequest = { + query: params.query, + numResults: Math.max(1, params.numResults || 5), contents: { - text: true - } + text: true, + highlights: true // 获取高亮片段 + }, + useAutoprompt: params.useAutoprompt ?? true + } + + // 添加可选参数 + if (params.type) { + requestBody.type = params.type + } + + if (params.category) { + requestBody.category = params.category + } + + if (params.startPublishedDate) { + requestBody.startPublishedDate = params.startPublishedDate + } + + if (params.endPublishedDate) { + requestBody.endPublishedDate = params.endPublishedDate + } + + if (params.startCrawlDate) { + requestBody.startCrawlDate = params.startCrawlDate + } + + if (params.endCrawlDate) { + requestBody.endCrawlDate = params.endCrawlDate + } + + if (params.includeDomains && params.includeDomains.length > 0) { + requestBody.includeDomains = params.includeDomains + } + + if (params.excludeDomains && params.excludeDomains.length > 0) { + requestBody.excludeDomains = params.excludeDomains + } + + const response = await fetch(`${this.apiHost}/search`, { + method: 'POST', + headers: { + 'x-api-key': this.apiKey!, + 'Content-Type': 'application/json' + }, + body: JSON.stringify(requestBody), + signal: params.signal }) - let response: Awaited - if (httpOptions?.signal) { - response = await Promise.race([ - searchPromise, - new Promise((_, reject) => { - httpOptions.signal?.addEventListener('abort', () => { - reject(new DOMException('The operation was aborted.', 'AbortError')) - }) - }) - ]) - } else { - response = await searchPromise + if (!response.ok) { + const errorText = await response.text() + throw new Error(`Exa API error (${response.status}): ${errorText}`) } + const data: ExaSearchResponse = await response.json() + + // 返回完整的 Exa 结果(包含 favicon、author、score 等字段) return { - query: response.autopromptString, - results: response.results.slice(0, websearch.maxResults).map((result) => { - return { + query: data.autopromptString || params.query, + results: data.results.slice(0, params.numResults || 5).map( + (result): ExaSearchResultType => ({ title: result.title || 'No title', content: result.text || '', - url: result.url || '' - } - }) + url: result.url || '', + favicon: result.favicon || undefined, + publishedDate: result.publishedDate || undefined, + author: result.author || undefined, + score: result.score, + highlights: result.highlights + }) + ) } } catch (error) { + if (error instanceof DOMException && error.name === 'AbortError') { + throw error + } logger.error('Exa search failed:', error as Error) throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`) } diff --git a/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts b/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts index 2a17dc3a26..8a95fcdc36 100644 --- a/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts @@ -2,7 +2,12 @@ import { loggerService } from '@logger' import { nanoid } from '@reduxjs/toolkit' import store from '@renderer/store' import { WebSearchState } from '@renderer/store/websearch' -import { WebSearchProvider, WebSearchProviderResponse, WebSearchProviderResult } from '@renderer/types' +import { + ProviderSpecificParams, + WebSearchProvider, + WebSearchProviderResponse, + WebSearchProviderResult +} from '@renderer/types' import { createAbortPromise } from '@renderer/utils/abortController' import { isAbortError } from '@renderer/utils/error' import { fetchWebContent, noContent } from '@renderer/utils/fetch' @@ -27,7 +32,8 @@ export default class LocalSearchProvider extends BaseWebSearchProvider { public async search( query: string, websearch: WebSearchState, - httpOptions?: RequestInit + httpOptions?: RequestInit, + _providerParams?: ProviderSpecificParams ): Promise { const uid = nanoid() const language = store.getState().settings.language diff --git a/src/renderer/src/providers/WebSearchProvider/SearxngProvider.ts b/src/renderer/src/providers/WebSearchProvider/SearxngProvider.ts index 15c500694e..93324a23e1 100644 --- a/src/renderer/src/providers/WebSearchProvider/SearxngProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/SearxngProvider.ts @@ -1,7 +1,7 @@ import { SearxngClient } from '@agentic/searxng' import { loggerService } from '@logger' import { WebSearchState } from '@renderer/store/websearch' -import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' +import { ProviderSpecificParams, WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' import { fetchWebContent, noContent } from '@renderer/utils/fetch' import axios from 'axios' import ky from 'ky' @@ -98,7 +98,8 @@ export default class SearxngProvider extends BaseWebSearchProvider { public async search( query: string, websearch: WebSearchState, - httpOptions?: RequestInit + httpOptions?: RequestInit, + _providerParams?: ProviderSpecificParams ): Promise { try { if (!query) { diff --git a/src/renderer/src/providers/WebSearchProvider/TavilyProvider.ts b/src/renderer/src/providers/WebSearchProvider/TavilyProvider.ts index ac89523889..84d570f63e 100644 --- a/src/renderer/src/providers/WebSearchProvider/TavilyProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/TavilyProvider.ts @@ -1,14 +1,45 @@ -import { TavilyClient } from '@agentic/tavily' import { loggerService } from '@logger' import { WebSearchState } from '@renderer/store/websearch' -import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' +import { + ProviderSpecificParams, + TavilySearchResult as TavilySearchResultType, + WebSearchProvider, + WebSearchProviderResponse +} from '@renderer/types' import BaseWebSearchProvider from './BaseWebSearchProvider' const logger = loggerService.withContext('TavilyProvider') -export default class TavilyProvider extends BaseWebSearchProvider { - private tvly: TavilyClient +interface TavilySearchRequest { + query: string + max_results?: number + topic?: 'general' | 'news' | 'finance' + search_depth?: 'basic' | 'advanced' + include_answer?: boolean + include_raw_content?: boolean + include_images?: boolean + include_domains?: string[] + exclude_domains?: string[] +} + +interface TavilySearchResult { + title: string + url: string + content: string + raw_content?: string + score?: number +} + +interface TavilySearchResponse { + query: string + results: TavilySearchResult[] + answer?: string + images?: string[] + response_time?: number +} + +export default class TavilyProvider extends BaseWebSearchProvider { constructor(provider: WebSearchProvider) { super(provider) if (!this.apiKey) { @@ -17,49 +48,119 @@ export default class TavilyProvider extends BaseWebSearchProvider { if (!this.apiHost) { throw new Error('API host is required for Tavily provider') } - this.tvly = new TavilyClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost }) } + /** + * 统一的搜索方法 - 根据 providerParams 决定是否使用高级参数 + */ public async search( query: string, websearch: WebSearchState, - httpOptions?: RequestInit + httpOptions?: RequestInit, + providerParams?: ProviderSpecificParams ): Promise { + // 如果提供了 Tavily 特定参数,使用高级搜索 + if (providerParams?.tavily) { + return this.searchWithParams({ + query, + maxResults: websearch.maxResults, + ...providerParams.tavily, // 展开高级参数 + signal: httpOptions?.signal ?? undefined + }) + } + + // 否则使用默认参数 + return this.searchWithParams({ + query, + maxResults: websearch.maxResults, + includeRawContent: true, + signal: httpOptions?.signal ?? undefined + }) + } + + /** + * 使用完整参数进行搜索(支持 Tavily 的所有高级功能) + */ + public async searchWithParams(params: { + query: string + maxResults?: number + topic?: 'general' | 'news' | 'finance' + searchDepth?: 'basic' | 'advanced' + includeAnswer?: boolean + includeRawContent?: boolean + includeImages?: boolean + includeDomains?: string[] + excludeDomains?: string[] + signal?: AbortSignal + }): Promise { try { - if (!query.trim()) { + if (!params.query.trim()) { throw new Error('Search query cannot be empty') } - // 使用 Promise.race 来支持 abort signal - const searchPromise = this.tvly.search({ - query, - max_results: Math.max(1, websearch.maxResults) + const requestBody: TavilySearchRequest = { + query: params.query, + max_results: Math.max(1, params.maxResults || 5), + include_raw_content: params.includeRawContent ?? true, + include_answer: params.includeAnswer ?? true, + include_images: params.includeImages ?? false + } + + // 添加可选参数 + if (params.topic) { + requestBody.topic = params.topic + } + + if (params.searchDepth) { + requestBody.search_depth = params.searchDepth + } + + if (params.includeDomains && params.includeDomains.length > 0) { + requestBody.include_domains = params.includeDomains + } + + if (params.excludeDomains && params.excludeDomains.length > 0) { + requestBody.exclude_domains = params.excludeDomains + } + + const response = await fetch(`${this.apiHost}/search`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + ...requestBody, + api_key: this.apiKey + }), + signal: params.signal }) - let result: Awaited - if (httpOptions?.signal) { - result = await Promise.race([ - searchPromise, - new Promise((_, reject) => { - httpOptions.signal?.addEventListener('abort', () => { - reject(new DOMException('The operation was aborted.', 'AbortError')) - }) - }) - ]) - } else { - result = await searchPromise + if (!response.ok) { + const errorText = await response.text() + throw new Error(`Tavily API error (${response.status}): ${errorText}`) } + + const data: TavilySearchResponse = await response.json() + + // 返回完整的 Tavily 结果(包含 answer、images 等字段) return { - query: result.query, - results: result.results.slice(0, websearch.maxResults).map((item) => { - return { + query: data.query, + results: data.results.slice(0, params.maxResults || 5).map( + (item): TavilySearchResultType => ({ title: item.title || 'No title', - content: item.content || '', - url: item.url || '' - } - }) + content: item.raw_content || item.content || '', + url: item.url || '', + rawContent: item.raw_content, + score: item.score, + answer: data.answer, // Tavily 的直接答案 + images: data.images // Tavily 的图片 + }) + ) } } catch (error) { + if (error instanceof DOMException && error.name === 'AbortError') { + throw error + } logger.error('Tavily search failed:', error as Error) throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`) } diff --git a/src/renderer/src/providers/WebSearchProvider/ZhipuProvider.ts b/src/renderer/src/providers/WebSearchProvider/ZhipuProvider.ts index b98c0cc170..905cf59cb6 100644 --- a/src/renderer/src/providers/WebSearchProvider/ZhipuProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/ZhipuProvider.ts @@ -1,6 +1,6 @@ import { loggerService } from '@logger' import { WebSearchState } from '@renderer/store/websearch' -import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' +import { ProviderSpecificParams, WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' import BaseWebSearchProvider from './BaseWebSearchProvider' @@ -46,7 +46,8 @@ export default class ZhipuProvider extends BaseWebSearchProvider { public async search( query: string, websearch: WebSearchState, - httpOptions?: RequestInit + httpOptions?: RequestInit, + _providerParams?: ProviderSpecificParams ): Promise { try { if (!query.trim()) { diff --git a/src/renderer/src/providers/WebSearchProvider/index.ts b/src/renderer/src/providers/WebSearchProvider/index.ts index e1fe8f185c..976fbd77b4 100644 --- a/src/renderer/src/providers/WebSearchProvider/index.ts +++ b/src/renderer/src/providers/WebSearchProvider/index.ts @@ -1,6 +1,6 @@ import { withSpanResult } from '@renderer/services/SpanManagerService' import type { WebSearchState } from '@renderer/store/websearch' -import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' +import { ProviderSpecificParams, WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' import { filterResultWithBlacklist } from '@renderer/utils/blacklistMatchPattern' import BaseWebSearchProvider from './BaseWebSearchProvider' @@ -24,10 +24,11 @@ export default class WebSearchEngineProvider { public async search( query: string, websearch: WebSearchState, - httpOptions?: RequestInit + httpOptions?: RequestInit, + providerParams?: ProviderSpecificParams ): Promise { - const callSearch = async ({ query, websearch }) => { - return await this.sdk.search(query, websearch, httpOptions) + const callSearch = async ({ query, websearch, providerParams }) => { + return await this.sdk.search(query, websearch, httpOptions, providerParams) } const traceParams = { @@ -38,7 +39,7 @@ export default class WebSearchEngineProvider { modelName: this.modelName } - const result = await withSpanResult(callSearch, traceParams, { query, websearch }) + const result = await withSpanResult(callSearch, traceParams, { query, websearch, providerParams }) return await filterResultWithBlacklist(result, websearch) } diff --git a/src/renderer/src/services/WebSearchService.ts b/src/renderer/src/services/WebSearchService.ts index ed30e618f0..7261c131d6 100644 --- a/src/renderer/src/services/WebSearchService.ts +++ b/src/renderer/src/services/WebSearchService.ts @@ -10,6 +10,7 @@ import { KnowledgeBase, KnowledgeItem, KnowledgeReference, + ProviderSpecificParams, WebSearchProvider, WebSearchProviderResponse, WebSearchProviderResult, @@ -161,13 +162,17 @@ class WebSearchService { * @public * @param provider 搜索提供商 * @param query 搜索查询 + * @param httpOptions HTTP选项(包含signal等) + * @param spanId Span ID用于追踪 + * @param providerParams Provider特定参数(如Exa的category、Tavily的searchDepth等) * @returns 搜索响应 */ public async search( provider: WebSearchProvider, query: string, httpOptions?: RequestInit, - spanId?: string + spanId?: string, + providerParams?: ProviderSpecificParams ): Promise { const websearch = this.getWebSearchState() const webSearchEngine = new WebSearchEngineProvider(provider, spanId) @@ -178,7 +183,7 @@ class WebSearchService { formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}` } - return await webSearchEngine.search(formattedQuery, websearch, httpOptions) + return await webSearchEngine.search(formattedQuery, websearch, httpOptions, providerParams) } /** @@ -424,6 +429,8 @@ class WebSearchService { * @param webSearchProvider - 要使用的网络搜索提供商 * @param extractResults - 包含搜索问题和链接的提取结果对象 * @param requestId - 唯一的请求标识符,用于状态跟踪和资源管理 + * @param externalSignal - 可选的 AbortSignal 用于取消请求 + * @param providerParams - 可选的 Provider 特定参数(如 Exa 的 category、Tavily 的 searchDepth 等) * * @returns 包含搜索结果的响应对象 */ @@ -431,7 +438,8 @@ class WebSearchService { webSearchProvider: WebSearchProvider, extractResults: ExtractResults, requestId: string, - externalSignal?: AbortSignal + externalSignal?: AbortSignal, + providerParams?: ProviderSpecificParams ): Promise { // 重置状态 await this.setWebSearchStatus(requestId, { phase: 'default' }) @@ -474,8 +482,9 @@ class WebSearchService { return { query: 'summaries', results: contents } } + // 执行搜索 const searchPromises = questions.map((q) => - this.search(webSearchProvider, q, { signal }, span?.spanContext().spanId) + this.search(webSearchProvider, q, { signal }, span?.spanContext().spanId, providerParams) ) const searchResults = await Promise.allSettled(searchPromises) diff --git a/src/renderer/src/services/messageStreaming/callbacks/toolCallbacks.ts b/src/renderer/src/services/messageStreaming/callbacks/toolCallbacks.ts index ef65a4962a..34ffa241b2 100644 --- a/src/renderer/src/services/messageStreaming/callbacks/toolCallbacks.ts +++ b/src/renderer/src/services/messageStreaming/callbacks/toolCallbacks.ts @@ -84,7 +84,8 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => { } blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true) // Handle citation block creation for web search results - if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response) { + const webSearchTools = ['builtin_web_search', 'builtin_exa_search', 'builtin_tavily_search'] + if (webSearchTools.includes(toolResponse.tool.name) && toolResponse.response) { const citationBlock = createCitationBlock( assistantMsgId, { diff --git a/src/renderer/src/store/messageBlock.ts b/src/renderer/src/store/messageBlock.ts index 524889bb32..86190964c5 100644 --- a/src/renderer/src/store/messageBlock.ts +++ b/src/renderer/src/store/messageBlock.ts @@ -4,6 +4,7 @@ import { createEntityAdapter, createSelector, createSlice, type PayloadAction } import { AISDKWebSearchResult, Citation, WebSearchProviderResponse, WebSearchSource } from '@renderer/types' import type { CitationMessageBlock, MessageBlock } from '@renderer/types/newMessage' import { MessageBlockType } from '@renderer/types/newMessage' +import { adaptSearchResultsToCitations } from '@renderer/utils/searchResultAdapters' import type OpenAI from 'openai' import type { RootState } from './index' // 确认 RootState 从 store/index.ts 导出 @@ -217,17 +218,12 @@ export const formatCitationsFromBlock = (block: CitationMessageBlock | undefined type: 'websearch' })) || [] break - case WebSearchSource.WEBSEARCH: - formattedCitations = - (block.response.results as WebSearchProviderResponse)?.results?.map((result, index) => ({ - number: index + 1, - url: result.url, - title: result.title, - content: result.content, - showFavicon: true, - type: 'websearch' - })) || [] + case WebSearchSource.WEBSEARCH: { + const results = (block.response.results as WebSearchProviderResponse)?.results || [] + // 使用适配器统一转换,自动处理 Provider 特定字段(如 Exa 的 favicon、Tavily 的 answer 等) + formattedCitations = adaptSearchResultsToCitations(results) break + } case WebSearchSource.AISDK: formattedCitations = (block.response?.results as AISDKWebSearchResult[])?.map((result, index) => ({ diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index d8845cd1cc..7a3b71e053 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -575,17 +575,63 @@ export type WebSearchProvider = { modelName?: string } -export type WebSearchProviderResult = { +// 基础搜索结果(所有 Provider 必须实现) +export interface BaseSearchResult { title: string content: string url: string } +// Exa Provider 特定扩展 +export interface ExaSearchResult extends BaseSearchResult { + favicon?: string + publishedDate?: string + author?: string + score?: number + highlights?: string[] +} + +// Tavily Provider 特定扩展 +export interface TavilySearchResult extends BaseSearchResult { + answer?: string // Tavily 的 AI 直接答案 + images?: string[] + rawContent?: string + score?: number +} + +// 联合类型 - 向后兼容 +export type WebSearchProviderResult = BaseSearchResult | ExaSearchResult | TavilySearchResult + export type WebSearchProviderResponse = { query?: string results: WebSearchProviderResult[] } +// Provider 特定参数类型 +export interface ExaSearchParams { + type?: 'neural' | 'keyword' | 'auto' | 'fast' + category?: string + startPublishedDate?: string + endPublishedDate?: string + startCrawlDate?: string + endCrawlDate?: string + useAutoprompt?: boolean +} + +export interface TavilySearchParams { + topic?: 'general' | 'news' | 'finance' + searchDepth?: 'basic' | 'advanced' + includeAnswer?: boolean + includeRawContent?: boolean + includeImages?: boolean +} + +// 联合类型 - 支持不同 Provider 的特定参数 +export interface ProviderSpecificParams { + exa?: ExaSearchParams + tavily?: TavilySearchParams +} + export type AISDKWebSearchResult = Omit, 'sourceType'> export type WebSearchResults = @@ -813,6 +859,7 @@ export interface Citation { hostname?: string content?: string showFavicon?: boolean + favicon?: string // 新增:直接的 favicon URL(来自 Provider) type?: string metadata?: Record } diff --git a/src/renderer/src/utils/searchResultAdapters.ts b/src/renderer/src/utils/searchResultAdapters.ts new file mode 100644 index 0000000000..7ae1b0c6da --- /dev/null +++ b/src/renderer/src/utils/searchResultAdapters.ts @@ -0,0 +1,77 @@ +/** + * 搜索结果适配器 + * 将不同 Provider 的搜索结果统一转换为 Citation 格式 + */ + +import type { Citation, WebSearchProviderResult } from '@renderer/types' + +/** + * 将 WebSearchProviderResult 转换为 Citation + * 自动识别并处理不同 Provider 的额外字段 + * + * @param result - 搜索结果(可能包含 Provider 特定字段) + * @param index - 结果序号(从0开始) + * @returns Citation 对象 + */ +export function adaptSearchResultToCitation(result: WebSearchProviderResult, index: number): Citation { + // 基础字段(所有 Provider 都有) + const citation: Citation = { + number: index + 1, + url: result.url, + title: result.title, + content: result.content, + showFavicon: true, + type: 'websearch' + } + + // Exa Provider 特定字段 + if ('favicon' in result && result.favicon) { + citation.favicon = result.favicon + } + + // 收集元数据 + const metadata: Record = {} + + // Exa 元数据 + if ('publishedDate' in result && result.publishedDate) { + metadata.publishedDate = result.publishedDate + } + + if ('author' in result && result.author) { + metadata.author = result.author + } + + if ('score' in result && result.score !== undefined) { + metadata.score = result.score + } + + if ('highlights' in result && result.highlights && result.highlights.length > 0) { + metadata.highlights = result.highlights + } + + // Tavily 元数据 + if ('answer' in result && result.answer) { + metadata.answer = result.answer + } + + if ('images' in result && result.images && result.images.length > 0) { + metadata.images = result.images + } + + // 只在有元数据时添加 + if (Object.keys(metadata).length > 0) { + citation.metadata = metadata + } + + return citation +} + +/** + * 批量转换搜索结果为 Citations + * + * @param results - 搜索结果数组 + * @returns Citation 数组 + */ +export function adaptSearchResultsToCitations(results: WebSearchProviderResult[]): Citation[] { + return results.map((result, index) => adaptSearchResultToCitation(result, index)) +}