feat(tools): refactor MemorySearchTool and WebSearchTool for improved response handling

- Updated MemorySearchTool to utilize aiSdk for better integration and removed unused imports.
- Refactored WebSearchTool to streamline search results handling, changing from an array to a structured object for clarity.
- Adjusted MessageTool and MessageWebSearchTool components to reflect changes in tool response structure.
- Enhanced error handling and logging in tool callbacks for improved debugging and user feedback.
This commit is contained in:
MyPrototypeWhat 2025-08-22 19:35:09 +08:00
parent d34b640807
commit ca4e7e3d2b
6 changed files with 33 additions and 38 deletions

View File

@ -1,11 +1,13 @@
import { aiSdk, type InferToolOutput } from '@cherrystudio/ai-core'
import store from '@renderer/store' import store from '@renderer/store'
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory' import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
import type { Assistant } from '@renderer/types' import type { Assistant } from '@renderer/types'
import { type InferToolOutput, tool } from 'ai' // import { type InferToolOutput, tool } from 'ai'
import { z } from 'zod' import { z } from 'zod'
import { MemoryProcessor } from '../../services/MemoryProcessor' import { MemoryProcessor } from '../../services/MemoryProcessor'
const { tool } = aiSdk
/** /**
* 🧠 * 🧠
* AI * AI

View File

@ -54,15 +54,17 @@ Call this tool to execute the search. You can optionally provide additional cont
} }
} }
const searchResults: WebSearchProviderResponse[] = [] let searchResults: WebSearchProviderResponse = {
query: '',
results: []
}
// 检查是否需要搜索 // 检查是否需要搜索
if (finalQueries[0] === 'not_needed') { if (finalQueries[0] === 'not_needed') {
return { return {
summary: 'No search needed based on the query analysis.', summary: 'No search needed based on the query analysis.',
searchResults: [], searchResults,
sources: '', sources: '',
instructions: '', instructions: ''
rawResults: []
} }
} }
@ -74,29 +76,24 @@ Call this tool to execute the search. You can optionally provide additional cont
links: extractedKeywords.links links: extractedKeywords.links
} }
} }
const response = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId) searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
searchResults.push(response)
} catch (error) { } catch (error) {
return { return {
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`, summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
searchResults: [], sources: [],
sources: '', instructions: ''
instructions: '',
rawResults: []
} }
} }
console.log('searchResults', searchResults)
if (searchResults.length === 0 || !searchResults[0].results) { if (searchResults.results.length === 0) {
return { return {
summary: 'No search results found for the given query.', summary: 'No search results found for the given query.',
searchResults: [], sources: [],
sources: '', instructions: ''
instructions: '',
rawResults: []
} }
} }
const results = searchResults[0].results const results = searchResults.results
const citationData = results.map((result, index) => ({ const citationData = results.map((result, index) => ({
number: index + 1, number: index + 1,
title: result.title, title: result.title,
@ -105,25 +102,18 @@ Call this tool to execute the search. You can optionally provide additional cont
})) }))
// 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑 // 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\`` // const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
// 构建完整的引用指导文本 // 构建完整的引用指导文本
const fullInstructions = REFERENCE_PROMPT.replace( const fullInstructions = REFERENCE_PROMPT.replace(
'{question}', '{question}',
"Based on the search results, please answer the user's question with proper citations." "Based on the search results, please answer the user's question with proper citations."
).replace('{references}', referenceContent) ).replace('{references}', 'searchResults:')
return { return {
summary: `Found ${citationData.length} relevant sources. Use [number] format to cite specific information.`, summary: `Found ${citationData.length} relevant sources. Use [number] format to cite specific information.`,
searchResults, searchResults,
sources: citationData instructions: fullInstructions
.map((source) => `[${source.number}] ${source.title}\n${source.content}\nURL: ${source.url}`)
.join('\n\n'),
instructions: fullInstructions,
// 原始数据,便于后续处理
rawResults: citationData
} }
} }
}) })

View File

@ -2,8 +2,8 @@ import { MCPToolResponse } from '@renderer/types'
import type { ToolMessageBlock } from '@renderer/types/newMessage' import type { ToolMessageBlock } from '@renderer/types/newMessage'
import { Collapse } from 'antd' import { Collapse } from 'antd'
import { MessageKnowledgeSearchToolBody, MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch' import { MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
import { MessageWebSearchToolBody, MessageWebSearchToolTitle } from './MessageWebSearchTool' import { MessageWebSearchToolTitle } from './MessageWebSearchTool'
interface Props { interface Props {
block: ToolMessageBlock block: ToolMessageBlock
@ -69,12 +69,12 @@ const ChooseTool = (toolResponse: MCPToolResponse): { label: React.ReactNode; bo
case 'web_search_preview': case 'web_search_preview':
return { return {
label: <MessageWebSearchToolTitle toolResponse={toolResponse} />, label: <MessageWebSearchToolTitle toolResponse={toolResponse} />,
body: <MessageWebSearchToolBody toolResponse={toolResponse} /> body: null
} }
case 'knowledge_search': case 'knowledge_search':
return { return {
label: <MessageKnowledgeSearchToolTitle toolResponse={toolResponse} />, label: <MessageKnowledgeSearchToolTitle toolResponse={toolResponse} />,
body: <MessageKnowledgeSearchToolBody toolResponse={toolResponse} /> body: null
} }
default: default:
return null return null
@ -82,7 +82,7 @@ const ChooseTool = (toolResponse: MCPToolResponse): { label: React.ReactNode; bo
} }
export default function MessageTool({ block }: Props) { export default function MessageTool({ block }: Props) {
// FIXME: 语义错误,这里已经不是 MCP tool 了 // FIXME: 语义错误,这里已经不是 MCP tool 了,更改rawMcpToolResponse需要改用户数据, 所以暂时保留
const toolResponse = block.metadata?.rawMcpToolResponse const toolResponse = block.metadata?.rawMcpToolResponse
if (!toolResponse) return null if (!toolResponse) return null
@ -91,7 +91,7 @@ export default function MessageTool({ block }: Props) {
if (!toolRenderer) return null if (!toolRenderer) return null
return ( return toolRenderer.body ? (
<Collapse <Collapse
items={[ items={[
{ {
@ -109,6 +109,8 @@ export default function MessageTool({ block }: Props) {
size="small" size="small"
ghost ghost
/> />
) : (
toolRenderer.label
) )
} }
// const PrepareToolWrapper = styled.span` // const PrepareToolWrapper = styled.span`

View File

@ -25,7 +25,7 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT
<MessageWebSearchToolTitleTextWrapper type="secondary"> <MessageWebSearchToolTitleTextWrapper type="secondary">
<Search size={16} style={{ color: 'unset' }} /> <Search size={16} style={{ color: 'unset' }} />
{i18n.t('message.websearch.fetch_complete', { {i18n.t('message.websearch.fetch_complete', {
count: toolOutput?.searchResults?.reduce((acc, result) => acc + result.results.length, 0) ?? 0 count: toolOutput?.searchResults?.results?.length ?? 0
})} })}
</MessageWebSearchToolTitleTextWrapper> </MessageWebSearchToolTitleTextWrapper>
) )

View File

@ -380,7 +380,7 @@ export async function fetchGenerate({
// } // }
const middlewareConfig: AiSdkMiddlewareConfig = { const middlewareConfig: AiSdkMiddlewareConfig = {
streamOutput: assistant.settings?.streamOutput ?? true, streamOutput: assistant.settings?.streamOutput ?? false,
enableReasoning: false, enableReasoning: false,
isPromptToolUse: false, isPromptToolUse: false,
isSupportedToolUse: false, isSupportedToolUse: false,
@ -393,6 +393,7 @@ export async function fetchGenerate({
const result = await AI.completions( const result = await AI.completions(
model.id, model.id,
{ {
system: prompt,
prompt: content prompt: content
}, },
{ {

View File

@ -103,11 +103,11 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true) blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
// Handle citation block creation for web search results // Handle citation block creation for web search results
if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response?.rawResults) { if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response?.searchResults) {
const citationBlock = createCitationBlock( const citationBlock = createCitationBlock(
assistantMsgId, assistantMsgId,
{ {
response: { results: toolResponse.response.rawResults, source: WebSearchSource.AISDK } response: { results: toolResponse.response.searchResults, source: WebSearchSource.WEBSEARCH }
}, },
{ {
status: MessageBlockStatus.SUCCESS status: MessageBlockStatus.SUCCESS