mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-03 11:19:10 +08:00
refactor: enhance search orchestration and web search tool integration
- Updated `searchOrchestrationPlugin` to improve handling of assistant configurations and prevent concurrent analysis. - Refactored `webSearchTool` to utilize pre-extracted keywords for more efficient web searches. - Introduced a new `MessageKnowledgeSearch` component for displaying knowledge search results. - Cleaned up commented-out code and improved type safety across various components. - Enhanced the integration of web search results in the UI for better user experience.
This commit is contained in:
parent
0310648445
commit
a05d7cbe2d
@ -97,7 +97,6 @@ export class AiSdkToChunkAdapter {
|
|||||||
type: ChunkType.TEXT_DELTA,
|
type: ChunkType.TEXT_DELTA,
|
||||||
text: final.text || ''
|
text: final.text || ''
|
||||||
})
|
})
|
||||||
console.log('final.text', final.text)
|
|
||||||
break
|
break
|
||||||
case 'text-end':
|
case 'text-end':
|
||||||
this.onChunk({
|
this.onChunk({
|
||||||
|
|||||||
@ -21,7 +21,7 @@ import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-cor
|
|||||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||||
import { formatApiHost } from '@renderer/utils/api'
|
import { formatApiHost } from '@renderer/utils/api'
|
||||||
import { cloneDeep } from 'lodash'
|
import { cloneDeep } from 'lodash'
|
||||||
|
|
||||||
@ -163,7 +163,7 @@ export default class ModernAiProvider {
|
|||||||
// 内置了默认搜索参数,如果改的话可以传config进去
|
// 内置了默认搜索参数,如果改的话可以传config进去
|
||||||
plugins.push(webSearchPlugin())
|
plugins.push(webSearchPlugin())
|
||||||
}
|
}
|
||||||
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant as Assistant))
|
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant))
|
||||||
|
|
||||||
// 2. 推理模型时添加推理插件
|
// 2. 推理模型时添加推理插件
|
||||||
if (middlewareConfig.enableReasoning) {
|
if (middlewareConfig.enableReasoning) {
|
||||||
|
|||||||
@ -20,7 +20,7 @@ export interface AiSdkMiddlewareConfig {
|
|||||||
enableWebSearch?: boolean
|
enableWebSearch?: boolean
|
||||||
mcpTools?: BaseTool[]
|
mcpTools?: BaseTool[]
|
||||||
// TODO assistant
|
// TODO assistant
|
||||||
assistant?: Assistant
|
assistant: Assistant
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -8,7 +8,6 @@
|
|||||||
*/
|
*/
|
||||||
import type { AiRequestContext, ModelMessage } from '@cherrystudio/ai-core'
|
import type { AiRequestContext, ModelMessage } from '@cherrystudio/ai-core'
|
||||||
import { definePlugin } from '@cherrystudio/ai-core'
|
import { definePlugin } from '@cherrystudio/ai-core'
|
||||||
import { RuntimeExecutor } from '@cherrystudio/ai-core/core/runtime/executor'
|
|
||||||
// import { generateObject } from '@cherrystudio/ai-core'
|
// import { generateObject } from '@cherrystudio/ai-core'
|
||||||
import {
|
import {
|
||||||
SEARCH_SUMMARY_PROMPT,
|
SEARCH_SUMMARY_PROMPT,
|
||||||
@ -19,13 +18,13 @@ import { getDefaultModel, getProviderByModel } from '@renderer/services/Assistan
|
|||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
||||||
import type { Assistant } from '@renderer/types'
|
import type { Assistant } from '@renderer/types'
|
||||||
|
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
|
||||||
import { isEmpty } from 'lodash'
|
import { isEmpty } from 'lodash'
|
||||||
import { z } from 'zod'
|
|
||||||
|
|
||||||
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||||
import { knowledgeSearchTool } from '../tools/KnowledgeSearchTool'
|
import { knowledgeSearchTool } from '../tools/KnowledgeSearchTool'
|
||||||
import { memorySearchTool } from '../tools/MemorySearchTool'
|
import { memorySearchTool } from '../tools/MemorySearchTool'
|
||||||
import { webSearchTool } from '../tools/WebSearchTool'
|
import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool'
|
||||||
|
|
||||||
const getMessageContent = (message: ModelMessage) => {
|
const getMessageContent = (message: ModelMessage) => {
|
||||||
if (typeof message.content === 'string') return message.content
|
if (typeof message.content === 'string') return message.content
|
||||||
@ -39,32 +38,32 @@ const getMessageContent = (message: ModelMessage) => {
|
|||||||
|
|
||||||
// === Schema Definitions ===
|
// === Schema Definitions ===
|
||||||
|
|
||||||
const WebSearchSchema = z.object({
|
// const WebSearchSchema = z.object({
|
||||||
question: z
|
// question: z
|
||||||
.array(z.string())
|
// .array(z.string())
|
||||||
.describe('Search queries for web search. Use "not_needed" if no web search is required.'),
|
// .describe('Search queries for web search. Use "not_needed" if no web search is required.'),
|
||||||
links: z.array(z.string()).optional().describe('Specific URLs to search or summarize if mentioned in the query.')
|
// links: z.array(z.string()).optional().describe('Specific URLs to search or summarize if mentioned in the query.')
|
||||||
})
|
// })
|
||||||
|
|
||||||
const KnowledgeSearchSchema = z.object({
|
// const KnowledgeSearchSchema = z.object({
|
||||||
question: z
|
// question: z
|
||||||
.array(z.string())
|
// .array(z.string())
|
||||||
.describe('Search queries for knowledge base. Use "not_needed" if no knowledge search is required.'),
|
// .describe('Search queries for knowledge base. Use "not_needed" if no knowledge search is required.'),
|
||||||
rewrite: z
|
// rewrite: z
|
||||||
.string()
|
// .string()
|
||||||
.describe('Rewritten query with alternative phrasing while preserving original intent and meaning.')
|
// .describe('Rewritten query with alternative phrasing while preserving original intent and meaning.')
|
||||||
})
|
// })
|
||||||
|
|
||||||
const SearchIntentAnalysisSchema = z.object({
|
// const SearchIntentAnalysisSchema = z.object({
|
||||||
websearch: WebSearchSchema.optional().describe('Web search intent analysis results.'),
|
// websearch: WebSearchSchema.optional().describe('Web search intent analysis results.'),
|
||||||
knowledge: KnowledgeSearchSchema.optional().describe('Knowledge base search intent analysis results.')
|
// knowledge: KnowledgeSearchSchema.optional().describe('Knowledge base search intent analysis results.')
|
||||||
})
|
// })
|
||||||
|
|
||||||
type SearchIntentResult = z.infer<typeof SearchIntentAnalysisSchema>
|
// type SearchIntentResult = z.infer<typeof SearchIntentAnalysisSchema>
|
||||||
|
|
||||||
let isAnalyzing = false
|
// let isAnalyzing = false
|
||||||
/**
|
/**
|
||||||
* 🧠 意图分析函数 - 使用结构化输出重构
|
* 🧠 意图分析函数 - 使用 XML 解析
|
||||||
*/
|
*/
|
||||||
async function analyzeSearchIntent(
|
async function analyzeSearchIntent(
|
||||||
lastUserMessage: ModelMessage,
|
lastUserMessage: ModelMessage,
|
||||||
@ -74,13 +73,11 @@ async function analyzeSearchIntent(
|
|||||||
shouldKnowledgeSearch?: boolean
|
shouldKnowledgeSearch?: boolean
|
||||||
shouldMemorySearch?: boolean
|
shouldMemorySearch?: boolean
|
||||||
lastAnswer?: ModelMessage
|
lastAnswer?: ModelMessage
|
||||||
context?:
|
context: AiRequestContext & {
|
||||||
| AiRequestContext
|
isAnalyzing?: boolean
|
||||||
| {
|
}
|
||||||
executor: RuntimeExecutor
|
}
|
||||||
}
|
): Promise<ExtractResults | undefined> {
|
||||||
} = {}
|
|
||||||
): Promise<SearchIntentResult | undefined> {
|
|
||||||
const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer, context } = options
|
const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer, context } = options
|
||||||
|
|
||||||
if (!lastUserMessage) return undefined
|
if (!lastUserMessage) return undefined
|
||||||
@ -91,19 +88,19 @@ async function analyzeSearchIntent(
|
|||||||
|
|
||||||
if (!needWebExtract && !needKnowledgeExtract) return undefined
|
if (!needWebExtract && !needKnowledgeExtract) return undefined
|
||||||
|
|
||||||
// 选择合适的提示词和schema
|
// 选择合适的提示词
|
||||||
let prompt: string
|
let prompt: string
|
||||||
let schema: z.Schema
|
// let schema: z.Schema
|
||||||
|
|
||||||
if (needWebExtract && !needKnowledgeExtract) {
|
if (needWebExtract && !needKnowledgeExtract) {
|
||||||
prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||||
schema = z.object({ websearch: WebSearchSchema })
|
// schema = z.object({ websearch: WebSearchSchema })
|
||||||
} else if (!needWebExtract && needKnowledgeExtract) {
|
} else if (!needWebExtract && needKnowledgeExtract) {
|
||||||
prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
|
prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
|
||||||
schema = z.object({ knowledge: KnowledgeSearchSchema })
|
// schema = z.object({ knowledge: KnowledgeSearchSchema })
|
||||||
} else {
|
} else {
|
||||||
prompt = SEARCH_SUMMARY_PROMPT
|
prompt = SEARCH_SUMMARY_PROMPT
|
||||||
schema = SearchIntentAnalysisSchema
|
// schema = SearchIntentAnalysisSchema
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建消息上下文 - 简化逻辑
|
// 构建消息上下文 - 简化逻辑
|
||||||
@ -121,16 +118,15 @@ async function analyzeSearchIntent(
|
|||||||
console.error('Provider not found or missing API key')
|
console.error('Provider not found or missing API key')
|
||||||
return getFallbackResult()
|
return getFallbackResult()
|
||||||
}
|
}
|
||||||
|
// console.log('formattedPrompt', schema)
|
||||||
try {
|
try {
|
||||||
isAnalyzing = true
|
context.isAnalyzing = true
|
||||||
const result = await context?.executor?.generateObject(model.id, {
|
const { text: result } = await context.executor.generateText(model.id, {
|
||||||
schema,
|
|
||||||
prompt: formattedPrompt
|
prompt: formattedPrompt
|
||||||
})
|
})
|
||||||
isAnalyzing = false
|
context.isAnalyzing = false
|
||||||
console.log('result', context)
|
const parsedResult = extractInfoFromXML(result)
|
||||||
const parsedResult = result?.object as SearchIntentResult
|
console.log('parsedResult', parsedResult)
|
||||||
|
|
||||||
// 根据需求过滤结果
|
// 根据需求过滤结果
|
||||||
return {
|
return {
|
||||||
@ -142,7 +138,7 @@ async function analyzeSearchIntent(
|
|||||||
return getFallbackResult()
|
return getFallbackResult()
|
||||||
}
|
}
|
||||||
|
|
||||||
function getFallbackResult(): SearchIntentResult {
|
function getFallbackResult(): ExtractResults {
|
||||||
const fallbackContent = getMessageContent(lastUserMessage)
|
const fallbackContent = getMessageContent(lastUserMessage)
|
||||||
return {
|
return {
|
||||||
websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
|
websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
|
||||||
@ -159,7 +155,11 @@ async function analyzeSearchIntent(
|
|||||||
/**
|
/**
|
||||||
* 🧠 记忆存储函数 - 基于注释代码中的 processConversationMemory
|
* 🧠 记忆存储函数 - 基于注释代码中的 processConversationMemory
|
||||||
*/
|
*/
|
||||||
async function storeConversationMemory(messages: ModelMessage[], assistant: Assistant): Promise<void> {
|
async function storeConversationMemory(
|
||||||
|
messages: ModelMessage[],
|
||||||
|
assistant: Assistant,
|
||||||
|
context: AiRequestContext
|
||||||
|
): Promise<void> {
|
||||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||||
|
|
||||||
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||||
@ -185,14 +185,13 @@ async function storeConversationMemory(messages: ModelMessage[], assistant: Assi
|
|||||||
}
|
}
|
||||||
|
|
||||||
const currentUserId = selectCurrentUserId(store.getState())
|
const currentUserId = selectCurrentUserId(store.getState())
|
||||||
const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
// const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||||
|
|
||||||
const processorConfig = MemoryProcessor.getProcessorConfig(
|
const processorConfig = MemoryProcessor.getProcessorConfig(
|
||||||
memoryConfig,
|
memoryConfig,
|
||||||
assistant.id,
|
assistant.id,
|
||||||
currentUserId,
|
currentUserId,
|
||||||
// TODO
|
context.requestId
|
||||||
lastUserMessage?.id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
console.log('Processing conversation memory...', { messageCount: conversationMessages.length })
|
console.log('Processing conversation memory...', { messageCount: conversationMessages.length })
|
||||||
@ -224,9 +223,10 @@ async function storeConversationMemory(messages: ModelMessage[], assistant: Assi
|
|||||||
*/
|
*/
|
||||||
export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||||
// 存储意图分析结果
|
// 存储意图分析结果
|
||||||
const intentAnalysisResults: { [requestId: string]: SearchIntentResult } = {}
|
const intentAnalysisResults: { [requestId: string]: ExtractResults } = {}
|
||||||
const userMessages: { [requestId: string]: ModelMessage } = {}
|
const userMessages: { [requestId: string]: ModelMessage } = {}
|
||||||
console.log('searchOrchestrationPlugin', assistant)
|
console.log('searchOrchestrationPlugin', assistant)
|
||||||
|
|
||||||
return definePlugin({
|
return definePlugin({
|
||||||
name: 'search-orchestration',
|
name: 'search-orchestration',
|
||||||
enforce: 'pre', // 确保在其他插件之前执行
|
enforce: 'pre', // 确保在其他插件之前执行
|
||||||
@ -235,7 +235,8 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
|||||||
* 🔍 Step 1: 意图识别阶段
|
* 🔍 Step 1: 意图识别阶段
|
||||||
*/
|
*/
|
||||||
onRequestStart: async (context: AiRequestContext) => {
|
onRequestStart: async (context: AiRequestContext) => {
|
||||||
if (isAnalyzing) return
|
console.log('onRequestStart', context.isAnalyzing)
|
||||||
|
if (context.isAnalyzing) return
|
||||||
console.log('🧠 [SearchOrchestration] Starting intent analysis...', context.requestId)
|
console.log('🧠 [SearchOrchestration] Starting intent analysis...', context.requestId)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -294,7 +295,7 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
|||||||
* 🔧 Step 2: 工具配置阶段
|
* 🔧 Step 2: 工具配置阶段
|
||||||
*/
|
*/
|
||||||
transformParams: async (params: any, context: AiRequestContext) => {
|
transformParams: async (params: any, context: AiRequestContext) => {
|
||||||
if (isAnalyzing) return
|
if (context.isAnalyzing) return params
|
||||||
console.log('🔧 [SearchOrchestration] Configuring tools based on intent...', context.requestId)
|
console.log('🔧 [SearchOrchestration] Configuring tools based on intent...', context.requestId)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -314,8 +315,13 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
|||||||
const needsSearch = analysisResult.websearch.question && analysisResult.websearch.question[0] !== 'not_needed'
|
const needsSearch = analysisResult.websearch.question && analysisResult.websearch.question[0] !== 'not_needed'
|
||||||
|
|
||||||
if (needsSearch) {
|
if (needsSearch) {
|
||||||
console.log('🌐 [SearchOrchestration] Adding web search tool')
|
// onChunk({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS })
|
||||||
params.tools['builtin_web_search'] = webSearchTool(assistant.webSearchProviderId)
|
console.log('🌐 [SearchOrchestration] Adding web search tool with pre-extracted keywords')
|
||||||
|
params.tools['builtin_web_search'] = webSearchToolWithPreExtractedKeywords(
|
||||||
|
assistant.webSearchProviderId,
|
||||||
|
analysisResult.websearch,
|
||||||
|
context.requestId
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -370,7 +376,7 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
|||||||
const messages = context.originalParams.messages
|
const messages = context.originalParams.messages
|
||||||
|
|
||||||
if (messages && assistant) {
|
if (messages && assistant) {
|
||||||
await storeConversationMemory(messages, assistant)
|
await storeConversationMemory(messages, assistant, context)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清理缓存
|
// 清理缓存
|
||||||
|
|||||||
@ -1,131 +1,262 @@
|
|||||||
import { extractSearchKeywords } from '@renderer/aiCore/transformParameters'
|
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||||
import WebSearchService from '@renderer/services/WebSearchService'
|
import WebSearchService from '@renderer/services/WebSearchService'
|
||||||
import { Assistant, Message, WebSearchProvider } from '@renderer/types'
|
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
|
||||||
import { UserMessageStatus } from '@renderer/types/newMessage'
|
|
||||||
import { ExtractResults } from '@renderer/utils/extract'
|
import { ExtractResults } from '@renderer/utils/extract'
|
||||||
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
|
import { InferToolInput, InferToolOutput, tool } from 'ai'
|
||||||
import { z } from 'zod'
|
import { z } from 'zod'
|
||||||
|
|
||||||
// import { AiSdkTool, ToolCallResult } from './types'
|
// import { AiSdkTool, ToolCallResult } from './types'
|
||||||
|
|
||||||
const WebSearchProviderResult = z.object({
|
// const WebSearchResult = z.array(
|
||||||
query: z.string().optional(),
|
// z.object({
|
||||||
results: z.array(
|
// query: z.string().optional(),
|
||||||
z.object({
|
// results: z.array(
|
||||||
title: z.string(),
|
// z.object({
|
||||||
content: z.string(),
|
// title: z.string(),
|
||||||
url: z.string()
|
// content: z.string(),
|
||||||
})
|
// url: z.string()
|
||||||
)
|
// })
|
||||||
})
|
// )
|
||||||
const webSearchToolInputSchema = z.object({
|
// })
|
||||||
query: z.string().describe('The query to search for')
|
// )
|
||||||
})
|
// const webSearchToolInputSchema = z.object({
|
||||||
|
// query: z.string().describe('The query to search for')
|
||||||
|
// })
|
||||||
|
|
||||||
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => {
|
// export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => {
|
||||||
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
// const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||||
return tool({
|
// return tool({
|
||||||
name: 'builtin_web_search',
|
// name: 'builtin_web_search',
|
||||||
description: 'Search the web for information',
|
// description: 'Search the web for information',
|
||||||
inputSchema: webSearchToolInputSchema,
|
// inputSchema: webSearchToolInputSchema,
|
||||||
outputSchema: WebSearchProviderResult,
|
// outputSchema: WebSearchProviderResult,
|
||||||
execute: async ({ query }) => {
|
// execute: async ({ query }) => {
|
||||||
console.log('webSearchTool', query)
|
// console.log('webSearchTool', query)
|
||||||
const response = await webSearchService.search(query)
|
// const response = await webSearchService.search(query)
|
||||||
console.log('webSearchTool response', response)
|
// console.log('webSearchTool response', response)
|
||||||
return response
|
// return response
|
||||||
}
|
// }
|
||||||
})
|
// })
|
||||||
}
|
// }
|
||||||
export type WebSearchToolInput = InferToolInput<ReturnType<typeof webSearchTool>>
|
// export type WebSearchToolInput = InferToolInput<ReturnType<typeof webSearchTool>>
|
||||||
export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchTool>>
|
// export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchTool>>
|
||||||
|
|
||||||
export const webSearchToolWithExtraction = (
|
/**
|
||||||
|
* 使用预提取关键词的网络搜索工具
|
||||||
|
* 这个工具直接使用插件阶段分析的搜索意图,避免重复分析
|
||||||
|
*/
|
||||||
|
export const webSearchToolWithPreExtractedKeywords = (
|
||||||
webSearchProviderId: WebSearchProvider['id'],
|
webSearchProviderId: WebSearchProvider['id'],
|
||||||
requestId: string,
|
extractedKeywords: {
|
||||||
assistant: Assistant
|
question: string[]
|
||||||
|
links?: string[]
|
||||||
|
},
|
||||||
|
requestId: string
|
||||||
) => {
|
) => {
|
||||||
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||||
|
|
||||||
return tool({
|
return tool({
|
||||||
name: 'web_search_with_extraction',
|
name: 'builtin_web_search',
|
||||||
description: 'Search the web for information with automatic keyword extraction from user messages',
|
description: `Search the web and return citable sources using pre-analyzed search intent.
|
||||||
|
|
||||||
|
Pre-extracted search keywords: "${extractedKeywords.question.join(', ')}"${
|
||||||
|
extractedKeywords.links
|
||||||
|
? `
|
||||||
|
Relevant links: ${extractedKeywords.links.join(', ')}`
|
||||||
|
: ''
|
||||||
|
}
|
||||||
|
|
||||||
|
This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.
|
||||||
|
|
||||||
|
Call this tool to execute the search. You can optionally provide additional context to refine the search.`,
|
||||||
|
|
||||||
inputSchema: z.object({
|
inputSchema: z.object({
|
||||||
userMessage: z.object({
|
additionalContext: z
|
||||||
content: z.string().describe('The main content of the message'),
|
.string()
|
||||||
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
.optional()
|
||||||
}),
|
.describe('Optional additional context, keywords, or specific focus to enhance the search')
|
||||||
lastAnswer: z.object({
|
|
||||||
content: z.string().describe('The main content of the message'),
|
|
||||||
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
|
||||||
})
|
|
||||||
}),
|
}),
|
||||||
outputSchema: z.object({
|
|
||||||
extractedKeywords: z.object({
|
execute: async ({ additionalContext }) => {
|
||||||
question: z.array(z.string()),
|
let finalQueries = [...extractedKeywords.question]
|
||||||
links: z.array(z.string()).optional()
|
|
||||||
}),
|
if (additionalContext?.trim()) {
|
||||||
searchResults: z.array(
|
// 如果大模型提供了额外上下文,使用更具体的描述
|
||||||
z.object({
|
console.log(`🔍 AI enhanced search with: ${additionalContext}`)
|
||||||
query: z.string(),
|
const cleanContext = additionalContext.trim()
|
||||||
results: WebSearchProviderResult
|
if (cleanContext) {
|
||||||
})
|
finalQueries = [cleanContext]
|
||||||
)
|
console.log(`➕ Added additional context: ${cleanContext}`)
|
||||||
}),
|
}
|
||||||
execute: async ({ userMessage, lastAnswer }) => {
|
|
||||||
const lastUserMessage: Message = {
|
|
||||||
id: requestId,
|
|
||||||
role: userMessage.role,
|
|
||||||
assistantId: assistant.id,
|
|
||||||
topicId: 'temp',
|
|
||||||
createdAt: new Date().toISOString(),
|
|
||||||
status: UserMessageStatus.SUCCESS,
|
|
||||||
blocks: []
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const lastAnswerMessage: Message | undefined = lastAnswer
|
const searchResults: WebSearchProviderResponse[] = []
|
||||||
? {
|
|
||||||
id: requestId + '_answer',
|
|
||||||
role: lastAnswer.role,
|
|
||||||
assistantId: assistant.id,
|
|
||||||
topicId: 'temp',
|
|
||||||
createdAt: new Date().toISOString(),
|
|
||||||
status: UserMessageStatus.SUCCESS,
|
|
||||||
blocks: []
|
|
||||||
}
|
|
||||||
: undefined
|
|
||||||
|
|
||||||
const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
|
// 检查是否需要搜索
|
||||||
shouldWebSearch: true,
|
if (finalQueries[0] === 'not_needed') {
|
||||||
shouldKnowledgeSearch: false,
|
return {
|
||||||
lastAnswer: lastAnswerMessage
|
summary: 'No search needed based on the query analysis.',
|
||||||
})
|
searchResults: [],
|
||||||
|
sources: '',
|
||||||
if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
|
instructions: '',
|
||||||
return 'No search needed or extraction failed'
|
rawResults: []
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const searchQueries = extractResults.websearch.question
|
try {
|
||||||
const searchResults: Array<{ query: string; results: any }> = []
|
// 构建 ExtractResults 结构用于 processWebsearch
|
||||||
|
const extractResults: ExtractResults = {
|
||||||
for (const query of searchQueries) {
|
|
||||||
// 构建单个查询的ExtractResults结构
|
|
||||||
const queryExtractResults: ExtractResults = {
|
|
||||||
websearch: {
|
websearch: {
|
||||||
question: [query],
|
question: finalQueries,
|
||||||
links: extractResults.websearch.links
|
links: extractedKeywords.links
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
|
console.log('extractResults', extractResults)
|
||||||
searchResults.push({
|
const response = await webSearchService.processWebsearch(extractResults, requestId)
|
||||||
query,
|
searchResults.push(response)
|
||||||
results: response
|
} catch (error) {
|
||||||
})
|
console.error(`Web search failed for query "${finalQueries}":`, error)
|
||||||
|
return {
|
||||||
|
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||||
|
searchResults: [],
|
||||||
|
sources: '',
|
||||||
|
instructions: '',
|
||||||
|
rawResults: []
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return { extractedKeywords: extractResults.websearch, searchResults }
|
if (searchResults.length === 0 || !searchResults[0].results) {
|
||||||
|
return {
|
||||||
|
summary: 'No search results found for the given query.',
|
||||||
|
searchResults: [],
|
||||||
|
sources: '',
|
||||||
|
instructions: '',
|
||||||
|
rawResults: []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const results = searchResults[0].results
|
||||||
|
const citationData = results.map((result, index) => ({
|
||||||
|
number: index + 1,
|
||||||
|
title: result.title,
|
||||||
|
content: result.content,
|
||||||
|
url: result.url
|
||||||
|
}))
|
||||||
|
|
||||||
|
// 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑
|
||||||
|
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
|
||||||
|
|
||||||
|
// 构建完整的引用指导文本
|
||||||
|
const fullInstructions = REFERENCE_PROMPT.replace(
|
||||||
|
'{question}',
|
||||||
|
"Based on the search results, please answer the user's question with proper citations."
|
||||||
|
).replace('{references}', referenceContent)
|
||||||
|
|
||||||
|
return {
|
||||||
|
summary: `Found ${citationData.length} relevant sources. Use [number] format to cite specific information.`,
|
||||||
|
searchResults,
|
||||||
|
sources: citationData
|
||||||
|
.map((source) => `[${source.number}] ${source.title}\n${source.content}\nURL: ${source.url}`)
|
||||||
|
.join('\n\n'),
|
||||||
|
|
||||||
|
instructions: fullInstructions,
|
||||||
|
|
||||||
|
// 原始数据,便于后续处理
|
||||||
|
rawResults: citationData
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
export type WebSearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof webSearchToolWithExtraction>>
|
// export const webSearchToolWithExtraction = (
|
||||||
|
// webSearchProviderId: WebSearchProvider['id'],
|
||||||
|
// requestId: string,
|
||||||
|
// assistant: Assistant
|
||||||
|
// ) => {
|
||||||
|
// const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||||
|
|
||||||
|
// return tool({
|
||||||
|
// name: 'web_search_with_extraction',
|
||||||
|
// description: 'Search the web for information with automatic keyword extraction from user messages',
|
||||||
|
// inputSchema: z.object({
|
||||||
|
// userMessage: z.object({
|
||||||
|
// content: z.string().describe('The main content of the message'),
|
||||||
|
// role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||||
|
// }),
|
||||||
|
// lastAnswer: z.object({
|
||||||
|
// content: z.string().describe('The main content of the message'),
|
||||||
|
// role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||||
|
// })
|
||||||
|
// }),
|
||||||
|
// outputSchema: z.object({
|
||||||
|
// extractedKeywords: z.object({
|
||||||
|
// question: z.array(z.string()),
|
||||||
|
// links: z.array(z.string()).optional()
|
||||||
|
// }),
|
||||||
|
// searchResults: z.array(
|
||||||
|
// z.object({
|
||||||
|
// query: z.string(),
|
||||||
|
// results: WebSearchProviderResult
|
||||||
|
// })
|
||||||
|
// )
|
||||||
|
// }),
|
||||||
|
// execute: async ({ userMessage, lastAnswer }) => {
|
||||||
|
// const lastUserMessage: Message = {
|
||||||
|
// id: requestId,
|
||||||
|
// role: userMessage.role,
|
||||||
|
// assistantId: assistant.id,
|
||||||
|
// topicId: 'temp',
|
||||||
|
// createdAt: new Date().toISOString(),
|
||||||
|
// status: UserMessageStatus.SUCCESS,
|
||||||
|
// blocks: []
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const lastAnswerMessage: Message | undefined = lastAnswer
|
||||||
|
// ? {
|
||||||
|
// id: requestId + '_answer',
|
||||||
|
// role: lastAnswer.role,
|
||||||
|
// assistantId: assistant.id,
|
||||||
|
// topicId: 'temp',
|
||||||
|
// createdAt: new Date().toISOString(),
|
||||||
|
// status: UserMessageStatus.SUCCESS,
|
||||||
|
// blocks: []
|
||||||
|
// }
|
||||||
|
// : undefined
|
||||||
|
|
||||||
|
// const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
|
||||||
|
// shouldWebSearch: true,
|
||||||
|
// shouldKnowledgeSearch: false,
|
||||||
|
// lastAnswer: lastAnswerMessage
|
||||||
|
// })
|
||||||
|
|
||||||
|
// if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
|
||||||
|
// return 'No search needed or extraction failed'
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const searchQueries = extractResults.websearch.question
|
||||||
|
// const searchResults: Array<{ query: string; results: any }> = []
|
||||||
|
|
||||||
|
// for (const query of searchQueries) {
|
||||||
|
// // 构建单个查询的ExtractResults结构
|
||||||
|
// const queryExtractResults: ExtractResults = {
|
||||||
|
// websearch: {
|
||||||
|
// question: [query],
|
||||||
|
// links: extractResults.websearch.links
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
|
||||||
|
// searchResults.push({
|
||||||
|
// query,
|
||||||
|
// results: response
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return { extractedKeywords: extractResults.websearch, searchResults }
|
||||||
|
// }
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
// export type WebSearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof webSearchToolWithExtraction>>
|
||||||
|
|
||||||
|
export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchToolWithPreExtractedKeywords>>
|
||||||
|
export type WebSearchToolInput = InferToolInput<ReturnType<typeof webSearchToolWithPreExtractedKeywords>>
|
||||||
|
|||||||
@ -0,0 +1,72 @@
|
|||||||
|
import { KnowledgeSearchToolInput, KnowledgeSearchToolOutput } from '@renderer/aiCore/tools/KnowledgeSearchTool'
|
||||||
|
import Spinner from '@renderer/components/Spinner'
|
||||||
|
import i18n from '@renderer/i18n'
|
||||||
|
import { MCPToolResponse } from '@renderer/types'
|
||||||
|
import { Typography } from 'antd'
|
||||||
|
import { FileSearch } from 'lucide-react'
|
||||||
|
import styled from 'styled-components'
|
||||||
|
|
||||||
|
const { Text } = Typography
|
||||||
|
export function MessageKnowledgeSearchToolTitle({ toolResponse }: { toolResponse: MCPToolResponse }) {
|
||||||
|
const toolInput = toolResponse.arguments as KnowledgeSearchToolInput
|
||||||
|
const toolOutput = toolResponse.response as KnowledgeSearchToolOutput
|
||||||
|
|
||||||
|
return toolResponse.status !== 'done' ? (
|
||||||
|
<Spinner
|
||||||
|
text={
|
||||||
|
<PrepareToolWrapper>
|
||||||
|
{i18n.t('message.searching')}
|
||||||
|
<span>{toolInput?.rewrite ?? toolInput?.query ?? ''}</span>
|
||||||
|
</PrepareToolWrapper>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<MessageWebSearchToolTitleTextWrapper type="secondary">
|
||||||
|
<FileSearch size={16} style={{ color: 'unset' }} />
|
||||||
|
{i18n.t('message.websearch.fetch_complete', { count: toolOutput.length ?? 0 })}
|
||||||
|
</MessageWebSearchToolTitleTextWrapper>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function MessageKnowledgeSearchToolBody({ toolResponse }: { toolResponse: MCPToolResponse }) {
|
||||||
|
const toolOutput = toolResponse.response as KnowledgeSearchToolOutput
|
||||||
|
|
||||||
|
return toolResponse.status === 'done' ? (
|
||||||
|
<MessageWebSearchToolBodyUlWrapper>
|
||||||
|
{toolOutput.map((result) => (
|
||||||
|
<li key={result.id}>
|
||||||
|
<span>{result.id}</span>
|
||||||
|
<span>{result.content}</span>
|
||||||
|
</li>
|
||||||
|
))}
|
||||||
|
</MessageWebSearchToolBodyUlWrapper>
|
||||||
|
) : null
|
||||||
|
}
|
||||||
|
|
||||||
|
const PrepareToolWrapper = styled.span`
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 4px;
|
||||||
|
font-size: 14px;
|
||||||
|
padding-left: 0;
|
||||||
|
`
|
||||||
|
const MessageWebSearchToolTitleTextWrapper = styled(Text)`
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 4px;
|
||||||
|
`
|
||||||
|
|
||||||
|
const MessageWebSearchToolBodyUlWrapper = styled.ul`
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 4px;
|
||||||
|
padding: 0;
|
||||||
|
> li {
|
||||||
|
padding: 0;
|
||||||
|
margin: 0;
|
||||||
|
max-width: 70%;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
`
|
||||||
@ -1,11 +1,14 @@
|
|||||||
|
import { MCPToolResponse } from '@renderer/types'
|
||||||
import type { ToolMessageBlock } from '@renderer/types/newMessage'
|
import type { ToolMessageBlock } from '@renderer/types/newMessage'
|
||||||
import { Collapse } from 'antd'
|
import { Collapse } from 'antd'
|
||||||
|
|
||||||
|
import { MessageKnowledgeSearchToolBody, MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
|
||||||
import { MessageWebSearchToolBody, MessageWebSearchToolTitle } from './MessageWebSearchTool'
|
import { MessageWebSearchToolBody, MessageWebSearchToolTitle } from './MessageWebSearchTool'
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
block: ToolMessageBlock
|
block: ToolMessageBlock
|
||||||
}
|
}
|
||||||
|
const prefix = 'builtin_'
|
||||||
|
|
||||||
// const toolNameMapText = {
|
// const toolNameMapText = {
|
||||||
// web_search: i18n.t('message.searching')
|
// web_search: i18n.t('message.searching')
|
||||||
@ -41,18 +44,62 @@ interface Props {
|
|||||||
// return <p>{toolDoneNameText}</p>
|
// return <p>{toolDoneNameText}</p>
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
// const ToolLabelComponents = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
|
||||||
|
// if (webSearchToolNames.includes(toolResponse.tool.name)) {
|
||||||
|
// return <MessageWebSearchToolTitle toolResponse={toolResponse} />
|
||||||
|
// }
|
||||||
|
// return <MessageWebSearchToolTitle toolResponse={toolResponse} />
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const ToolBodyComponents = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
|
||||||
|
// if (webSearchToolNames.includes(toolResponse.tool.name)) {
|
||||||
|
// return <MessageWebSearchToolBody toolResponse={toolResponse} />
|
||||||
|
// }
|
||||||
|
// return <MessageWebSearchToolBody toolResponse={toolResponse} />
|
||||||
|
// }
|
||||||
|
|
||||||
|
const ChooseTool = (
|
||||||
|
toolResponse: MCPToolResponse
|
||||||
|
): {
|
||||||
|
label: React.ReactNode
|
||||||
|
body: React.ReactNode
|
||||||
|
} => {
|
||||||
|
let toolName = toolResponse.tool.name
|
||||||
|
if (toolName.startsWith(prefix)) {
|
||||||
|
toolName = toolName.slice(prefix.length)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (toolName) {
|
||||||
|
case 'web_search':
|
||||||
|
case 'web_search_preview':
|
||||||
|
return {
|
||||||
|
label: <MessageWebSearchToolTitle toolResponse={toolResponse} />,
|
||||||
|
body: <MessageWebSearchToolBody toolResponse={toolResponse} />
|
||||||
|
}
|
||||||
|
case 'knowledge_search':
|
||||||
|
return {
|
||||||
|
label: <MessageKnowledgeSearchToolTitle toolResponse={toolResponse} />,
|
||||||
|
body: <MessageKnowledgeSearchToolBody toolResponse={toolResponse} />
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return {
|
||||||
|
label: <MessageWebSearchToolTitle toolResponse={toolResponse} />,
|
||||||
|
body: <MessageWebSearchToolBody toolResponse={toolResponse} />
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export default function MessageTool({ block }: Props) {
|
export default function MessageTool({ block }: Props) {
|
||||||
const toolResponse = block.metadata?.rawMcpToolResponse
|
const toolResponse = block.metadata?.rawMcpToolResponse
|
||||||
if (!toolResponse) return null
|
if (!toolResponse) return null
|
||||||
console.log('toolResponse', toolResponse)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Collapse
|
<Collapse
|
||||||
items={[
|
items={[
|
||||||
{
|
{
|
||||||
key: '1',
|
key: '1',
|
||||||
label: <MessageWebSearchToolTitle toolResponse={toolResponse} />,
|
label: ChooseTool(toolResponse).label,
|
||||||
children: <MessageWebSearchToolBody toolResponse={toolResponse} />,
|
children: ChooseTool(toolResponse).body,
|
||||||
showArrow: false,
|
showArrow: false,
|
||||||
styles: {
|
styles: {
|
||||||
header: {
|
header: {
|
||||||
|
|||||||
@ -17,14 +17,16 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT
|
|||||||
text={
|
text={
|
||||||
<PrepareToolWrapper>
|
<PrepareToolWrapper>
|
||||||
{i18n.t('message.searching')}
|
{i18n.t('message.searching')}
|
||||||
<span>{toolInput?.query ?? ''}</span>
|
<span>{toolInput?.additionalContext ?? ''}</span>
|
||||||
</PrepareToolWrapper>
|
</PrepareToolWrapper>
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<MessageWebSearchToolTitleTextWrapper type="secondary">
|
<MessageWebSearchToolTitleTextWrapper type="secondary">
|
||||||
<Search size={16} style={{ color: 'unset' }} />
|
<Search size={16} style={{ color: 'unset' }} />
|
||||||
{i18n.t('message.websearch.fetch_complete', { count: toolOutput.results.length ?? 0 })}
|
{i18n.t('message.websearch.fetch_complete', {
|
||||||
|
count: toolOutput?.searchResults?.reduce((acc, result) => acc + result.results.length, 0) ?? 0
|
||||||
|
})}
|
||||||
</MessageWebSearchToolTitleTextWrapper>
|
</MessageWebSearchToolTitleTextWrapper>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -32,15 +34,17 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT
|
|||||||
export const MessageWebSearchToolBody = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
|
export const MessageWebSearchToolBody = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
|
||||||
const toolOutput = toolResponse.response as WebSearchToolOutput
|
const toolOutput = toolResponse.response as WebSearchToolOutput
|
||||||
|
|
||||||
return toolResponse.status === 'done' ? (
|
return toolResponse.status === 'done'
|
||||||
<MessageWebSearchToolBodyUlWrapper>
|
? toolOutput?.searchResults?.map((result, index) => (
|
||||||
{toolOutput.results.map((result) => (
|
<MessageWebSearchToolBodyUlWrapper key={result?.query ?? '' + index}>
|
||||||
<li key={result.url}>
|
{result.results.map((item, index) => (
|
||||||
<Link href={result.url}>{result.title}</Link>
|
<li key={item.url + index}>
|
||||||
</li>
|
<Link href={item.url}>{item.title}</Link>
|
||||||
))}
|
</li>
|
||||||
</MessageWebSearchToolBodyUlWrapper>
|
))}
|
||||||
) : null
|
</MessageWebSearchToolBodyUlWrapper>
|
||||||
|
))
|
||||||
|
: null
|
||||||
}
|
}
|
||||||
|
|
||||||
const PrepareToolWrapper = styled.span`
|
const PrepareToolWrapper = styled.span`
|
||||||
|
|||||||
@ -881,7 +881,6 @@ const fetchAndProcessAssistantResponseImpl = async (
|
|||||||
saveUpdatesToDB,
|
saveUpdatesToDB,
|
||||||
assistant
|
assistant
|
||||||
})
|
})
|
||||||
console.log('callbacks', callbacks)
|
|
||||||
const streamProcessorCallbacks = createStreamProcessor(callbacks)
|
const streamProcessorCallbacks = createStreamProcessor(callbacks)
|
||||||
|
|
||||||
const abortController = new AbortController()
|
const abortController = new AbortController()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user