feat(tools): refactor MemorySearchTool and WebSearchTool for improved response handling

- Updated MemorySearchTool to utilize aiSdk for better integration and removed unused imports.
- Refactored WebSearchTool to streamline search results handling, changing from an array to a structured object for clarity.
- Adjusted MessageTool and MessageWebSearchTool components to reflect changes in tool response structure.
- Enhanced error handling and logging in tool callbacks for improved debugging and user feedback.
This commit is contained in:
MyPrototypeWhat 2025-08-22 19:35:09 +08:00
parent d34b640807
commit ca4e7e3d2b
6 changed files with 33 additions and 38 deletions

View File

@ -1,11 +1,13 @@
import { aiSdk, type InferToolOutput } from '@cherrystudio/ai-core'
import store from '@renderer/store'
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
import type { Assistant } from '@renderer/types'
import { type InferToolOutput, tool } from 'ai'
// import { type InferToolOutput, tool } from 'ai'
import { z } from 'zod'
import { MemoryProcessor } from '../../services/MemoryProcessor'
const { tool } = aiSdk
/**
* 🧠
* AI

View File

@ -54,15 +54,17 @@ Call this tool to execute the search. You can optionally provide additional cont
}
}
const searchResults: WebSearchProviderResponse[] = []
let searchResults: WebSearchProviderResponse = {
query: '',
results: []
}
// 检查是否需要搜索
if (finalQueries[0] === 'not_needed') {
return {
summary: 'No search needed based on the query analysis.',
searchResults: [],
searchResults,
sources: '',
instructions: '',
rawResults: []
instructions: ''
}
}
@ -74,29 +76,24 @@ Call this tool to execute the search. You can optionally provide additional cont
links: extractedKeywords.links
}
}
const response = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
searchResults.push(response)
searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
} catch (error) {
return {
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
searchResults: [],
sources: '',
instructions: '',
rawResults: []
sources: [],
instructions: ''
}
}
if (searchResults.length === 0 || !searchResults[0].results) {
console.log('searchResults', searchResults)
if (searchResults.results.length === 0) {
return {
summary: 'No search results found for the given query.',
searchResults: [],
sources: '',
instructions: '',
rawResults: []
sources: [],
instructions: ''
}
}
const results = searchResults[0].results
const results = searchResults.results
const citationData = results.map((result, index) => ({
number: index + 1,
title: result.title,
@ -105,25 +102,18 @@ Call this tool to execute the search. You can optionally provide additional cont
}))
// 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
// const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
// 构建完整的引用指导文本
const fullInstructions = REFERENCE_PROMPT.replace(
'{question}',
"Based on the search results, please answer the user's question with proper citations."
).replace('{references}', referenceContent)
).replace('{references}', 'searchResults:')
return {
summary: `Found ${citationData.length} relevant sources. Use [number] format to cite specific information.`,
searchResults,
sources: citationData
.map((source) => `[${source.number}] ${source.title}\n${source.content}\nURL: ${source.url}`)
.join('\n\n'),
instructions: fullInstructions,
// 原始数据,便于后续处理
rawResults: citationData
instructions: fullInstructions
}
}
})

View File

@ -2,8 +2,8 @@ import { MCPToolResponse } from '@renderer/types'
import type { ToolMessageBlock } from '@renderer/types/newMessage'
import { Collapse } from 'antd'
import { MessageKnowledgeSearchToolBody, MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
import { MessageWebSearchToolBody, MessageWebSearchToolTitle } from './MessageWebSearchTool'
import { MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
import { MessageWebSearchToolTitle } from './MessageWebSearchTool'
interface Props {
block: ToolMessageBlock
@ -69,12 +69,12 @@ const ChooseTool = (toolResponse: MCPToolResponse): { label: React.ReactNode; bo
case 'web_search_preview':
return {
label: <MessageWebSearchToolTitle toolResponse={toolResponse} />,
body: <MessageWebSearchToolBody toolResponse={toolResponse} />
body: null
}
case 'knowledge_search':
return {
label: <MessageKnowledgeSearchToolTitle toolResponse={toolResponse} />,
body: <MessageKnowledgeSearchToolBody toolResponse={toolResponse} />
body: null
}
default:
return null
@ -82,7 +82,7 @@ const ChooseTool = (toolResponse: MCPToolResponse): { label: React.ReactNode; bo
}
export default function MessageTool({ block }: Props) {
// FIXME: 语义错误,这里已经不是 MCP tool 了
// FIXME: 语义错误,这里已经不是 MCP tool 了,更改rawMcpToolResponse需要改用户数据, 所以暂时保留
const toolResponse = block.metadata?.rawMcpToolResponse
if (!toolResponse) return null
@ -91,7 +91,7 @@ export default function MessageTool({ block }: Props) {
if (!toolRenderer) return null
return (
return toolRenderer.body ? (
<Collapse
items={[
{
@ -109,6 +109,8 @@ export default function MessageTool({ block }: Props) {
size="small"
ghost
/>
) : (
toolRenderer.label
)
}
// const PrepareToolWrapper = styled.span`

View File

@ -25,7 +25,7 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT
<MessageWebSearchToolTitleTextWrapper type="secondary">
<Search size={16} style={{ color: 'unset' }} />
{i18n.t('message.websearch.fetch_complete', {
count: toolOutput?.searchResults?.reduce((acc, result) => acc + result.results.length, 0) ?? 0
count: toolOutput?.searchResults?.results?.length ?? 0
})}
</MessageWebSearchToolTitleTextWrapper>
)

View File

@ -380,7 +380,7 @@ export async function fetchGenerate({
// }
const middlewareConfig: AiSdkMiddlewareConfig = {
streamOutput: assistant.settings?.streamOutput ?? true,
streamOutput: assistant.settings?.streamOutput ?? false,
enableReasoning: false,
isPromptToolUse: false,
isSupportedToolUse: false,
@ -393,6 +393,7 @@ export async function fetchGenerate({
const result = await AI.completions(
model.id,
{
system: prompt,
prompt: content
},
{

View File

@ -103,11 +103,11 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
// Handle citation block creation for web search results
if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response?.rawResults) {
if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response?.searchResults) {
const citationBlock = createCitationBlock(
assistantMsgId,
{
response: { results: toolResponse.response.rawResults, source: WebSearchSource.AISDK }
response: { results: toolResponse.response.searchResults, source: WebSearchSource.WEBSEARCH }
},
{
status: MessageBlockStatus.SUCCESS