mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
refactor: enhance search orchestration and web search tool integration
- Updated `searchOrchestrationPlugin` to improve handling of assistant configurations and prevent concurrent analysis. - Refactored `webSearchTool` to utilize pre-extracted keywords for more efficient web searches. - Introduced a new `MessageKnowledgeSearch` component for displaying knowledge search results. - Cleaned up commented-out code and improved type safety across various components. - Enhanced the integration of web search results in the UI for better user experience.
This commit is contained in:
parent
0310648445
commit
a05d7cbe2d
@ -97,7 +97,6 @@ export class AiSdkToChunkAdapter {
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: final.text || ''
|
||||
})
|
||||
console.log('final.text', final.text)
|
||||
break
|
||||
case 'text-end':
|
||||
this.onChunk({
|
||||
|
||||
@ -21,7 +21,7 @@ import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-cor
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { cloneDeep } from 'lodash'
|
||||
|
||||
@ -163,7 +163,7 @@ export default class ModernAiProvider {
|
||||
// 内置了默认搜索参数,如果改的话可以传config进去
|
||||
plugins.push(webSearchPlugin())
|
||||
}
|
||||
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant as Assistant))
|
||||
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant))
|
||||
|
||||
// 2. 推理模型时添加推理插件
|
||||
if (middlewareConfig.enableReasoning) {
|
||||
|
||||
@ -20,7 +20,7 @@ export interface AiSdkMiddlewareConfig {
|
||||
enableWebSearch?: boolean
|
||||
mcpTools?: BaseTool[]
|
||||
// TODO assistant
|
||||
assistant?: Assistant
|
||||
assistant: Assistant
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -8,7 +8,6 @@
|
||||
*/
|
||||
import type { AiRequestContext, ModelMessage } from '@cherrystudio/ai-core'
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import { RuntimeExecutor } from '@cherrystudio/ai-core/core/runtime/executor'
|
||||
// import { generateObject } from '@cherrystudio/ai-core'
|
||||
import {
|
||||
SEARCH_SUMMARY_PROMPT,
|
||||
@ -19,13 +18,13 @@ import { getDefaultModel, getProviderByModel } from '@renderer/services/Assistan
|
||||
import store from '@renderer/store'
|
||||
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
||||
import type { Assistant } from '@renderer/types'
|
||||
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
|
||||
import { isEmpty } from 'lodash'
|
||||
import { z } from 'zod'
|
||||
|
||||
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||
import { knowledgeSearchTool } from '../tools/KnowledgeSearchTool'
|
||||
import { memorySearchTool } from '../tools/MemorySearchTool'
|
||||
import { webSearchTool } from '../tools/WebSearchTool'
|
||||
import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool'
|
||||
|
||||
const getMessageContent = (message: ModelMessage) => {
|
||||
if (typeof message.content === 'string') return message.content
|
||||
@ -39,32 +38,32 @@ const getMessageContent = (message: ModelMessage) => {
|
||||
|
||||
// === Schema Definitions ===
|
||||
|
||||
const WebSearchSchema = z.object({
|
||||
question: z
|
||||
.array(z.string())
|
||||
.describe('Search queries for web search. Use "not_needed" if no web search is required.'),
|
||||
links: z.array(z.string()).optional().describe('Specific URLs to search or summarize if mentioned in the query.')
|
||||
})
|
||||
// const WebSearchSchema = z.object({
|
||||
// question: z
|
||||
// .array(z.string())
|
||||
// .describe('Search queries for web search. Use "not_needed" if no web search is required.'),
|
||||
// links: z.array(z.string()).optional().describe('Specific URLs to search or summarize if mentioned in the query.')
|
||||
// })
|
||||
|
||||
const KnowledgeSearchSchema = z.object({
|
||||
question: z
|
||||
.array(z.string())
|
||||
.describe('Search queries for knowledge base. Use "not_needed" if no knowledge search is required.'),
|
||||
rewrite: z
|
||||
.string()
|
||||
.describe('Rewritten query with alternative phrasing while preserving original intent and meaning.')
|
||||
})
|
||||
// const KnowledgeSearchSchema = z.object({
|
||||
// question: z
|
||||
// .array(z.string())
|
||||
// .describe('Search queries for knowledge base. Use "not_needed" if no knowledge search is required.'),
|
||||
// rewrite: z
|
||||
// .string()
|
||||
// .describe('Rewritten query with alternative phrasing while preserving original intent and meaning.')
|
||||
// })
|
||||
|
||||
const SearchIntentAnalysisSchema = z.object({
|
||||
websearch: WebSearchSchema.optional().describe('Web search intent analysis results.'),
|
||||
knowledge: KnowledgeSearchSchema.optional().describe('Knowledge base search intent analysis results.')
|
||||
})
|
||||
// const SearchIntentAnalysisSchema = z.object({
|
||||
// websearch: WebSearchSchema.optional().describe('Web search intent analysis results.'),
|
||||
// knowledge: KnowledgeSearchSchema.optional().describe('Knowledge base search intent analysis results.')
|
||||
// })
|
||||
|
||||
type SearchIntentResult = z.infer<typeof SearchIntentAnalysisSchema>
|
||||
// type SearchIntentResult = z.infer<typeof SearchIntentAnalysisSchema>
|
||||
|
||||
let isAnalyzing = false
|
||||
// let isAnalyzing = false
|
||||
/**
|
||||
* 🧠 意图分析函数 - 使用结构化输出重构
|
||||
* 🧠 意图分析函数 - 使用 XML 解析
|
||||
*/
|
||||
async function analyzeSearchIntent(
|
||||
lastUserMessage: ModelMessage,
|
||||
@ -74,13 +73,11 @@ async function analyzeSearchIntent(
|
||||
shouldKnowledgeSearch?: boolean
|
||||
shouldMemorySearch?: boolean
|
||||
lastAnswer?: ModelMessage
|
||||
context?:
|
||||
| AiRequestContext
|
||||
| {
|
||||
executor: RuntimeExecutor
|
||||
}
|
||||
} = {}
|
||||
): Promise<SearchIntentResult | undefined> {
|
||||
context: AiRequestContext & {
|
||||
isAnalyzing?: boolean
|
||||
}
|
||||
}
|
||||
): Promise<ExtractResults | undefined> {
|
||||
const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer, context } = options
|
||||
|
||||
if (!lastUserMessage) return undefined
|
||||
@ -91,19 +88,19 @@ async function analyzeSearchIntent(
|
||||
|
||||
if (!needWebExtract && !needKnowledgeExtract) return undefined
|
||||
|
||||
// 选择合适的提示词和schema
|
||||
// 选择合适的提示词
|
||||
let prompt: string
|
||||
let schema: z.Schema
|
||||
// let schema: z.Schema
|
||||
|
||||
if (needWebExtract && !needKnowledgeExtract) {
|
||||
prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
schema = z.object({ websearch: WebSearchSchema })
|
||||
// schema = z.object({ websearch: WebSearchSchema })
|
||||
} else if (!needWebExtract && needKnowledgeExtract) {
|
||||
prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
|
||||
schema = z.object({ knowledge: KnowledgeSearchSchema })
|
||||
// schema = z.object({ knowledge: KnowledgeSearchSchema })
|
||||
} else {
|
||||
prompt = SEARCH_SUMMARY_PROMPT
|
||||
schema = SearchIntentAnalysisSchema
|
||||
// schema = SearchIntentAnalysisSchema
|
||||
}
|
||||
|
||||
// 构建消息上下文 - 简化逻辑
|
||||
@ -121,16 +118,15 @@ async function analyzeSearchIntent(
|
||||
console.error('Provider not found or missing API key')
|
||||
return getFallbackResult()
|
||||
}
|
||||
|
||||
// console.log('formattedPrompt', schema)
|
||||
try {
|
||||
isAnalyzing = true
|
||||
const result = await context?.executor?.generateObject(model.id, {
|
||||
schema,
|
||||
context.isAnalyzing = true
|
||||
const { text: result } = await context.executor.generateText(model.id, {
|
||||
prompt: formattedPrompt
|
||||
})
|
||||
isAnalyzing = false
|
||||
console.log('result', context)
|
||||
const parsedResult = result?.object as SearchIntentResult
|
||||
context.isAnalyzing = false
|
||||
const parsedResult = extractInfoFromXML(result)
|
||||
console.log('parsedResult', parsedResult)
|
||||
|
||||
// 根据需求过滤结果
|
||||
return {
|
||||
@ -142,7 +138,7 @@ async function analyzeSearchIntent(
|
||||
return getFallbackResult()
|
||||
}
|
||||
|
||||
function getFallbackResult(): SearchIntentResult {
|
||||
function getFallbackResult(): ExtractResults {
|
||||
const fallbackContent = getMessageContent(lastUserMessage)
|
||||
return {
|
||||
websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
|
||||
@ -159,7 +155,11 @@ async function analyzeSearchIntent(
|
||||
/**
|
||||
* 🧠 记忆存储函数 - 基于注释代码中的 processConversationMemory
|
||||
*/
|
||||
async function storeConversationMemory(messages: ModelMessage[], assistant: Assistant): Promise<void> {
|
||||
async function storeConversationMemory(
|
||||
messages: ModelMessage[],
|
||||
assistant: Assistant,
|
||||
context: AiRequestContext
|
||||
): Promise<void> {
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
|
||||
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||
@ -185,14 +185,13 @@ async function storeConversationMemory(messages: ModelMessage[], assistant: Assi
|
||||
}
|
||||
|
||||
const currentUserId = selectCurrentUserId(store.getState())
|
||||
const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||
// const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||
|
||||
const processorConfig = MemoryProcessor.getProcessorConfig(
|
||||
memoryConfig,
|
||||
assistant.id,
|
||||
currentUserId,
|
||||
// TODO
|
||||
lastUserMessage?.id
|
||||
context.requestId
|
||||
)
|
||||
|
||||
console.log('Processing conversation memory...', { messageCount: conversationMessages.length })
|
||||
@ -224,9 +223,10 @@ async function storeConversationMemory(messages: ModelMessage[], assistant: Assi
|
||||
*/
|
||||
export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
// 存储意图分析结果
|
||||
const intentAnalysisResults: { [requestId: string]: SearchIntentResult } = {}
|
||||
const intentAnalysisResults: { [requestId: string]: ExtractResults } = {}
|
||||
const userMessages: { [requestId: string]: ModelMessage } = {}
|
||||
console.log('searchOrchestrationPlugin', assistant)
|
||||
|
||||
return definePlugin({
|
||||
name: 'search-orchestration',
|
||||
enforce: 'pre', // 确保在其他插件之前执行
|
||||
@ -235,7 +235,8 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
* 🔍 Step 1: 意图识别阶段
|
||||
*/
|
||||
onRequestStart: async (context: AiRequestContext) => {
|
||||
if (isAnalyzing) return
|
||||
console.log('onRequestStart', context.isAnalyzing)
|
||||
if (context.isAnalyzing) return
|
||||
console.log('🧠 [SearchOrchestration] Starting intent analysis...', context.requestId)
|
||||
|
||||
try {
|
||||
@ -294,7 +295,7 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
* 🔧 Step 2: 工具配置阶段
|
||||
*/
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
if (isAnalyzing) return
|
||||
if (context.isAnalyzing) return params
|
||||
console.log('🔧 [SearchOrchestration] Configuring tools based on intent...', context.requestId)
|
||||
|
||||
try {
|
||||
@ -314,8 +315,13 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
const needsSearch = analysisResult.websearch.question && analysisResult.websearch.question[0] !== 'not_needed'
|
||||
|
||||
if (needsSearch) {
|
||||
console.log('🌐 [SearchOrchestration] Adding web search tool')
|
||||
params.tools['builtin_web_search'] = webSearchTool(assistant.webSearchProviderId)
|
||||
// onChunk({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS })
|
||||
console.log('🌐 [SearchOrchestration] Adding web search tool with pre-extracted keywords')
|
||||
params.tools['builtin_web_search'] = webSearchToolWithPreExtractedKeywords(
|
||||
assistant.webSearchProviderId,
|
||||
analysisResult.websearch,
|
||||
context.requestId
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -370,7 +376,7 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
const messages = context.originalParams.messages
|
||||
|
||||
if (messages && assistant) {
|
||||
await storeConversationMemory(messages, assistant)
|
||||
await storeConversationMemory(messages, assistant, context)
|
||||
}
|
||||
|
||||
// 清理缓存
|
||||
|
||||
@ -1,131 +1,262 @@
|
||||
import { extractSearchKeywords } from '@renderer/aiCore/transformParameters'
|
||||
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||
import WebSearchService from '@renderer/services/WebSearchService'
|
||||
import { Assistant, Message, WebSearchProvider } from '@renderer/types'
|
||||
import { UserMessageStatus } from '@renderer/types/newMessage'
|
||||
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
|
||||
import { ExtractResults } from '@renderer/utils/extract'
|
||||
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
|
||||
import { InferToolInput, InferToolOutput, tool } from 'ai'
|
||||
import { z } from 'zod'
|
||||
|
||||
// import { AiSdkTool, ToolCallResult } from './types'
|
||||
|
||||
const WebSearchProviderResult = z.object({
|
||||
query: z.string().optional(),
|
||||
results: z.array(
|
||||
z.object({
|
||||
title: z.string(),
|
||||
content: z.string(),
|
||||
url: z.string()
|
||||
})
|
||||
)
|
||||
})
|
||||
const webSearchToolInputSchema = z.object({
|
||||
query: z.string().describe('The query to search for')
|
||||
})
|
||||
// const WebSearchResult = z.array(
|
||||
// z.object({
|
||||
// query: z.string().optional(),
|
||||
// results: z.array(
|
||||
// z.object({
|
||||
// title: z.string(),
|
||||
// content: z.string(),
|
||||
// url: z.string()
|
||||
// })
|
||||
// )
|
||||
// })
|
||||
// )
|
||||
// const webSearchToolInputSchema = z.object({
|
||||
// query: z.string().describe('The query to search for')
|
||||
// })
|
||||
|
||||
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => {
|
||||
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||
return tool({
|
||||
name: 'builtin_web_search',
|
||||
description: 'Search the web for information',
|
||||
inputSchema: webSearchToolInputSchema,
|
||||
outputSchema: WebSearchProviderResult,
|
||||
execute: async ({ query }) => {
|
||||
console.log('webSearchTool', query)
|
||||
const response = await webSearchService.search(query)
|
||||
console.log('webSearchTool response', response)
|
||||
return response
|
||||
}
|
||||
})
|
||||
}
|
||||
export type WebSearchToolInput = InferToolInput<ReturnType<typeof webSearchTool>>
|
||||
export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchTool>>
|
||||
// export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => {
|
||||
// const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||
// return tool({
|
||||
// name: 'builtin_web_search',
|
||||
// description: 'Search the web for information',
|
||||
// inputSchema: webSearchToolInputSchema,
|
||||
// outputSchema: WebSearchProviderResult,
|
||||
// execute: async ({ query }) => {
|
||||
// console.log('webSearchTool', query)
|
||||
// const response = await webSearchService.search(query)
|
||||
// console.log('webSearchTool response', response)
|
||||
// return response
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// export type WebSearchToolInput = InferToolInput<ReturnType<typeof webSearchTool>>
|
||||
// export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchTool>>
|
||||
|
||||
export const webSearchToolWithExtraction = (
|
||||
/**
|
||||
* 使用预提取关键词的网络搜索工具
|
||||
* 这个工具直接使用插件阶段分析的搜索意图,避免重复分析
|
||||
*/
|
||||
export const webSearchToolWithPreExtractedKeywords = (
|
||||
webSearchProviderId: WebSearchProvider['id'],
|
||||
requestId: string,
|
||||
assistant: Assistant
|
||||
extractedKeywords: {
|
||||
question: string[]
|
||||
links?: string[]
|
||||
},
|
||||
requestId: string
|
||||
) => {
|
||||
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||
|
||||
return tool({
|
||||
name: 'web_search_with_extraction',
|
||||
description: 'Search the web for information with automatic keyword extraction from user messages',
|
||||
name: 'builtin_web_search',
|
||||
description: `Search the web and return citable sources using pre-analyzed search intent.
|
||||
|
||||
Pre-extracted search keywords: "${extractedKeywords.question.join(', ')}"${
|
||||
extractedKeywords.links
|
||||
? `
|
||||
Relevant links: ${extractedKeywords.links.join(', ')}`
|
||||
: ''
|
||||
}
|
||||
|
||||
This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.
|
||||
|
||||
Call this tool to execute the search. You can optionally provide additional context to refine the search.`,
|
||||
|
||||
inputSchema: z.object({
|
||||
userMessage: z.object({
|
||||
content: z.string().describe('The main content of the message'),
|
||||
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
}),
|
||||
lastAnswer: z.object({
|
||||
content: z.string().describe('The main content of the message'),
|
||||
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
})
|
||||
additionalContext: z
|
||||
.string()
|
||||
.optional()
|
||||
.describe('Optional additional context, keywords, or specific focus to enhance the search')
|
||||
}),
|
||||
outputSchema: z.object({
|
||||
extractedKeywords: z.object({
|
||||
question: z.array(z.string()),
|
||||
links: z.array(z.string()).optional()
|
||||
}),
|
||||
searchResults: z.array(
|
||||
z.object({
|
||||
query: z.string(),
|
||||
results: WebSearchProviderResult
|
||||
})
|
||||
)
|
||||
}),
|
||||
execute: async ({ userMessage, lastAnswer }) => {
|
||||
const lastUserMessage: Message = {
|
||||
id: requestId,
|
||||
role: userMessage.role,
|
||||
assistantId: assistant.id,
|
||||
topicId: 'temp',
|
||||
createdAt: new Date().toISOString(),
|
||||
status: UserMessageStatus.SUCCESS,
|
||||
blocks: []
|
||||
|
||||
execute: async ({ additionalContext }) => {
|
||||
let finalQueries = [...extractedKeywords.question]
|
||||
|
||||
if (additionalContext?.trim()) {
|
||||
// 如果大模型提供了额外上下文,使用更具体的描述
|
||||
console.log(`🔍 AI enhanced search with: ${additionalContext}`)
|
||||
const cleanContext = additionalContext.trim()
|
||||
if (cleanContext) {
|
||||
finalQueries = [cleanContext]
|
||||
console.log(`➕ Added additional context: ${cleanContext}`)
|
||||
}
|
||||
}
|
||||
|
||||
const lastAnswerMessage: Message | undefined = lastAnswer
|
||||
? {
|
||||
id: requestId + '_answer',
|
||||
role: lastAnswer.role,
|
||||
assistantId: assistant.id,
|
||||
topicId: 'temp',
|
||||
createdAt: new Date().toISOString(),
|
||||
status: UserMessageStatus.SUCCESS,
|
||||
blocks: []
|
||||
}
|
||||
: undefined
|
||||
const searchResults: WebSearchProviderResponse[] = []
|
||||
|
||||
const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
|
||||
shouldWebSearch: true,
|
||||
shouldKnowledgeSearch: false,
|
||||
lastAnswer: lastAnswerMessage
|
||||
})
|
||||
|
||||
if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
|
||||
return 'No search needed or extraction failed'
|
||||
// 检查是否需要搜索
|
||||
if (finalQueries[0] === 'not_needed') {
|
||||
return {
|
||||
summary: 'No search needed based on the query analysis.',
|
||||
searchResults: [],
|
||||
sources: '',
|
||||
instructions: '',
|
||||
rawResults: []
|
||||
}
|
||||
}
|
||||
|
||||
const searchQueries = extractResults.websearch.question
|
||||
const searchResults: Array<{ query: string; results: any }> = []
|
||||
|
||||
for (const query of searchQueries) {
|
||||
// 构建单个查询的ExtractResults结构
|
||||
const queryExtractResults: ExtractResults = {
|
||||
try {
|
||||
// 构建 ExtractResults 结构用于 processWebsearch
|
||||
const extractResults: ExtractResults = {
|
||||
websearch: {
|
||||
question: [query],
|
||||
links: extractResults.websearch.links
|
||||
question: finalQueries,
|
||||
links: extractedKeywords.links
|
||||
}
|
||||
}
|
||||
const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
|
||||
searchResults.push({
|
||||
query,
|
||||
results: response
|
||||
})
|
||||
console.log('extractResults', extractResults)
|
||||
const response = await webSearchService.processWebsearch(extractResults, requestId)
|
||||
searchResults.push(response)
|
||||
} catch (error) {
|
||||
console.error(`Web search failed for query "${finalQueries}":`, error)
|
||||
return {
|
||||
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
searchResults: [],
|
||||
sources: '',
|
||||
instructions: '',
|
||||
rawResults: []
|
||||
}
|
||||
}
|
||||
|
||||
return { extractedKeywords: extractResults.websearch, searchResults }
|
||||
if (searchResults.length === 0 || !searchResults[0].results) {
|
||||
return {
|
||||
summary: 'No search results found for the given query.',
|
||||
searchResults: [],
|
||||
sources: '',
|
||||
instructions: '',
|
||||
rawResults: []
|
||||
}
|
||||
}
|
||||
|
||||
const results = searchResults[0].results
|
||||
const citationData = results.map((result, index) => ({
|
||||
number: index + 1,
|
||||
title: result.title,
|
||||
content: result.content,
|
||||
url: result.url
|
||||
}))
|
||||
|
||||
// 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑
|
||||
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
|
||||
|
||||
// 构建完整的引用指导文本
|
||||
const fullInstructions = REFERENCE_PROMPT.replace(
|
||||
'{question}',
|
||||
"Based on the search results, please answer the user's question with proper citations."
|
||||
).replace('{references}', referenceContent)
|
||||
|
||||
return {
|
||||
summary: `Found ${citationData.length} relevant sources. Use [number] format to cite specific information.`,
|
||||
searchResults,
|
||||
sources: citationData
|
||||
.map((source) => `[${source.number}] ${source.title}\n${source.content}\nURL: ${source.url}`)
|
||||
.join('\n\n'),
|
||||
|
||||
instructions: fullInstructions,
|
||||
|
||||
// 原始数据,便于后续处理
|
||||
rawResults: citationData
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export type WebSearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof webSearchToolWithExtraction>>
|
||||
// export const webSearchToolWithExtraction = (
|
||||
// webSearchProviderId: WebSearchProvider['id'],
|
||||
// requestId: string,
|
||||
// assistant: Assistant
|
||||
// ) => {
|
||||
// const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||
|
||||
// return tool({
|
||||
// name: 'web_search_with_extraction',
|
||||
// description: 'Search the web for information with automatic keyword extraction from user messages',
|
||||
// inputSchema: z.object({
|
||||
// userMessage: z.object({
|
||||
// content: z.string().describe('The main content of the message'),
|
||||
// role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
// }),
|
||||
// lastAnswer: z.object({
|
||||
// content: z.string().describe('The main content of the message'),
|
||||
// role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
// })
|
||||
// }),
|
||||
// outputSchema: z.object({
|
||||
// extractedKeywords: z.object({
|
||||
// question: z.array(z.string()),
|
||||
// links: z.array(z.string()).optional()
|
||||
// }),
|
||||
// searchResults: z.array(
|
||||
// z.object({
|
||||
// query: z.string(),
|
||||
// results: WebSearchProviderResult
|
||||
// })
|
||||
// )
|
||||
// }),
|
||||
// execute: async ({ userMessage, lastAnswer }) => {
|
||||
// const lastUserMessage: Message = {
|
||||
// id: requestId,
|
||||
// role: userMessage.role,
|
||||
// assistantId: assistant.id,
|
||||
// topicId: 'temp',
|
||||
// createdAt: new Date().toISOString(),
|
||||
// status: UserMessageStatus.SUCCESS,
|
||||
// blocks: []
|
||||
// }
|
||||
|
||||
// const lastAnswerMessage: Message | undefined = lastAnswer
|
||||
// ? {
|
||||
// id: requestId + '_answer',
|
||||
// role: lastAnswer.role,
|
||||
// assistantId: assistant.id,
|
||||
// topicId: 'temp',
|
||||
// createdAt: new Date().toISOString(),
|
||||
// status: UserMessageStatus.SUCCESS,
|
||||
// blocks: []
|
||||
// }
|
||||
// : undefined
|
||||
|
||||
// const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
|
||||
// shouldWebSearch: true,
|
||||
// shouldKnowledgeSearch: false,
|
||||
// lastAnswer: lastAnswerMessage
|
||||
// })
|
||||
|
||||
// if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
|
||||
// return 'No search needed or extraction failed'
|
||||
// }
|
||||
|
||||
// const searchQueries = extractResults.websearch.question
|
||||
// const searchResults: Array<{ query: string; results: any }> = []
|
||||
|
||||
// for (const query of searchQueries) {
|
||||
// // 构建单个查询的ExtractResults结构
|
||||
// const queryExtractResults: ExtractResults = {
|
||||
// websearch: {
|
||||
// question: [query],
|
||||
// links: extractResults.websearch.links
|
||||
// }
|
||||
// }
|
||||
// const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
|
||||
// searchResults.push({
|
||||
// query,
|
||||
// results: response
|
||||
// })
|
||||
// }
|
||||
|
||||
// return { extractedKeywords: extractResults.websearch, searchResults }
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
// export type WebSearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof webSearchToolWithExtraction>>
|
||||
|
||||
export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchToolWithPreExtractedKeywords>>
|
||||
export type WebSearchToolInput = InferToolInput<ReturnType<typeof webSearchToolWithPreExtractedKeywords>>
|
||||
|
||||
@ -0,0 +1,72 @@
|
||||
import { KnowledgeSearchToolInput, KnowledgeSearchToolOutput } from '@renderer/aiCore/tools/KnowledgeSearchTool'
|
||||
import Spinner from '@renderer/components/Spinner'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { MCPToolResponse } from '@renderer/types'
|
||||
import { Typography } from 'antd'
|
||||
import { FileSearch } from 'lucide-react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
const { Text } = Typography
|
||||
export function MessageKnowledgeSearchToolTitle({ toolResponse }: { toolResponse: MCPToolResponse }) {
|
||||
const toolInput = toolResponse.arguments as KnowledgeSearchToolInput
|
||||
const toolOutput = toolResponse.response as KnowledgeSearchToolOutput
|
||||
|
||||
return toolResponse.status !== 'done' ? (
|
||||
<Spinner
|
||||
text={
|
||||
<PrepareToolWrapper>
|
||||
{i18n.t('message.searching')}
|
||||
<span>{toolInput?.rewrite ?? toolInput?.query ?? ''}</span>
|
||||
</PrepareToolWrapper>
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
<MessageWebSearchToolTitleTextWrapper type="secondary">
|
||||
<FileSearch size={16} style={{ color: 'unset' }} />
|
||||
{i18n.t('message.websearch.fetch_complete', { count: toolOutput.length ?? 0 })}
|
||||
</MessageWebSearchToolTitleTextWrapper>
|
||||
)
|
||||
}
|
||||
|
||||
export function MessageKnowledgeSearchToolBody({ toolResponse }: { toolResponse: MCPToolResponse }) {
|
||||
const toolOutput = toolResponse.response as KnowledgeSearchToolOutput
|
||||
|
||||
return toolResponse.status === 'done' ? (
|
||||
<MessageWebSearchToolBodyUlWrapper>
|
||||
{toolOutput.map((result) => (
|
||||
<li key={result.id}>
|
||||
<span>{result.id}</span>
|
||||
<span>{result.content}</span>
|
||||
</li>
|
||||
))}
|
||||
</MessageWebSearchToolBodyUlWrapper>
|
||||
) : null
|
||||
}
|
||||
|
||||
const PrepareToolWrapper = styled.span`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
font-size: 14px;
|
||||
padding-left: 0;
|
||||
`
|
||||
const MessageWebSearchToolTitleTextWrapper = styled(Text)`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
`
|
||||
|
||||
const MessageWebSearchToolBodyUlWrapper = styled.ul`
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
padding: 0;
|
||||
> li {
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
max-width: 70%;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
`
|
||||
@ -1,11 +1,14 @@
|
||||
import { MCPToolResponse } from '@renderer/types'
|
||||
import type { ToolMessageBlock } from '@renderer/types/newMessage'
|
||||
import { Collapse } from 'antd'
|
||||
|
||||
import { MessageKnowledgeSearchToolBody, MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
|
||||
import { MessageWebSearchToolBody, MessageWebSearchToolTitle } from './MessageWebSearchTool'
|
||||
|
||||
interface Props {
|
||||
block: ToolMessageBlock
|
||||
}
|
||||
const prefix = 'builtin_'
|
||||
|
||||
// const toolNameMapText = {
|
||||
// web_search: i18n.t('message.searching')
|
||||
@ -41,18 +44,62 @@ interface Props {
|
||||
// return <p>{toolDoneNameText}</p>
|
||||
// }
|
||||
|
||||
// const ToolLabelComponents = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
|
||||
// if (webSearchToolNames.includes(toolResponse.tool.name)) {
|
||||
// return <MessageWebSearchToolTitle toolResponse={toolResponse} />
|
||||
// }
|
||||
// return <MessageWebSearchToolTitle toolResponse={toolResponse} />
|
||||
// }
|
||||
|
||||
// const ToolBodyComponents = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
|
||||
// if (webSearchToolNames.includes(toolResponse.tool.name)) {
|
||||
// return <MessageWebSearchToolBody toolResponse={toolResponse} />
|
||||
// }
|
||||
// return <MessageWebSearchToolBody toolResponse={toolResponse} />
|
||||
// }
|
||||
|
||||
const ChooseTool = (
|
||||
toolResponse: MCPToolResponse
|
||||
): {
|
||||
label: React.ReactNode
|
||||
body: React.ReactNode
|
||||
} => {
|
||||
let toolName = toolResponse.tool.name
|
||||
if (toolName.startsWith(prefix)) {
|
||||
toolName = toolName.slice(prefix.length)
|
||||
}
|
||||
|
||||
switch (toolName) {
|
||||
case 'web_search':
|
||||
case 'web_search_preview':
|
||||
return {
|
||||
label: <MessageWebSearchToolTitle toolResponse={toolResponse} />,
|
||||
body: <MessageWebSearchToolBody toolResponse={toolResponse} />
|
||||
}
|
||||
case 'knowledge_search':
|
||||
return {
|
||||
label: <MessageKnowledgeSearchToolTitle toolResponse={toolResponse} />,
|
||||
body: <MessageKnowledgeSearchToolBody toolResponse={toolResponse} />
|
||||
}
|
||||
default:
|
||||
return {
|
||||
label: <MessageWebSearchToolTitle toolResponse={toolResponse} />,
|
||||
body: <MessageWebSearchToolBody toolResponse={toolResponse} />
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default function MessageTool({ block }: Props) {
|
||||
const toolResponse = block.metadata?.rawMcpToolResponse
|
||||
if (!toolResponse) return null
|
||||
console.log('toolResponse', toolResponse)
|
||||
|
||||
return (
|
||||
<Collapse
|
||||
items={[
|
||||
{
|
||||
key: '1',
|
||||
label: <MessageWebSearchToolTitle toolResponse={toolResponse} />,
|
||||
children: <MessageWebSearchToolBody toolResponse={toolResponse} />,
|
||||
label: ChooseTool(toolResponse).label,
|
||||
children: ChooseTool(toolResponse).body,
|
||||
showArrow: false,
|
||||
styles: {
|
||||
header: {
|
||||
|
||||
@ -17,14 +17,16 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT
|
||||
text={
|
||||
<PrepareToolWrapper>
|
||||
{i18n.t('message.searching')}
|
||||
<span>{toolInput?.query ?? ''}</span>
|
||||
<span>{toolInput?.additionalContext ?? ''}</span>
|
||||
</PrepareToolWrapper>
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
<MessageWebSearchToolTitleTextWrapper type="secondary">
|
||||
<Search size={16} style={{ color: 'unset' }} />
|
||||
{i18n.t('message.websearch.fetch_complete', { count: toolOutput.results.length ?? 0 })}
|
||||
{i18n.t('message.websearch.fetch_complete', {
|
||||
count: toolOutput?.searchResults?.reduce((acc, result) => acc + result.results.length, 0) ?? 0
|
||||
})}
|
||||
</MessageWebSearchToolTitleTextWrapper>
|
||||
)
|
||||
}
|
||||
@ -32,15 +34,17 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT
|
||||
export const MessageWebSearchToolBody = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
|
||||
const toolOutput = toolResponse.response as WebSearchToolOutput
|
||||
|
||||
return toolResponse.status === 'done' ? (
|
||||
<MessageWebSearchToolBodyUlWrapper>
|
||||
{toolOutput.results.map((result) => (
|
||||
<li key={result.url}>
|
||||
<Link href={result.url}>{result.title}</Link>
|
||||
</li>
|
||||
))}
|
||||
</MessageWebSearchToolBodyUlWrapper>
|
||||
) : null
|
||||
return toolResponse.status === 'done'
|
||||
? toolOutput?.searchResults?.map((result, index) => (
|
||||
<MessageWebSearchToolBodyUlWrapper key={result?.query ?? '' + index}>
|
||||
{result.results.map((item, index) => (
|
||||
<li key={item.url + index}>
|
||||
<Link href={item.url}>{item.title}</Link>
|
||||
</li>
|
||||
))}
|
||||
</MessageWebSearchToolBodyUlWrapper>
|
||||
))
|
||||
: null
|
||||
}
|
||||
|
||||
const PrepareToolWrapper = styled.span`
|
||||
|
||||
@ -881,7 +881,6 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
saveUpdatesToDB,
|
||||
assistant
|
||||
})
|
||||
console.log('callbacks', callbacks)
|
||||
const streamProcessorCallbacks = createStreamProcessor(callbacks)
|
||||
|
||||
const abortController = new AbortController()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user