mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 21:01:32 +08:00
feat: enhance web search tool functionality and type definitions
- Introduced new `WebSearchToolOutputSchema` type to standardize output from web search tools. - Updated `webSearchTool` and `webSearchToolWithExtraction` to utilize Zod for input and output schema validation. - Refactored tool execution logic to improve error handling and response formatting. - Cleaned up unused type imports and comments for better code clarity.
This commit is contained in:
parent
c3a6456499
commit
786bc8dca9
@ -41,3 +41,27 @@ export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
|
||||
maxUses: 5
|
||||
}
|
||||
}
|
||||
|
||||
export type WebSearchToolOutputSchema = {
|
||||
// Anthropic 工具 - 手动定义
|
||||
anthropicWebSearch: Array<{
|
||||
url: string
|
||||
title: string
|
||||
pageAge: string | null
|
||||
encryptedContent: string
|
||||
type: string
|
||||
}>
|
||||
|
||||
// OpenAI 工具 - 基于实际输出
|
||||
openaiWebSearch: {
|
||||
status: 'completed' | 'failed'
|
||||
}
|
||||
|
||||
// Google 工具
|
||||
googleSearch: {
|
||||
webSearchQueries?: string[]
|
||||
groundingChunks?: Array<{
|
||||
web?: { uri: string; title: string }
|
||||
}>
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,7 +64,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
|
||||
})
|
||||
|
||||
// 导出类型定义供开发者使用
|
||||
export type { WebSearchPluginConfig } from './helper'
|
||||
export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper'
|
||||
|
||||
// 默认导出
|
||||
export default webSearchPlugin
|
||||
|
||||
@ -3,143 +3,127 @@ import WebSearchService from '@renderer/services/WebSearchService'
|
||||
import { Assistant, Message, WebSearchProvider } from '@renderer/types'
|
||||
import { UserMessageStatus } from '@renderer/types/newMessage'
|
||||
import { ExtractResults } from '@renderer/utils/extract'
|
||||
import * as aiSdk from 'ai'
|
||||
import { type InferToolOutput, tool } from 'ai'
|
||||
import { z } from 'zod'
|
||||
|
||||
import { AiSdkTool, ToolCallResult } from './types'
|
||||
// import { AiSdkTool, ToolCallResult } from './types'
|
||||
|
||||
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']): AiSdkTool => {
|
||||
const WebSearchProviderResult = z.object({
|
||||
query: z.string().optional(),
|
||||
results: z.array(
|
||||
z.object({
|
||||
title: z.string(),
|
||||
content: z.string(),
|
||||
url: z.string()
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => {
|
||||
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||
return {
|
||||
return tool({
|
||||
name: 'builtin_web_search',
|
||||
description: 'Search the web for information',
|
||||
inputSchema: aiSdk.jsonSchema({
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: { type: 'string', description: 'The query to search for' }
|
||||
},
|
||||
required: ['query']
|
||||
inputSchema: z.object({
|
||||
query: z.string().describe('The query to search for')
|
||||
}),
|
||||
execute: async ({ query }): Promise<ToolCallResult> => {
|
||||
try {
|
||||
console.log('webSearchTool', query)
|
||||
const response = await webSearchService.search(query)
|
||||
console.log('webSearchTool response', response)
|
||||
return {
|
||||
success: true,
|
||||
data: response
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
success: false,
|
||||
data: error
|
||||
}
|
||||
}
|
||||
outputSchema: WebSearchProviderResult,
|
||||
execute: async ({ query }) => {
|
||||
console.log('webSearchTool', query)
|
||||
const response = await webSearchService.search(query)
|
||||
console.log('webSearchTool response', response)
|
||||
return response
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchTool>>
|
||||
|
||||
export const webSearchToolWithExtraction = (
|
||||
webSearchProviderId: WebSearchProvider['id'],
|
||||
requestId: string,
|
||||
assistant: Assistant
|
||||
): AiSdkTool => {
|
||||
) => {
|
||||
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||
|
||||
return {
|
||||
return tool({
|
||||
name: 'web_search_with_extraction',
|
||||
description: 'Search the web for information with automatic keyword extraction from user messages',
|
||||
inputSchema: aiSdk.jsonSchema({
|
||||
type: 'object',
|
||||
properties: {
|
||||
userMessage: {
|
||||
type: 'object',
|
||||
description: 'The user message to extract keywords from',
|
||||
properties: {
|
||||
content: { type: 'string', description: 'The main content of the message' },
|
||||
role: { type: 'string', description: 'Message role (user/assistant/system)' }
|
||||
},
|
||||
required: ['content', 'role']
|
||||
},
|
||||
lastAnswer: {
|
||||
type: 'object',
|
||||
description: 'Optional last assistant response for context',
|
||||
properties: {
|
||||
content: { type: 'string', description: 'The main content of the message' },
|
||||
role: { type: 'string', description: 'Message role (user/assistant/system)' }
|
||||
},
|
||||
required: ['content', 'role']
|
||||
}
|
||||
},
|
||||
required: ['userMessage']
|
||||
inputSchema: z.object({
|
||||
userMessage: z.object({
|
||||
content: z.string().describe('The main content of the message'),
|
||||
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
}),
|
||||
lastAnswer: z.object({
|
||||
content: z.string().describe('The main content of the message'),
|
||||
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
})
|
||||
}),
|
||||
execute: async ({ userMessage, lastAnswer }): Promise<ToolCallResult> => {
|
||||
try {
|
||||
const lastUserMessage: Message = {
|
||||
id: requestId,
|
||||
role: userMessage.role as 'user' | 'assistant' | 'system',
|
||||
assistantId: assistant.id,
|
||||
topicId: 'temp',
|
||||
createdAt: new Date().toISOString(),
|
||||
status: UserMessageStatus.SUCCESS,
|
||||
blocks: []
|
||||
}
|
||||
|
||||
const lastAnswerMessage: Message | undefined = lastAnswer
|
||||
? {
|
||||
id: requestId + '_answer',
|
||||
role: lastAnswer.role as 'user' | 'assistant' | 'system',
|
||||
assistantId: assistant.id,
|
||||
topicId: 'temp',
|
||||
createdAt: new Date().toISOString(),
|
||||
status: UserMessageStatus.SUCCESS,
|
||||
blocks: []
|
||||
}
|
||||
: undefined
|
||||
|
||||
const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
|
||||
shouldWebSearch: true,
|
||||
shouldKnowledgeSearch: false,
|
||||
lastAnswer: lastAnswerMessage
|
||||
outputSchema: z.object({
|
||||
extractedKeywords: z.object({
|
||||
question: z.array(z.string()),
|
||||
links: z.array(z.string()).optional()
|
||||
}),
|
||||
searchResults: z.array(
|
||||
z.object({
|
||||
query: z.string(),
|
||||
results: WebSearchProviderResult
|
||||
})
|
||||
|
||||
if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
|
||||
return {
|
||||
success: false,
|
||||
data: 'No search needed or extraction failed'
|
||||
}
|
||||
}
|
||||
|
||||
const searchQueries = extractResults.websearch.question
|
||||
const searchResults: Array<{ query: string; results: any }> = []
|
||||
|
||||
for (const query of searchQueries) {
|
||||
// 构建单个查询的ExtractResults结构
|
||||
const queryExtractResults: ExtractResults = {
|
||||
websearch: {
|
||||
question: [query],
|
||||
links: extractResults.websearch.links
|
||||
}
|
||||
}
|
||||
const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
|
||||
searchResults.push({
|
||||
query,
|
||||
results: response
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
success: true,
|
||||
data: {
|
||||
extractedKeywords: extractResults.websearch,
|
||||
searchResults
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
success: false,
|
||||
data: error
|
||||
}
|
||||
)
|
||||
}),
|
||||
execute: async ({ userMessage, lastAnswer }) => {
|
||||
const lastUserMessage: Message = {
|
||||
id: requestId,
|
||||
role: userMessage.role,
|
||||
assistantId: assistant.id,
|
||||
topicId: 'temp',
|
||||
createdAt: new Date().toISOString(),
|
||||
status: UserMessageStatus.SUCCESS,
|
||||
blocks: []
|
||||
}
|
||||
|
||||
const lastAnswerMessage: Message | undefined = lastAnswer
|
||||
? {
|
||||
id: requestId + '_answer',
|
||||
role: lastAnswer.role,
|
||||
assistantId: assistant.id,
|
||||
topicId: 'temp',
|
||||
createdAt: new Date().toISOString(),
|
||||
status: UserMessageStatus.SUCCESS,
|
||||
blocks: []
|
||||
}
|
||||
: undefined
|
||||
|
||||
const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
|
||||
shouldWebSearch: true,
|
||||
shouldKnowledgeSearch: false,
|
||||
lastAnswer: lastAnswerMessage
|
||||
})
|
||||
|
||||
if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
|
||||
return 'No search needed or extraction failed'
|
||||
}
|
||||
|
||||
const searchQueries = extractResults.websearch.question
|
||||
const searchResults: Array<{ query: string; results: any }> = []
|
||||
|
||||
for (const query of searchQueries) {
|
||||
// 构建单个查询的ExtractResults结构
|
||||
const queryExtractResults: ExtractResults = {
|
||||
websearch: {
|
||||
question: [query],
|
||||
links: extractResults.websearch.links
|
||||
}
|
||||
}
|
||||
const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
|
||||
searchResults.push({
|
||||
query,
|
||||
results: response
|
||||
})
|
||||
}
|
||||
|
||||
return { extractedKeywords: extractResults.websearch, searchResults }
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export type WebSearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof webSearchToolWithExtraction>>
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import { Tool } from '@cherrystudio/ai-core'
|
||||
// import { Tool } from '@cherrystudio/ai-core'
|
||||
|
||||
export type ToolCallResult = {
|
||||
success: boolean
|
||||
data: any
|
||||
}
|
||||
// export type ToolCallResult = {
|
||||
// success: boolean
|
||||
// data: any
|
||||
// }
|
||||
|
||||
export type AiSdkTool = Tool<any, ToolCallResult>
|
||||
// export type AiSdkTool = Tool<any, ToolCallResult>
|
||||
|
||||
@ -1,19 +1,20 @@
|
||||
import { aiSdk, Tool } from '@cherrystudio/ai-core'
|
||||
import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
||||
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
||||
import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant'
|
||||
import { isFunctionCallingModel } from '@renderer/config/models'
|
||||
import { MCPTool, MCPToolResponse, Model } from '@renderer/types'
|
||||
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
||||
import { tool } from 'ai'
|
||||
import { JSONSchema7 } from 'json-schema'
|
||||
|
||||
// Setup tools configuration based on provided parameters
|
||||
export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||
tools: Record<string, AiSdkTool>
|
||||
tools: Record<string, Tool>
|
||||
useSystemPromptForTools?: boolean
|
||||
} {
|
||||
const { mcpTools, model, enableToolUse } = params
|
||||
|
||||
let tools: Record<string, AiSdkTool> = {}
|
||||
let tools: Record<string, Tool> = {}
|
||||
|
||||
if (!mcpTools?.length) {
|
||||
return { tools }
|
||||
@ -35,15 +36,15 @@ export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; e
|
||||
/**
|
||||
* 将 MCPTool 转换为 AI SDK 工具格式
|
||||
*/
|
||||
export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string, Tool<any, ToolCallResult>> {
|
||||
const tools: Record<string, Tool<any, ToolCallResult>> = {}
|
||||
export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string, Tool> {
|
||||
const tools: Record<string, Tool> = {}
|
||||
|
||||
for (const mcpTool of mcpTools) {
|
||||
console.log('mcpTool', mcpTool.inputSchema)
|
||||
tools[mcpTool.name] = aiSdk.tool<any, ToolCallResult>({
|
||||
tools[mcpTool.name] = tool({
|
||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||
inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||
execute: async (params): Promise<ToolCallResult> => {
|
||||
execute: async (params) => {
|
||||
console.log('execute_params', params)
|
||||
// 创建适配的 MCPToolResponse 对象
|
||||
const toolResponse: MCPToolResponse = {
|
||||
@ -64,15 +65,10 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string,
|
||||
}
|
||||
console.log('result', result)
|
||||
// 返回工具执行结果
|
||||
return {
|
||||
success: true,
|
||||
data: result
|
||||
}
|
||||
return result
|
||||
} catch (error) {
|
||||
console.error(`MCP Tool execution failed: ${mcpTool.name}`, error)
|
||||
throw new Error(
|
||||
`Tool ${mcpTool.name} execution failed: ${error instanceof Error ? error.message : String(error)}`
|
||||
)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
import { WebSearchToolOutputSchema } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import type { WebSearchToolOutput } from '@renderer/aiCore/tools/WebSearchTool'
|
||||
|
||||
import type { MCPToolInputSchema } from './index'
|
||||
|
||||
export type ToolType = 'builtin' | 'provider' | 'mcp'
|
||||
@ -30,3 +33,5 @@ export interface MCPTool extends BaseTool {
|
||||
inputSchema: MCPToolInputSchema
|
||||
type: 'mcp'
|
||||
}
|
||||
|
||||
export type WebSearchToolOutputSchema = WebSearchToolOutput | WebSearchToolOutputSchema
|
||||
|
||||
Loading…
Reference in New Issue
Block a user