mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-08 14:29:15 +08:00
feat: enhance web search tool functionality and type definitions
- Introduced new `WebSearchToolOutputSchema` type to standardize output from web search tools. - Updated `webSearchTool` and `webSearchToolWithExtraction` to utilize Zod for input and output schema validation. - Refactored tool execution logic to improve error handling and response formatting. - Cleaned up unused type imports and comments for better code clarity.
This commit is contained in:
parent
c3a6456499
commit
786bc8dca9
@ -41,3 +41,27 @@ export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
|
|||||||
maxUses: 5
|
maxUses: 5
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type WebSearchToolOutputSchema = {
|
||||||
|
// Anthropic 工具 - 手动定义
|
||||||
|
anthropicWebSearch: Array<{
|
||||||
|
url: string
|
||||||
|
title: string
|
||||||
|
pageAge: string | null
|
||||||
|
encryptedContent: string
|
||||||
|
type: string
|
||||||
|
}>
|
||||||
|
|
||||||
|
// OpenAI 工具 - 基于实际输出
|
||||||
|
openaiWebSearch: {
|
||||||
|
status: 'completed' | 'failed'
|
||||||
|
}
|
||||||
|
|
||||||
|
// Google 工具
|
||||||
|
googleSearch: {
|
||||||
|
webSearchQueries?: string[]
|
||||||
|
groundingChunks?: Array<{
|
||||||
|
web?: { uri: string; title: string }
|
||||||
|
}>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -64,7 +64,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
|
|||||||
})
|
})
|
||||||
|
|
||||||
// 导出类型定义供开发者使用
|
// 导出类型定义供开发者使用
|
||||||
export type { WebSearchPluginConfig } from './helper'
|
export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper'
|
||||||
|
|
||||||
// 默认导出
|
// 默认导出
|
||||||
export default webSearchPlugin
|
export default webSearchPlugin
|
||||||
|
|||||||
@ -3,143 +3,127 @@ import WebSearchService from '@renderer/services/WebSearchService'
|
|||||||
import { Assistant, Message, WebSearchProvider } from '@renderer/types'
|
import { Assistant, Message, WebSearchProvider } from '@renderer/types'
|
||||||
import { UserMessageStatus } from '@renderer/types/newMessage'
|
import { UserMessageStatus } from '@renderer/types/newMessage'
|
||||||
import { ExtractResults } from '@renderer/utils/extract'
|
import { ExtractResults } from '@renderer/utils/extract'
|
||||||
import * as aiSdk from 'ai'
|
import { type InferToolOutput, tool } from 'ai'
|
||||||
|
import { z } from 'zod'
|
||||||
|
|
||||||
import { AiSdkTool, ToolCallResult } from './types'
|
// import { AiSdkTool, ToolCallResult } from './types'
|
||||||
|
|
||||||
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']): AiSdkTool => {
|
const WebSearchProviderResult = z.object({
|
||||||
|
query: z.string().optional(),
|
||||||
|
results: z.array(
|
||||||
|
z.object({
|
||||||
|
title: z.string(),
|
||||||
|
content: z.string(),
|
||||||
|
url: z.string()
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => {
|
||||||
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||||
return {
|
return tool({
|
||||||
name: 'builtin_web_search',
|
name: 'builtin_web_search',
|
||||||
description: 'Search the web for information',
|
description: 'Search the web for information',
|
||||||
inputSchema: aiSdk.jsonSchema({
|
inputSchema: z.object({
|
||||||
type: 'object',
|
query: z.string().describe('The query to search for')
|
||||||
properties: {
|
|
||||||
query: { type: 'string', description: 'The query to search for' }
|
|
||||||
},
|
|
||||||
required: ['query']
|
|
||||||
}),
|
}),
|
||||||
execute: async ({ query }): Promise<ToolCallResult> => {
|
outputSchema: WebSearchProviderResult,
|
||||||
try {
|
execute: async ({ query }) => {
|
||||||
console.log('webSearchTool', query)
|
console.log('webSearchTool', query)
|
||||||
const response = await webSearchService.search(query)
|
const response = await webSearchService.search(query)
|
||||||
console.log('webSearchTool response', response)
|
console.log('webSearchTool response', response)
|
||||||
return {
|
return response
|
||||||
success: true,
|
|
||||||
data: response
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
return {
|
|
||||||
success: false,
|
|
||||||
data: error
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchTool>>
|
||||||
|
|
||||||
export const webSearchToolWithExtraction = (
|
export const webSearchToolWithExtraction = (
|
||||||
webSearchProviderId: WebSearchProvider['id'],
|
webSearchProviderId: WebSearchProvider['id'],
|
||||||
requestId: string,
|
requestId: string,
|
||||||
assistant: Assistant
|
assistant: Assistant
|
||||||
): AiSdkTool => {
|
) => {
|
||||||
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||||
|
|
||||||
return {
|
return tool({
|
||||||
name: 'web_search_with_extraction',
|
name: 'web_search_with_extraction',
|
||||||
description: 'Search the web for information with automatic keyword extraction from user messages',
|
description: 'Search the web for information with automatic keyword extraction from user messages',
|
||||||
inputSchema: aiSdk.jsonSchema({
|
inputSchema: z.object({
|
||||||
type: 'object',
|
userMessage: z.object({
|
||||||
properties: {
|
content: z.string().describe('The main content of the message'),
|
||||||
userMessage: {
|
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||||
type: 'object',
|
}),
|
||||||
description: 'The user message to extract keywords from',
|
lastAnswer: z.object({
|
||||||
properties: {
|
content: z.string().describe('The main content of the message'),
|
||||||
content: { type: 'string', description: 'The main content of the message' },
|
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||||
role: { type: 'string', description: 'Message role (user/assistant/system)' }
|
})
|
||||||
},
|
|
||||||
required: ['content', 'role']
|
|
||||||
},
|
|
||||||
lastAnswer: {
|
|
||||||
type: 'object',
|
|
||||||
description: 'Optional last assistant response for context',
|
|
||||||
properties: {
|
|
||||||
content: { type: 'string', description: 'The main content of the message' },
|
|
||||||
role: { type: 'string', description: 'Message role (user/assistant/system)' }
|
|
||||||
},
|
|
||||||
required: ['content', 'role']
|
|
||||||
}
|
|
||||||
},
|
|
||||||
required: ['userMessage']
|
|
||||||
}),
|
}),
|
||||||
execute: async ({ userMessage, lastAnswer }): Promise<ToolCallResult> => {
|
outputSchema: z.object({
|
||||||
try {
|
extractedKeywords: z.object({
|
||||||
const lastUserMessage: Message = {
|
question: z.array(z.string()),
|
||||||
id: requestId,
|
links: z.array(z.string()).optional()
|
||||||
role: userMessage.role as 'user' | 'assistant' | 'system',
|
}),
|
||||||
assistantId: assistant.id,
|
searchResults: z.array(
|
||||||
topicId: 'temp',
|
z.object({
|
||||||
createdAt: new Date().toISOString(),
|
query: z.string(),
|
||||||
status: UserMessageStatus.SUCCESS,
|
results: WebSearchProviderResult
|
||||||
blocks: []
|
|
||||||
}
|
|
||||||
|
|
||||||
const lastAnswerMessage: Message | undefined = lastAnswer
|
|
||||||
? {
|
|
||||||
id: requestId + '_answer',
|
|
||||||
role: lastAnswer.role as 'user' | 'assistant' | 'system',
|
|
||||||
assistantId: assistant.id,
|
|
||||||
topicId: 'temp',
|
|
||||||
createdAt: new Date().toISOString(),
|
|
||||||
status: UserMessageStatus.SUCCESS,
|
|
||||||
blocks: []
|
|
||||||
}
|
|
||||||
: undefined
|
|
||||||
|
|
||||||
const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
|
|
||||||
shouldWebSearch: true,
|
|
||||||
shouldKnowledgeSearch: false,
|
|
||||||
lastAnswer: lastAnswerMessage
|
|
||||||
})
|
})
|
||||||
|
)
|
||||||
if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
|
}),
|
||||||
return {
|
execute: async ({ userMessage, lastAnswer }) => {
|
||||||
success: false,
|
const lastUserMessage: Message = {
|
||||||
data: 'No search needed or extraction failed'
|
id: requestId,
|
||||||
}
|
role: userMessage.role,
|
||||||
}
|
assistantId: assistant.id,
|
||||||
|
topicId: 'temp',
|
||||||
const searchQueries = extractResults.websearch.question
|
createdAt: new Date().toISOString(),
|
||||||
const searchResults: Array<{ query: string; results: any }> = []
|
status: UserMessageStatus.SUCCESS,
|
||||||
|
blocks: []
|
||||||
for (const query of searchQueries) {
|
|
||||||
// 构建单个查询的ExtractResults结构
|
|
||||||
const queryExtractResults: ExtractResults = {
|
|
||||||
websearch: {
|
|
||||||
question: [query],
|
|
||||||
links: extractResults.websearch.links
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
|
|
||||||
searchResults.push({
|
|
||||||
query,
|
|
||||||
results: response
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
success: true,
|
|
||||||
data: {
|
|
||||||
extractedKeywords: extractResults.websearch,
|
|
||||||
searchResults
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
return {
|
|
||||||
success: false,
|
|
||||||
data: error
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const lastAnswerMessage: Message | undefined = lastAnswer
|
||||||
|
? {
|
||||||
|
id: requestId + '_answer',
|
||||||
|
role: lastAnswer.role,
|
||||||
|
assistantId: assistant.id,
|
||||||
|
topicId: 'temp',
|
||||||
|
createdAt: new Date().toISOString(),
|
||||||
|
status: UserMessageStatus.SUCCESS,
|
||||||
|
blocks: []
|
||||||
|
}
|
||||||
|
: undefined
|
||||||
|
|
||||||
|
const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
|
||||||
|
shouldWebSearch: true,
|
||||||
|
shouldKnowledgeSearch: false,
|
||||||
|
lastAnswer: lastAnswerMessage
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
|
||||||
|
return 'No search needed or extraction failed'
|
||||||
|
}
|
||||||
|
|
||||||
|
const searchQueries = extractResults.websearch.question
|
||||||
|
const searchResults: Array<{ query: string; results: any }> = []
|
||||||
|
|
||||||
|
for (const query of searchQueries) {
|
||||||
|
// 构建单个查询的ExtractResults结构
|
||||||
|
const queryExtractResults: ExtractResults = {
|
||||||
|
websearch: {
|
||||||
|
question: [query],
|
||||||
|
links: extractResults.websearch.links
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
|
||||||
|
searchResults.push({
|
||||||
|
query,
|
||||||
|
results: response
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return { extractedKeywords: extractResults.websearch, searchResults }
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type WebSearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof webSearchToolWithExtraction>>
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import { Tool } from '@cherrystudio/ai-core'
|
// import { Tool } from '@cherrystudio/ai-core'
|
||||||
|
|
||||||
export type ToolCallResult = {
|
// export type ToolCallResult = {
|
||||||
success: boolean
|
// success: boolean
|
||||||
data: any
|
// data: any
|
||||||
}
|
// }
|
||||||
|
|
||||||
export type AiSdkTool = Tool<any, ToolCallResult>
|
// export type AiSdkTool = Tool<any, ToolCallResult>
|
||||||
|
|||||||
@ -1,19 +1,20 @@
|
|||||||
import { aiSdk, Tool } from '@cherrystudio/ai-core'
|
import { aiSdk, Tool } from '@cherrystudio/ai-core'
|
||||||
import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
||||||
import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant'
|
import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant'
|
||||||
import { isFunctionCallingModel } from '@renderer/config/models'
|
import { isFunctionCallingModel } from '@renderer/config/models'
|
||||||
import { MCPTool, MCPToolResponse, Model } from '@renderer/types'
|
import { MCPTool, MCPToolResponse, Model } from '@renderer/types'
|
||||||
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
||||||
|
import { tool } from 'ai'
|
||||||
import { JSONSchema7 } from 'json-schema'
|
import { JSONSchema7 } from 'json-schema'
|
||||||
|
|
||||||
// Setup tools configuration based on provided parameters
|
// Setup tools configuration based on provided parameters
|
||||||
export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||||
tools: Record<string, AiSdkTool>
|
tools: Record<string, Tool>
|
||||||
useSystemPromptForTools?: boolean
|
useSystemPromptForTools?: boolean
|
||||||
} {
|
} {
|
||||||
const { mcpTools, model, enableToolUse } = params
|
const { mcpTools, model, enableToolUse } = params
|
||||||
|
|
||||||
let tools: Record<string, AiSdkTool> = {}
|
let tools: Record<string, Tool> = {}
|
||||||
|
|
||||||
if (!mcpTools?.length) {
|
if (!mcpTools?.length) {
|
||||||
return { tools }
|
return { tools }
|
||||||
@ -35,15 +36,15 @@ export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; e
|
|||||||
/**
|
/**
|
||||||
* 将 MCPTool 转换为 AI SDK 工具格式
|
* 将 MCPTool 转换为 AI SDK 工具格式
|
||||||
*/
|
*/
|
||||||
export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string, Tool<any, ToolCallResult>> {
|
export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string, Tool> {
|
||||||
const tools: Record<string, Tool<any, ToolCallResult>> = {}
|
const tools: Record<string, Tool> = {}
|
||||||
|
|
||||||
for (const mcpTool of mcpTools) {
|
for (const mcpTool of mcpTools) {
|
||||||
console.log('mcpTool', mcpTool.inputSchema)
|
console.log('mcpTool', mcpTool.inputSchema)
|
||||||
tools[mcpTool.name] = aiSdk.tool<any, ToolCallResult>({
|
tools[mcpTool.name] = tool({
|
||||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||||
inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||||
execute: async (params): Promise<ToolCallResult> => {
|
execute: async (params) => {
|
||||||
console.log('execute_params', params)
|
console.log('execute_params', params)
|
||||||
// 创建适配的 MCPToolResponse 对象
|
// 创建适配的 MCPToolResponse 对象
|
||||||
const toolResponse: MCPToolResponse = {
|
const toolResponse: MCPToolResponse = {
|
||||||
@ -64,15 +65,10 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string,
|
|||||||
}
|
}
|
||||||
console.log('result', result)
|
console.log('result', result)
|
||||||
// 返回工具执行结果
|
// 返回工具执行结果
|
||||||
return {
|
return result
|
||||||
success: true,
|
|
||||||
data: result
|
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`MCP Tool execution failed: ${mcpTool.name}`, error)
|
console.error(`MCP Tool execution failed: ${mcpTool.name}`, error)
|
||||||
throw new Error(
|
throw error
|
||||||
`Tool ${mcpTool.name} execution failed: ${error instanceof Error ? error.message : String(error)}`
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
import { WebSearchToolOutputSchema } from '@cherrystudio/ai-core/built-in/plugins'
|
||||||
|
import type { WebSearchToolOutput } from '@renderer/aiCore/tools/WebSearchTool'
|
||||||
|
|
||||||
import type { MCPToolInputSchema } from './index'
|
import type { MCPToolInputSchema } from './index'
|
||||||
|
|
||||||
export type ToolType = 'builtin' | 'provider' | 'mcp'
|
export type ToolType = 'builtin' | 'provider' | 'mcp'
|
||||||
@ -30,3 +33,5 @@ export interface MCPTool extends BaseTool {
|
|||||||
inputSchema: MCPToolInputSchema
|
inputSchema: MCPToolInputSchema
|
||||||
type: 'mcp'
|
type: 'mcp'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type WebSearchToolOutputSchema = WebSearchToolOutput | WebSearchToolOutputSchema
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user