diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts index 8d7cba4a5f..a7c7187bca 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts @@ -41,3 +41,27 @@ export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = { maxUses: 5 } } + +export type WebSearchToolOutputSchema = { + // Anthropic 工具 - 手动定义 + anthropicWebSearch: Array<{ + url: string + title: string + pageAge: string | null + encryptedContent: string + type: string + }> + + // OpenAI 工具 - 基于实际输出 + openaiWebSearch: { + status: 'completed' | 'failed' + } + + // Google 工具 + googleSearch: { + webSearchQueries?: string[] + groundingChunks?: Array<{ + web?: { uri: string; title: string } + }> + } +} diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts index 058ff5411a..2d8d5ea1fb 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts @@ -64,7 +64,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR }) // 导出类型定义供开发者使用 -export type { WebSearchPluginConfig } from './helper' +export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper' // 默认导出 export default webSearchPlugin diff --git a/src/renderer/src/aiCore/tools/WebSearchTool.ts b/src/renderer/src/aiCore/tools/WebSearchTool.ts index dc8ff8cf52..c6fbf618bb 100644 --- a/src/renderer/src/aiCore/tools/WebSearchTool.ts +++ b/src/renderer/src/aiCore/tools/WebSearchTool.ts @@ -3,143 +3,127 @@ import WebSearchService from '@renderer/services/WebSearchService' import { Assistant, Message, WebSearchProvider } from '@renderer/types' import { UserMessageStatus } from '@renderer/types/newMessage' import { ExtractResults } from '@renderer/utils/extract' -import * as aiSdk from 'ai' +import { type InferToolOutput, tool } from 'ai' +import { z } from 'zod' -import { AiSdkTool, ToolCallResult } from './types' +// import { AiSdkTool, ToolCallResult } from './types' -export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']): AiSdkTool => { +const WebSearchProviderResult = z.object({ + query: z.string().optional(), + results: z.array( + z.object({ + title: z.string(), + content: z.string(), + url: z.string() + }) + ) +}) + +export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => { const webSearchService = WebSearchService.getInstance(webSearchProviderId) - return { + return tool({ name: 'builtin_web_search', description: 'Search the web for information', - inputSchema: aiSdk.jsonSchema({ - type: 'object', - properties: { - query: { type: 'string', description: 'The query to search for' } - }, - required: ['query'] + inputSchema: z.object({ + query: z.string().describe('The query to search for') }), - execute: async ({ query }): Promise => { - try { - console.log('webSearchTool', query) - const response = await webSearchService.search(query) - console.log('webSearchTool response', response) - return { - success: true, - data: response - } - } catch (error) { - return { - success: false, - data: error - } - } + outputSchema: WebSearchProviderResult, + execute: async ({ query }) => { + console.log('webSearchTool', query) + const response = await webSearchService.search(query) + console.log('webSearchTool response', response) + return response } - } + }) } +export type WebSearchToolOutput = InferToolOutput> export const webSearchToolWithExtraction = ( webSearchProviderId: WebSearchProvider['id'], requestId: string, assistant: Assistant -): AiSdkTool => { +) => { const webSearchService = WebSearchService.getInstance(webSearchProviderId) - return { + return tool({ name: 'web_search_with_extraction', description: 'Search the web for information with automatic keyword extraction from user messages', - inputSchema: aiSdk.jsonSchema({ - type: 'object', - properties: { - userMessage: { - type: 'object', - description: 'The user message to extract keywords from', - properties: { - content: { type: 'string', description: 'The main content of the message' }, - role: { type: 'string', description: 'Message role (user/assistant/system)' } - }, - required: ['content', 'role'] - }, - lastAnswer: { - type: 'object', - description: 'Optional last assistant response for context', - properties: { - content: { type: 'string', description: 'The main content of the message' }, - role: { type: 'string', description: 'Message role (user/assistant/system)' } - }, - required: ['content', 'role'] - } - }, - required: ['userMessage'] + inputSchema: z.object({ + userMessage: z.object({ + content: z.string().describe('The main content of the message'), + role: z.enum(['user', 'assistant', 'system']).describe('Message role') + }), + lastAnswer: z.object({ + content: z.string().describe('The main content of the message'), + role: z.enum(['user', 'assistant', 'system']).describe('Message role') + }) }), - execute: async ({ userMessage, lastAnswer }): Promise => { - try { - const lastUserMessage: Message = { - id: requestId, - role: userMessage.role as 'user' | 'assistant' | 'system', - assistantId: assistant.id, - topicId: 'temp', - createdAt: new Date().toISOString(), - status: UserMessageStatus.SUCCESS, - blocks: [] - } - - const lastAnswerMessage: Message | undefined = lastAnswer - ? { - id: requestId + '_answer', - role: lastAnswer.role as 'user' | 'assistant' | 'system', - assistantId: assistant.id, - topicId: 'temp', - createdAt: new Date().toISOString(), - status: UserMessageStatus.SUCCESS, - blocks: [] - } - : undefined - - const extractResults = await extractSearchKeywords(lastUserMessage, assistant, { - shouldWebSearch: true, - shouldKnowledgeSearch: false, - lastAnswer: lastAnswerMessage + outputSchema: z.object({ + extractedKeywords: z.object({ + question: z.array(z.string()), + links: z.array(z.string()).optional() + }), + searchResults: z.array( + z.object({ + query: z.string(), + results: WebSearchProviderResult }) - - if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') { - return { - success: false, - data: 'No search needed or extraction failed' - } - } - - const searchQueries = extractResults.websearch.question - const searchResults: Array<{ query: string; results: any }> = [] - - for (const query of searchQueries) { - // 构建单个查询的ExtractResults结构 - const queryExtractResults: ExtractResults = { - websearch: { - question: [query], - links: extractResults.websearch.links - } - } - const response = await webSearchService.processWebsearch(queryExtractResults, requestId) - searchResults.push({ - query, - results: response - }) - } - - return { - success: true, - data: { - extractedKeywords: extractResults.websearch, - searchResults - } - } - } catch (error) { - return { - success: false, - data: error - } + ) + }), + execute: async ({ userMessage, lastAnswer }) => { + const lastUserMessage: Message = { + id: requestId, + role: userMessage.role, + assistantId: assistant.id, + topicId: 'temp', + createdAt: new Date().toISOString(), + status: UserMessageStatus.SUCCESS, + blocks: [] } + + const lastAnswerMessage: Message | undefined = lastAnswer + ? { + id: requestId + '_answer', + role: lastAnswer.role, + assistantId: assistant.id, + topicId: 'temp', + createdAt: new Date().toISOString(), + status: UserMessageStatus.SUCCESS, + blocks: [] + } + : undefined + + const extractResults = await extractSearchKeywords(lastUserMessage, assistant, { + shouldWebSearch: true, + shouldKnowledgeSearch: false, + lastAnswer: lastAnswerMessage + }) + + if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') { + return 'No search needed or extraction failed' + } + + const searchQueries = extractResults.websearch.question + const searchResults: Array<{ query: string; results: any }> = [] + + for (const query of searchQueries) { + // 构建单个查询的ExtractResults结构 + const queryExtractResults: ExtractResults = { + websearch: { + question: [query], + links: extractResults.websearch.links + } + } + const response = await webSearchService.processWebsearch(queryExtractResults, requestId) + searchResults.push({ + query, + results: response + }) + } + + return { extractedKeywords: extractResults.websearch, searchResults } } - } + }) } + +export type WebSearchToolWithExtractionOutput = InferToolOutput> diff --git a/src/renderer/src/aiCore/tools/types.ts b/src/renderer/src/aiCore/tools/types.ts index 23591c7f00..5a1f1675be 100644 --- a/src/renderer/src/aiCore/tools/types.ts +++ b/src/renderer/src/aiCore/tools/types.ts @@ -1,8 +1,8 @@ -import { Tool } from '@cherrystudio/ai-core' +// import { Tool } from '@cherrystudio/ai-core' -export type ToolCallResult = { - success: boolean - data: any -} +// export type ToolCallResult = { +// success: boolean +// data: any +// } -export type AiSdkTool = Tool +// export type AiSdkTool = Tool diff --git a/src/renderer/src/aiCore/utils/mcp.ts b/src/renderer/src/aiCore/utils/mcp.ts index 26a77c4f93..de635fbf8c 100644 --- a/src/renderer/src/aiCore/utils/mcp.ts +++ b/src/renderer/src/aiCore/utils/mcp.ts @@ -1,19 +1,20 @@ import { aiSdk, Tool } from '@cherrystudio/ai-core' -import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types' +// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types' import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant' import { isFunctionCallingModel } from '@renderer/config/models' import { MCPTool, MCPToolResponse, Model } from '@renderer/types' import { callMCPTool } from '@renderer/utils/mcp-tools' +import { tool } from 'ai' import { JSONSchema7 } from 'json-schema' // Setup tools configuration based on provided parameters export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): { - tools: Record + tools: Record useSystemPromptForTools?: boolean } { const { mcpTools, model, enableToolUse } = params - let tools: Record = {} + let tools: Record = {} if (!mcpTools?.length) { return { tools } @@ -35,15 +36,15 @@ export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; e /** * 将 MCPTool 转换为 AI SDK 工具格式 */ -export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record> { - const tools: Record> = {} +export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record { + const tools: Record = {} for (const mcpTool of mcpTools) { console.log('mcpTool', mcpTool.inputSchema) - tools[mcpTool.name] = aiSdk.tool({ + tools[mcpTool.name] = tool({ description: mcpTool.description || `Tool from ${mcpTool.serverName}`, inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7), - execute: async (params): Promise => { + execute: async (params) => { console.log('execute_params', params) // 创建适配的 MCPToolResponse 对象 const toolResponse: MCPToolResponse = { @@ -64,15 +65,10 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record