feat: enhance web search tool functionality and type definitions

- Introduced new `WebSearchToolOutputSchema` type to standardize output from web search tools.
- Updated `webSearchTool` and `webSearchToolWithExtraction` to utilize Zod for input and output schema validation.
- Refactored tool execution logic to improve error handling and response formatting.
- Cleaned up unused type imports and comments for better code clarity.
This commit is contained in:
MyPrototypeWhat 2025-07-18 19:33:54 +08:00
parent c3a6456499
commit 786bc8dca9
6 changed files with 148 additions and 139 deletions

View File

@ -41,3 +41,27 @@ export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
maxUses: 5
}
}
export type WebSearchToolOutputSchema = {
// Anthropic 工具 - 手动定义
anthropicWebSearch: Array<{
url: string
title: string
pageAge: string | null
encryptedContent: string
type: string
}>
// OpenAI 工具 - 基于实际输出
openaiWebSearch: {
status: 'completed' | 'failed'
}
// Google 工具
googleSearch: {
webSearchQueries?: string[]
groundingChunks?: Array<{
web?: { uri: string; title: string }
}>
}
}

View File

@ -64,7 +64,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
})
// 导出类型定义供开发者使用
export type { WebSearchPluginConfig } from './helper'
export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper'
// 默认导出
export default webSearchPlugin

View File

@ -3,143 +3,127 @@ import WebSearchService from '@renderer/services/WebSearchService'
import { Assistant, Message, WebSearchProvider } from '@renderer/types'
import { UserMessageStatus } from '@renderer/types/newMessage'
import { ExtractResults } from '@renderer/utils/extract'
import * as aiSdk from 'ai'
import { type InferToolOutput, tool } from 'ai'
import { z } from 'zod'
import { AiSdkTool, ToolCallResult } from './types'
// import { AiSdkTool, ToolCallResult } from './types'
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']): AiSdkTool => {
const WebSearchProviderResult = z.object({
query: z.string().optional(),
results: z.array(
z.object({
title: z.string(),
content: z.string(),
url: z.string()
})
)
})
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => {
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
return {
return tool({
name: 'builtin_web_search',
description: 'Search the web for information',
inputSchema: aiSdk.jsonSchema({
type: 'object',
properties: {
query: { type: 'string', description: 'The query to search for' }
},
required: ['query']
inputSchema: z.object({
query: z.string().describe('The query to search for')
}),
execute: async ({ query }): Promise<ToolCallResult> => {
try {
console.log('webSearchTool', query)
const response = await webSearchService.search(query)
console.log('webSearchTool response', response)
return {
success: true,
data: response
}
} catch (error) {
return {
success: false,
data: error
}
}
outputSchema: WebSearchProviderResult,
execute: async ({ query }) => {
console.log('webSearchTool', query)
const response = await webSearchService.search(query)
console.log('webSearchTool response', response)
return response
}
}
})
}
export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchTool>>
export const webSearchToolWithExtraction = (
webSearchProviderId: WebSearchProvider['id'],
requestId: string,
assistant: Assistant
): AiSdkTool => {
) => {
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
return {
return tool({
name: 'web_search_with_extraction',
description: 'Search the web for information with automatic keyword extraction from user messages',
inputSchema: aiSdk.jsonSchema({
type: 'object',
properties: {
userMessage: {
type: 'object',
description: 'The user message to extract keywords from',
properties: {
content: { type: 'string', description: 'The main content of the message' },
role: { type: 'string', description: 'Message role (user/assistant/system)' }
},
required: ['content', 'role']
},
lastAnswer: {
type: 'object',
description: 'Optional last assistant response for context',
properties: {
content: { type: 'string', description: 'The main content of the message' },
role: { type: 'string', description: 'Message role (user/assistant/system)' }
},
required: ['content', 'role']
}
},
required: ['userMessage']
inputSchema: z.object({
userMessage: z.object({
content: z.string().describe('The main content of the message'),
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
}),
lastAnswer: z.object({
content: z.string().describe('The main content of the message'),
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
})
}),
execute: async ({ userMessage, lastAnswer }): Promise<ToolCallResult> => {
try {
const lastUserMessage: Message = {
id: requestId,
role: userMessage.role as 'user' | 'assistant' | 'system',
assistantId: assistant.id,
topicId: 'temp',
createdAt: new Date().toISOString(),
status: UserMessageStatus.SUCCESS,
blocks: []
}
const lastAnswerMessage: Message | undefined = lastAnswer
? {
id: requestId + '_answer',
role: lastAnswer.role as 'user' | 'assistant' | 'system',
assistantId: assistant.id,
topicId: 'temp',
createdAt: new Date().toISOString(),
status: UserMessageStatus.SUCCESS,
blocks: []
}
: undefined
const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
shouldWebSearch: true,
shouldKnowledgeSearch: false,
lastAnswer: lastAnswerMessage
outputSchema: z.object({
extractedKeywords: z.object({
question: z.array(z.string()),
links: z.array(z.string()).optional()
}),
searchResults: z.array(
z.object({
query: z.string(),
results: WebSearchProviderResult
})
if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
return {
success: false,
data: 'No search needed or extraction failed'
}
}
const searchQueries = extractResults.websearch.question
const searchResults: Array<{ query: string; results: any }> = []
for (const query of searchQueries) {
// 构建单个查询的ExtractResults结构
const queryExtractResults: ExtractResults = {
websearch: {
question: [query],
links: extractResults.websearch.links
}
}
const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
searchResults.push({
query,
results: response
})
}
return {
success: true,
data: {
extractedKeywords: extractResults.websearch,
searchResults
}
}
} catch (error) {
return {
success: false,
data: error
}
)
}),
execute: async ({ userMessage, lastAnswer }) => {
const lastUserMessage: Message = {
id: requestId,
role: userMessage.role,
assistantId: assistant.id,
topicId: 'temp',
createdAt: new Date().toISOString(),
status: UserMessageStatus.SUCCESS,
blocks: []
}
const lastAnswerMessage: Message | undefined = lastAnswer
? {
id: requestId + '_answer',
role: lastAnswer.role,
assistantId: assistant.id,
topicId: 'temp',
createdAt: new Date().toISOString(),
status: UserMessageStatus.SUCCESS,
blocks: []
}
: undefined
const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
shouldWebSearch: true,
shouldKnowledgeSearch: false,
lastAnswer: lastAnswerMessage
})
if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
return 'No search needed or extraction failed'
}
const searchQueries = extractResults.websearch.question
const searchResults: Array<{ query: string; results: any }> = []
for (const query of searchQueries) {
// 构建单个查询的ExtractResults结构
const queryExtractResults: ExtractResults = {
websearch: {
question: [query],
links: extractResults.websearch.links
}
}
const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
searchResults.push({
query,
results: response
})
}
return { extractedKeywords: extractResults.websearch, searchResults }
}
}
})
}
export type WebSearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof webSearchToolWithExtraction>>

View File

@ -1,8 +1,8 @@
import { Tool } from '@cherrystudio/ai-core'
// import { Tool } from '@cherrystudio/ai-core'
export type ToolCallResult = {
success: boolean
data: any
}
// export type ToolCallResult = {
// success: boolean
// data: any
// }
export type AiSdkTool = Tool<any, ToolCallResult>
// export type AiSdkTool = Tool<any, ToolCallResult>

View File

@ -1,19 +1,20 @@
import { aiSdk, Tool } from '@cherrystudio/ai-core'
import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant'
import { isFunctionCallingModel } from '@renderer/config/models'
import { MCPTool, MCPToolResponse, Model } from '@renderer/types'
import { callMCPTool } from '@renderer/utils/mcp-tools'
import { tool } from 'ai'
import { JSONSchema7 } from 'json-schema'
// Setup tools configuration based on provided parameters
export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
tools: Record<string, AiSdkTool>
tools: Record<string, Tool>
useSystemPromptForTools?: boolean
} {
const { mcpTools, model, enableToolUse } = params
let tools: Record<string, AiSdkTool> = {}
let tools: Record<string, Tool> = {}
if (!mcpTools?.length) {
return { tools }
@ -35,15 +36,15 @@ export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; e
/**
* MCPTool AI SDK
*/
export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string, Tool<any, ToolCallResult>> {
const tools: Record<string, Tool<any, ToolCallResult>> = {}
export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string, Tool> {
const tools: Record<string, Tool> = {}
for (const mcpTool of mcpTools) {
console.log('mcpTool', mcpTool.inputSchema)
tools[mcpTool.name] = aiSdk.tool<any, ToolCallResult>({
tools[mcpTool.name] = tool({
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
execute: async (params): Promise<ToolCallResult> => {
execute: async (params) => {
console.log('execute_params', params)
// 创建适配的 MCPToolResponse 对象
const toolResponse: MCPToolResponse = {
@ -64,15 +65,10 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string,
}
console.log('result', result)
// 返回工具执行结果
return {
success: true,
data: result
}
return result
} catch (error) {
console.error(`MCP Tool execution failed: ${mcpTool.name}`, error)
throw new Error(
`Tool ${mcpTool.name} execution failed: ${error instanceof Error ? error.message : String(error)}`
)
throw error
}
}
})

View File

@ -1,3 +1,6 @@
import { WebSearchToolOutputSchema } from '@cherrystudio/ai-core/built-in/plugins'
import type { WebSearchToolOutput } from '@renderer/aiCore/tools/WebSearchTool'
import type { MCPToolInputSchema } from './index'
export type ToolType = 'builtin' | 'provider' | 'mcp'
@ -30,3 +33,5 @@ export interface MCPTool extends BaseTool {
inputSchema: MCPToolInputSchema
type: 'mcp'
}
export type WebSearchToolOutputSchema = WebSearchToolOutput | WebSearchToolOutputSchema