mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat(tools): refactor MemorySearchTool and WebSearchTool for improved response handling
- Updated MemorySearchTool to utilize aiSdk for better integration and removed unused imports. - Refactored WebSearchTool to streamline search results handling, changing from an array to a structured object for clarity. - Adjusted MessageTool and MessageWebSearchTool components to reflect changes in tool response structure. - Enhanced error handling and logging in tool callbacks for improved debugging and user feedback.
This commit is contained in:
parent
d34b640807
commit
ca4e7e3d2b
@ -1,11 +1,13 @@
|
||||
import { aiSdk, type InferToolOutput } from '@cherrystudio/ai-core'
|
||||
import store from '@renderer/store'
|
||||
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
||||
import type { Assistant } from '@renderer/types'
|
||||
import { type InferToolOutput, tool } from 'ai'
|
||||
// import { type InferToolOutput, tool } from 'ai'
|
||||
import { z } from 'zod'
|
||||
|
||||
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||
|
||||
const { tool } = aiSdk
|
||||
/**
|
||||
* 🧠 基础记忆搜索工具
|
||||
* AI 可以主动调用的简单记忆搜索
|
||||
|
||||
@ -54,15 +54,17 @@ Call this tool to execute the search. You can optionally provide additional cont
|
||||
}
|
||||
}
|
||||
|
||||
const searchResults: WebSearchProviderResponse[] = []
|
||||
let searchResults: WebSearchProviderResponse = {
|
||||
query: '',
|
||||
results: []
|
||||
}
|
||||
// 检查是否需要搜索
|
||||
if (finalQueries[0] === 'not_needed') {
|
||||
return {
|
||||
summary: 'No search needed based on the query analysis.',
|
||||
searchResults: [],
|
||||
searchResults,
|
||||
sources: '',
|
||||
instructions: '',
|
||||
rawResults: []
|
||||
instructions: ''
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,29 +76,24 @@ Call this tool to execute the search. You can optionally provide additional cont
|
||||
links: extractedKeywords.links
|
||||
}
|
||||
}
|
||||
const response = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||
searchResults.push(response)
|
||||
searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||
} catch (error) {
|
||||
return {
|
||||
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
searchResults: [],
|
||||
sources: '',
|
||||
instructions: '',
|
||||
rawResults: []
|
||||
sources: [],
|
||||
instructions: ''
|
||||
}
|
||||
}
|
||||
|
||||
if (searchResults.length === 0 || !searchResults[0].results) {
|
||||
console.log('searchResults', searchResults)
|
||||
if (searchResults.results.length === 0) {
|
||||
return {
|
||||
summary: 'No search results found for the given query.',
|
||||
searchResults: [],
|
||||
sources: '',
|
||||
instructions: '',
|
||||
rawResults: []
|
||||
sources: [],
|
||||
instructions: ''
|
||||
}
|
||||
}
|
||||
|
||||
const results = searchResults[0].results
|
||||
const results = searchResults.results
|
||||
const citationData = results.map((result, index) => ({
|
||||
number: index + 1,
|
||||
title: result.title,
|
||||
@ -105,25 +102,18 @@ Call this tool to execute the search. You can optionally provide additional cont
|
||||
}))
|
||||
|
||||
// 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑
|
||||
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
|
||||
// const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
|
||||
|
||||
// 构建完整的引用指导文本
|
||||
const fullInstructions = REFERENCE_PROMPT.replace(
|
||||
'{question}',
|
||||
"Based on the search results, please answer the user's question with proper citations."
|
||||
).replace('{references}', referenceContent)
|
||||
).replace('{references}', 'searchResults:')
|
||||
|
||||
return {
|
||||
summary: `Found ${citationData.length} relevant sources. Use [number] format to cite specific information.`,
|
||||
searchResults,
|
||||
sources: citationData
|
||||
.map((source) => `[${source.number}] ${source.title}\n${source.content}\nURL: ${source.url}`)
|
||||
.join('\n\n'),
|
||||
|
||||
instructions: fullInstructions,
|
||||
|
||||
// 原始数据,便于后续处理
|
||||
rawResults: citationData
|
||||
instructions: fullInstructions
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@ -2,8 +2,8 @@ import { MCPToolResponse } from '@renderer/types'
|
||||
import type { ToolMessageBlock } from '@renderer/types/newMessage'
|
||||
import { Collapse } from 'antd'
|
||||
|
||||
import { MessageKnowledgeSearchToolBody, MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
|
||||
import { MessageWebSearchToolBody, MessageWebSearchToolTitle } from './MessageWebSearchTool'
|
||||
import { MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
|
||||
import { MessageWebSearchToolTitle } from './MessageWebSearchTool'
|
||||
|
||||
interface Props {
|
||||
block: ToolMessageBlock
|
||||
@ -69,12 +69,12 @@ const ChooseTool = (toolResponse: MCPToolResponse): { label: React.ReactNode; bo
|
||||
case 'web_search_preview':
|
||||
return {
|
||||
label: <MessageWebSearchToolTitle toolResponse={toolResponse} />,
|
||||
body: <MessageWebSearchToolBody toolResponse={toolResponse} />
|
||||
body: null
|
||||
}
|
||||
case 'knowledge_search':
|
||||
return {
|
||||
label: <MessageKnowledgeSearchToolTitle toolResponse={toolResponse} />,
|
||||
body: <MessageKnowledgeSearchToolBody toolResponse={toolResponse} />
|
||||
body: null
|
||||
}
|
||||
default:
|
||||
return null
|
||||
@ -82,7 +82,7 @@ const ChooseTool = (toolResponse: MCPToolResponse): { label: React.ReactNode; bo
|
||||
}
|
||||
|
||||
export default function MessageTool({ block }: Props) {
|
||||
// FIXME: 语义错误,这里已经不是 MCP tool 了
|
||||
// FIXME: 语义错误,这里已经不是 MCP tool 了,更改rawMcpToolResponse需要改用户数据, 所以暂时保留
|
||||
const toolResponse = block.metadata?.rawMcpToolResponse
|
||||
|
||||
if (!toolResponse) return null
|
||||
@ -91,7 +91,7 @@ export default function MessageTool({ block }: Props) {
|
||||
|
||||
if (!toolRenderer) return null
|
||||
|
||||
return (
|
||||
return toolRenderer.body ? (
|
||||
<Collapse
|
||||
items={[
|
||||
{
|
||||
@ -109,6 +109,8 @@ export default function MessageTool({ block }: Props) {
|
||||
size="small"
|
||||
ghost
|
||||
/>
|
||||
) : (
|
||||
toolRenderer.label
|
||||
)
|
||||
}
|
||||
// const PrepareToolWrapper = styled.span`
|
||||
|
||||
@ -25,7 +25,7 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT
|
||||
<MessageWebSearchToolTitleTextWrapper type="secondary">
|
||||
<Search size={16} style={{ color: 'unset' }} />
|
||||
{i18n.t('message.websearch.fetch_complete', {
|
||||
count: toolOutput?.searchResults?.reduce((acc, result) => acc + result.results.length, 0) ?? 0
|
||||
count: toolOutput?.searchResults?.results?.length ?? 0
|
||||
})}
|
||||
</MessageWebSearchToolTitleTextWrapper>
|
||||
)
|
||||
|
||||
@ -380,7 +380,7 @@ export async function fetchGenerate({
|
||||
// }
|
||||
|
||||
const middlewareConfig: AiSdkMiddlewareConfig = {
|
||||
streamOutput: assistant.settings?.streamOutput ?? true,
|
||||
streamOutput: assistant.settings?.streamOutput ?? false,
|
||||
enableReasoning: false,
|
||||
isPromptToolUse: false,
|
||||
isSupportedToolUse: false,
|
||||
@ -393,6 +393,7 @@ export async function fetchGenerate({
|
||||
const result = await AI.completions(
|
||||
model.id,
|
||||
{
|
||||
system: prompt,
|
||||
prompt: content
|
||||
},
|
||||
{
|
||||
|
||||
@ -103,11 +103,11 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
||||
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
|
||||
|
||||
// Handle citation block creation for web search results
|
||||
if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response?.rawResults) {
|
||||
if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response?.searchResults) {
|
||||
const citationBlock = createCitationBlock(
|
||||
assistantMsgId,
|
||||
{
|
||||
response: { results: toolResponse.response.rawResults, source: WebSearchSource.AISDK }
|
||||
response: { results: toolResponse.response.searchResults, source: WebSearchSource.WEBSEARCH }
|
||||
},
|
||||
{
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
|
||||
Loading…
Reference in New Issue
Block a user