diff --git a/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts b/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts index 2140950e19..f9a97b2cf3 100644 --- a/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts +++ b/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts @@ -67,6 +67,15 @@ export class ToolCallChunkHandler { description: toolName, type: 'provider' } + } else if (toolName.startsWith('builtin_')) { + // 如果是内置工具,沿用现有逻辑 + Logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`) + tool = { + id: toolCallId, + name: toolName, + description: toolName, + type: 'builtin' + } } else { // 如果是客户端执行的 MCP 工具,沿用现有逻辑 Logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`) diff --git a/src/renderer/src/aiCore/tools/WebSearchTool.ts b/src/renderer/src/aiCore/tools/WebSearchTool.ts index 52a6d743a3..dc8ff8cf52 100644 --- a/src/renderer/src/aiCore/tools/WebSearchTool.ts +++ b/src/renderer/src/aiCore/tools/WebSearchTool.ts @@ -1,13 +1,16 @@ +import { extractSearchKeywords } from '@renderer/aiCore/transformParameters' import WebSearchService from '@renderer/services/WebSearchService' -import { WebSearchProvider } from '@renderer/types' -import aiSdk from 'ai' +import { Assistant, Message, WebSearchProvider } from '@renderer/types' +import { UserMessageStatus } from '@renderer/types/newMessage' +import { ExtractResults } from '@renderer/utils/extract' +import * as aiSdk from 'ai' import { AiSdkTool, ToolCallResult } from './types' -export const webSearchTool = (webSearchProviderId: WebSearchProvider['id'], requestId: string): AiSdkTool => { +export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']): AiSdkTool => { const webSearchService = WebSearchService.getInstance(webSearchProviderId) return { - name: 'web_search', + name: 'builtin_web_search', description: 'Search the web for information', inputSchema: aiSdk.jsonSchema({ type: 'object', @@ -18,7 +21,9 @@ export const webSearchTool = (webSearchProviderId: WebSearchProvider['id'], requ }), execute: async ({ query }): Promise => { try { - const response = await webSearchService.processWebsearch(query, requestId) + console.log('webSearchTool', query) + const response = await webSearchService.search(query) + console.log('webSearchTool response', response) return { success: true, data: response @@ -32,3 +37,109 @@ export const webSearchTool = (webSearchProviderId: WebSearchProvider['id'], requ } } } + +export const webSearchToolWithExtraction = ( + webSearchProviderId: WebSearchProvider['id'], + requestId: string, + assistant: Assistant +): AiSdkTool => { + const webSearchService = WebSearchService.getInstance(webSearchProviderId) + + return { + name: 'web_search_with_extraction', + description: 'Search the web for information with automatic keyword extraction from user messages', + inputSchema: aiSdk.jsonSchema({ + type: 'object', + properties: { + userMessage: { + type: 'object', + description: 'The user message to extract keywords from', + properties: { + content: { type: 'string', description: 'The main content of the message' }, + role: { type: 'string', description: 'Message role (user/assistant/system)' } + }, + required: ['content', 'role'] + }, + lastAnswer: { + type: 'object', + description: 'Optional last assistant response for context', + properties: { + content: { type: 'string', description: 'The main content of the message' }, + role: { type: 'string', description: 'Message role (user/assistant/system)' } + }, + required: ['content', 'role'] + } + }, + required: ['userMessage'] + }), + execute: async ({ userMessage, lastAnswer }): Promise => { + try { + const lastUserMessage: Message = { + id: requestId, + role: userMessage.role as 'user' | 'assistant' | 'system', + assistantId: assistant.id, + topicId: 'temp', + createdAt: new Date().toISOString(), + status: UserMessageStatus.SUCCESS, + blocks: [] + } + + const lastAnswerMessage: Message | undefined = lastAnswer + ? { + id: requestId + '_answer', + role: lastAnswer.role as 'user' | 'assistant' | 'system', + assistantId: assistant.id, + topicId: 'temp', + createdAt: new Date().toISOString(), + status: UserMessageStatus.SUCCESS, + blocks: [] + } + : undefined + + const extractResults = await extractSearchKeywords(lastUserMessage, assistant, { + shouldWebSearch: true, + shouldKnowledgeSearch: false, + lastAnswer: lastAnswerMessage + }) + + if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') { + return { + success: false, + data: 'No search needed or extraction failed' + } + } + + const searchQueries = extractResults.websearch.question + const searchResults: Array<{ query: string; results: any }> = [] + + for (const query of searchQueries) { + // 构建单个查询的ExtractResults结构 + const queryExtractResults: ExtractResults = { + websearch: { + question: [query], + links: extractResults.websearch.links + } + } + const response = await webSearchService.processWebsearch(queryExtractResults, requestId) + searchResults.push({ + query, + results: response + }) + } + + return { + success: true, + data: { + extractedKeywords: extractResults.websearch, + searchResults + } + } + } catch (error) { + return { + success: false, + data: error + } + } + } + } +} diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts index 82e46c1cd9..0b577d9256 100644 --- a/src/renderer/src/aiCore/transformParameters.ts +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -13,6 +13,8 @@ import { TextPart, UserModelMessage } from '@cherrystudio/ai-core' +import AiProvider from '@renderer/aiCore' +import { CompletionsParams } from '@renderer/aiCore/middleware/schemas' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { isGenerateImageModel, @@ -26,10 +28,18 @@ import { isVisionModel, isWebSearchModel } from '@renderer/config/models' -import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService' +import { + SEARCH_SUMMARY_PROMPT, + SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY, + SEARCH_SUMMARY_PROMPT_WEB_ONLY +} from '@renderer/config/prompts' +import { getAssistantSettings, getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService' +import { getDefaultAssistant } from '@renderer/services/AssistantService' import type { Assistant, MCPTool, Message, Model, Provider } from '@renderer/types' import { FileTypes } from '@renderer/types' import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage' +// import { getWebSearchTools } from './utils/websearch' +import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract' import { findFileBlocks, findImageBlocks, @@ -38,12 +48,12 @@ import { } from '@renderer/utils/messageUtils/find' import { buildSystemPrompt } from '@renderer/utils/prompt' import { defaultTimeout } from '@shared/config/constant' +import { isEmpty } from 'lodash' import { webSearchTool } from './tools/WebSearchTool' // import { jsonSchemaToZod } from 'json-schema-to-zod' import { setupToolsConfig } from './utils/mcp' import { buildProviderOptions } from './utils/options' -// import { getWebSearchTools } from './utils/websearch' /** * 获取温度参数 @@ -289,10 +299,7 @@ export async function buildStreamTextParams( }) if (webSearchProviderId) { - // 生成requestId用于网络搜索工具 - const requestId = `request_${Date.now()}_${Math.random().toString(36).substring(2, 9)}` - - tools['builtin_web_search'] = webSearchTool(webSearchProviderId, requestId) + tools['builtin_web_search'] = webSearchTool(webSearchProviderId) } // 构建真正的 providerOptions @@ -336,3 +343,103 @@ export async function buildGenerateTextParams( // 复用流式参数的构建逻辑 return await buildStreamTextParams(messages, assistant, provider, options) } + +/** + * 提取外部工具搜索关键词和问题 + * 从用户消息中提取用于网络搜索和知识库搜索的关键词 + */ +export async function extractSearchKeywords( + lastUserMessage: Message, + assistant: Assistant, + options: { + shouldWebSearch?: boolean + shouldKnowledgeSearch?: boolean + lastAnswer?: Message + } = {} +): Promise { + const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer } = options + + if (!lastUserMessage) return undefined + + // 根据配置决定是否需要提取 + const needWebExtract = shouldWebSearch + const needKnowledgeExtract = shouldKnowledgeSearch + + if (!needWebExtract && !needKnowledgeExtract) return undefined + + // 选择合适的提示词 + let prompt: string + if (needWebExtract && !needKnowledgeExtract) { + prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY + } else if (!needWebExtract && needKnowledgeExtract) { + prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY + } else { + prompt = SEARCH_SUMMARY_PROMPT + } + + // 构建用于提取的助手配置 + const summaryAssistant = getDefaultAssistant() + summaryAssistant.model = assistant.model || getDefaultModel() + summaryAssistant.prompt = prompt + + try { + const result = await fetchSearchSummary({ + messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage], + assistant: summaryAssistant + }) + + if (!result) return getFallbackResult() + + const extracted = extractInfoFromXML(result.getText()) + // 根据需求过滤结果 + return { + websearch: needWebExtract ? extracted?.websearch : undefined, + knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined + } + } catch (e: any) { + console.error('extract error', e) + return getFallbackResult() + } + + function getFallbackResult(): ExtractResults { + const fallbackContent = getMainTextContent(lastUserMessage) + return { + websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined, + knowledge: shouldKnowledgeSearch + ? { + question: [fallbackContent || 'search'], + rewrite: fallbackContent || 'search' + } + : undefined + } + } +} + +/** + * 获取搜索摘要 - 内部辅助函数 + */ +async function fetchSearchSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) { + const model = assistant.model || getDefaultModel() + const provider = getProviderByModel(model) + + if (!hasApiKey(provider)) { + return null + } + + const AI = new AiProvider(provider) + + const params: CompletionsParams = { + callType: 'search', + messages: messages, + assistant, + streamOutput: false + } + + return await AI.completions(params) +} + +function hasApiKey(provider: Provider) { + if (!provider) return false + if (provider.id === 'ollama' || provider.id === 'lmstudio' || provider.type === 'vertexai') return true + return !isEmpty(provider.apiKey) +} diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 3849ee2c67..8307a3b16b 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -307,6 +307,7 @@ export async function fetchChatCompletion({ } = await buildStreamTextParams(messages, assistant, provider, { mcpTools: mcpTools, enableTools: isEnabledToolUse(assistant), + webSearchProviderId: assistant.webSearchProviderId, requestOptions: options }) diff --git a/src/renderer/src/services/WebSearchService.ts b/src/renderer/src/services/WebSearchService.ts index ea66fe7d54..9c964ff130 100644 --- a/src/renderer/src/services/WebSearchService.ts +++ b/src/renderer/src/services/WebSearchService.ts @@ -109,7 +109,7 @@ export default class WebSearchService { * @private * @returns 网络搜索状态 */ - private getWebSearchState(): WebSearchState { + private static getWebSearchState(): WebSearchState { return store.getState().websearch } @@ -118,8 +118,8 @@ export default class WebSearchService { * @public * @returns 如果默认搜索提供商已启用则返回true,否则返回false */ - public isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean { - const { providers } = this.getWebSearchState() + public static isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean { + const { providers } = WebSearchService.getWebSearchState() const provider = providers.find((provider) => provider.id === providerId) if (!provider) { @@ -149,7 +149,7 @@ export default class WebSearchService { * @returns 如果启用覆盖搜索则返回true,否则返回false */ public isOverwriteEnabled(): boolean { - const { overwrite } = this.getWebSearchState() + const { overwrite } = WebSearchService.getWebSearchState() return overwrite } @@ -159,7 +159,7 @@ export default class WebSearchService { * @returns 网络搜索提供商 */ public getWebSearchProvider(providerId?: string): WebSearchProvider | undefined { - const { providers } = this.getWebSearchState() + const { providers } = WebSearchService.getWebSearchState() const provider = providers.find((provider) => provider.id === providerId) return provider @@ -172,7 +172,7 @@ export default class WebSearchService { * @returns 搜索响应 */ public async search(query: string, httpOptions?: RequestInit): Promise { - const websearch = this.getWebSearchState() + const websearch = WebSearchService.getWebSearchState() const webSearchProvider = this.getWebSearchProvider(this.webSearchProviderId) if (!webSearchProvider) { throw new Error(`WebSearchProvider ${this.webSearchProviderId} not found`) @@ -495,7 +495,7 @@ export default class WebSearchService { } } - const { compressionConfig } = this.getWebSearchState() + const { compressionConfig } = WebSearchService.getWebSearchState() // RAG压缩处理 if (compressionConfig?.method === 'rag' && requestId) {