feat: enhance web search tool functionality and type definitions

- Introduced new `WebSearchToolOutputSchema` type to standardize output from web search tools.
- Updated `webSearchTool` and `webSearchToolWithExtraction` to utilize Zod for input and output schema validation.
- Refactored tool execution logic to improve error handling and response formatting.
- Cleaned up unused type imports and comments for better code clarity.
This commit is contained in:
MyPrototypeWhat 2025-07-18 19:33:54 +08:00
parent c3a6456499
commit 786bc8dca9
6 changed files with 148 additions and 139 deletions

View File

@ -41,3 +41,27 @@ export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
maxUses: 5 maxUses: 5
} }
} }
export type WebSearchToolOutputSchema = {
// Anthropic 工具 - 手动定义
anthropicWebSearch: Array<{
url: string
title: string
pageAge: string | null
encryptedContent: string
type: string
}>
// OpenAI 工具 - 基于实际输出
openaiWebSearch: {
status: 'completed' | 'failed'
}
// Google 工具
googleSearch: {
webSearchQueries?: string[]
groundingChunks?: Array<{
web?: { uri: string; title: string }
}>
}
}

View File

@ -64,7 +64,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
}) })
// 导出类型定义供开发者使用 // 导出类型定义供开发者使用
export type { WebSearchPluginConfig } from './helper' export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper'
// 默认导出 // 默认导出
export default webSearchPlugin export default webSearchPlugin

View File

@ -3,143 +3,127 @@ import WebSearchService from '@renderer/services/WebSearchService'
import { Assistant, Message, WebSearchProvider } from '@renderer/types' import { Assistant, Message, WebSearchProvider } from '@renderer/types'
import { UserMessageStatus } from '@renderer/types/newMessage' import { UserMessageStatus } from '@renderer/types/newMessage'
import { ExtractResults } from '@renderer/utils/extract' import { ExtractResults } from '@renderer/utils/extract'
import * as aiSdk from 'ai' import { type InferToolOutput, tool } from 'ai'
import { z } from 'zod'
import { AiSdkTool, ToolCallResult } from './types' // import { AiSdkTool, ToolCallResult } from './types'
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']): AiSdkTool => { const WebSearchProviderResult = z.object({
query: z.string().optional(),
results: z.array(
z.object({
title: z.string(),
content: z.string(),
url: z.string()
})
)
})
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => {
const webSearchService = WebSearchService.getInstance(webSearchProviderId) const webSearchService = WebSearchService.getInstance(webSearchProviderId)
return { return tool({
name: 'builtin_web_search', name: 'builtin_web_search',
description: 'Search the web for information', description: 'Search the web for information',
inputSchema: aiSdk.jsonSchema({ inputSchema: z.object({
type: 'object', query: z.string().describe('The query to search for')
properties: {
query: { type: 'string', description: 'The query to search for' }
},
required: ['query']
}), }),
execute: async ({ query }): Promise<ToolCallResult> => { outputSchema: WebSearchProviderResult,
try { execute: async ({ query }) => {
console.log('webSearchTool', query) console.log('webSearchTool', query)
const response = await webSearchService.search(query) const response = await webSearchService.search(query)
console.log('webSearchTool response', response) console.log('webSearchTool response', response)
return { return response
success: true,
data: response
}
} catch (error) {
return {
success: false,
data: error
}
}
} }
} })
} }
export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchTool>>
export const webSearchToolWithExtraction = ( export const webSearchToolWithExtraction = (
webSearchProviderId: WebSearchProvider['id'], webSearchProviderId: WebSearchProvider['id'],
requestId: string, requestId: string,
assistant: Assistant assistant: Assistant
): AiSdkTool => { ) => {
const webSearchService = WebSearchService.getInstance(webSearchProviderId) const webSearchService = WebSearchService.getInstance(webSearchProviderId)
return { return tool({
name: 'web_search_with_extraction', name: 'web_search_with_extraction',
description: 'Search the web for information with automatic keyword extraction from user messages', description: 'Search the web for information with automatic keyword extraction from user messages',
inputSchema: aiSdk.jsonSchema({ inputSchema: z.object({
type: 'object', userMessage: z.object({
properties: { content: z.string().describe('The main content of the message'),
userMessage: { role: z.enum(['user', 'assistant', 'system']).describe('Message role')
type: 'object', }),
description: 'The user message to extract keywords from', lastAnswer: z.object({
properties: { content: z.string().describe('The main content of the message'),
content: { type: 'string', description: 'The main content of the message' }, role: z.enum(['user', 'assistant', 'system']).describe('Message role')
role: { type: 'string', description: 'Message role (user/assistant/system)' } })
},
required: ['content', 'role']
},
lastAnswer: {
type: 'object',
description: 'Optional last assistant response for context',
properties: {
content: { type: 'string', description: 'The main content of the message' },
role: { type: 'string', description: 'Message role (user/assistant/system)' }
},
required: ['content', 'role']
}
},
required: ['userMessage']
}), }),
execute: async ({ userMessage, lastAnswer }): Promise<ToolCallResult> => { outputSchema: z.object({
try { extractedKeywords: z.object({
const lastUserMessage: Message = { question: z.array(z.string()),
id: requestId, links: z.array(z.string()).optional()
role: userMessage.role as 'user' | 'assistant' | 'system', }),
assistantId: assistant.id, searchResults: z.array(
topicId: 'temp', z.object({
createdAt: new Date().toISOString(), query: z.string(),
status: UserMessageStatus.SUCCESS, results: WebSearchProviderResult
blocks: []
}
const lastAnswerMessage: Message | undefined = lastAnswer
? {
id: requestId + '_answer',
role: lastAnswer.role as 'user' | 'assistant' | 'system',
assistantId: assistant.id,
topicId: 'temp',
createdAt: new Date().toISOString(),
status: UserMessageStatus.SUCCESS,
blocks: []
}
: undefined
const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
shouldWebSearch: true,
shouldKnowledgeSearch: false,
lastAnswer: lastAnswerMessage
}) })
)
if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') { }),
return { execute: async ({ userMessage, lastAnswer }) => {
success: false, const lastUserMessage: Message = {
data: 'No search needed or extraction failed' id: requestId,
} role: userMessage.role,
} assistantId: assistant.id,
topicId: 'temp',
const searchQueries = extractResults.websearch.question createdAt: new Date().toISOString(),
const searchResults: Array<{ query: string; results: any }> = [] status: UserMessageStatus.SUCCESS,
blocks: []
for (const query of searchQueries) {
// 构建单个查询的ExtractResults结构
const queryExtractResults: ExtractResults = {
websearch: {
question: [query],
links: extractResults.websearch.links
}
}
const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
searchResults.push({
query,
results: response
})
}
return {
success: true,
data: {
extractedKeywords: extractResults.websearch,
searchResults
}
}
} catch (error) {
return {
success: false,
data: error
}
} }
const lastAnswerMessage: Message | undefined = lastAnswer
? {
id: requestId + '_answer',
role: lastAnswer.role,
assistantId: assistant.id,
topicId: 'temp',
createdAt: new Date().toISOString(),
status: UserMessageStatus.SUCCESS,
blocks: []
}
: undefined
const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
shouldWebSearch: true,
shouldKnowledgeSearch: false,
lastAnswer: lastAnswerMessage
})
if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
return 'No search needed or extraction failed'
}
const searchQueries = extractResults.websearch.question
const searchResults: Array<{ query: string; results: any }> = []
for (const query of searchQueries) {
// 构建单个查询的ExtractResults结构
const queryExtractResults: ExtractResults = {
websearch: {
question: [query],
links: extractResults.websearch.links
}
}
const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
searchResults.push({
query,
results: response
})
}
return { extractedKeywords: extractResults.websearch, searchResults }
} }
} })
} }
export type WebSearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof webSearchToolWithExtraction>>

View File

@ -1,8 +1,8 @@
import { Tool } from '@cherrystudio/ai-core' // import { Tool } from '@cherrystudio/ai-core'
export type ToolCallResult = { // export type ToolCallResult = {
success: boolean // success: boolean
data: any // data: any
} // }
export type AiSdkTool = Tool<any, ToolCallResult> // export type AiSdkTool = Tool<any, ToolCallResult>

View File

@ -1,19 +1,20 @@
import { aiSdk, Tool } from '@cherrystudio/ai-core' import { aiSdk, Tool } from '@cherrystudio/ai-core'
import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types' // import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant' import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant'
import { isFunctionCallingModel } from '@renderer/config/models' import { isFunctionCallingModel } from '@renderer/config/models'
import { MCPTool, MCPToolResponse, Model } from '@renderer/types' import { MCPTool, MCPToolResponse, Model } from '@renderer/types'
import { callMCPTool } from '@renderer/utils/mcp-tools' import { callMCPTool } from '@renderer/utils/mcp-tools'
import { tool } from 'ai'
import { JSONSchema7 } from 'json-schema' import { JSONSchema7 } from 'json-schema'
// Setup tools configuration based on provided parameters // Setup tools configuration based on provided parameters
export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): { export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
tools: Record<string, AiSdkTool> tools: Record<string, Tool>
useSystemPromptForTools?: boolean useSystemPromptForTools?: boolean
} { } {
const { mcpTools, model, enableToolUse } = params const { mcpTools, model, enableToolUse } = params
let tools: Record<string, AiSdkTool> = {} let tools: Record<string, Tool> = {}
if (!mcpTools?.length) { if (!mcpTools?.length) {
return { tools } return { tools }
@ -35,15 +36,15 @@ export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; e
/** /**
* MCPTool AI SDK * MCPTool AI SDK
*/ */
export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string, Tool<any, ToolCallResult>> { export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string, Tool> {
const tools: Record<string, Tool<any, ToolCallResult>> = {} const tools: Record<string, Tool> = {}
for (const mcpTool of mcpTools) { for (const mcpTool of mcpTools) {
console.log('mcpTool', mcpTool.inputSchema) console.log('mcpTool', mcpTool.inputSchema)
tools[mcpTool.name] = aiSdk.tool<any, ToolCallResult>({ tools[mcpTool.name] = tool({
description: mcpTool.description || `Tool from ${mcpTool.serverName}`, description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7), inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
execute: async (params): Promise<ToolCallResult> => { execute: async (params) => {
console.log('execute_params', params) console.log('execute_params', params)
// 创建适配的 MCPToolResponse 对象 // 创建适配的 MCPToolResponse 对象
const toolResponse: MCPToolResponse = { const toolResponse: MCPToolResponse = {
@ -64,15 +65,10 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string,
} }
console.log('result', result) console.log('result', result)
// 返回工具执行结果 // 返回工具执行结果
return { return result
success: true,
data: result
}
} catch (error) { } catch (error) {
console.error(`MCP Tool execution failed: ${mcpTool.name}`, error) console.error(`MCP Tool execution failed: ${mcpTool.name}`, error)
throw new Error( throw error
`Tool ${mcpTool.name} execution failed: ${error instanceof Error ? error.message : String(error)}`
)
} }
} }
}) })

View File

@ -1,3 +1,6 @@
import { WebSearchToolOutputSchema } from '@cherrystudio/ai-core/built-in/plugins'
import type { WebSearchToolOutput } from '@renderer/aiCore/tools/WebSearchTool'
import type { MCPToolInputSchema } from './index' import type { MCPToolInputSchema } from './index'
export type ToolType = 'builtin' | 'provider' | 'mcp' export type ToolType = 'builtin' | 'provider' | 'mcp'
@ -30,3 +33,5 @@ export interface MCPTool extends BaseTool {
inputSchema: MCPToolInputSchema inputSchema: MCPToolInputSchema
type: 'mcp' type: 'mcp'
} }
export type WebSearchToolOutputSchema = WebSearchToolOutput | WebSearchToolOutputSchema