mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-08 14:29:15 +08:00
feat(tools): refactor MemorySearchTool and WebSearchTool for improved response handling
- Updated MemorySearchTool to utilize aiSdk for better integration and removed unused imports. - Refactored WebSearchTool to streamline search results handling, changing from an array to a structured object for clarity. - Adjusted MessageTool and MessageWebSearchTool components to reflect changes in tool response structure. - Enhanced error handling and logging in tool callbacks for improved debugging and user feedback.
This commit is contained in:
parent
d34b640807
commit
ca4e7e3d2b
@ -1,11 +1,13 @@
|
|||||||
|
import { aiSdk, type InferToolOutput } from '@cherrystudio/ai-core'
|
||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
||||||
import type { Assistant } from '@renderer/types'
|
import type { Assistant } from '@renderer/types'
|
||||||
import { type InferToolOutput, tool } from 'ai'
|
// import { type InferToolOutput, tool } from 'ai'
|
||||||
import { z } from 'zod'
|
import { z } from 'zod'
|
||||||
|
|
||||||
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||||
|
|
||||||
|
const { tool } = aiSdk
|
||||||
/**
|
/**
|
||||||
* 🧠 基础记忆搜索工具
|
* 🧠 基础记忆搜索工具
|
||||||
* AI 可以主动调用的简单记忆搜索
|
* AI 可以主动调用的简单记忆搜索
|
||||||
|
|||||||
@ -54,15 +54,17 @@ Call this tool to execute the search. You can optionally provide additional cont
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const searchResults: WebSearchProviderResponse[] = []
|
let searchResults: WebSearchProviderResponse = {
|
||||||
|
query: '',
|
||||||
|
results: []
|
||||||
|
}
|
||||||
// 检查是否需要搜索
|
// 检查是否需要搜索
|
||||||
if (finalQueries[0] === 'not_needed') {
|
if (finalQueries[0] === 'not_needed') {
|
||||||
return {
|
return {
|
||||||
summary: 'No search needed based on the query analysis.',
|
summary: 'No search needed based on the query analysis.',
|
||||||
searchResults: [],
|
searchResults,
|
||||||
sources: '',
|
sources: '',
|
||||||
instructions: '',
|
instructions: ''
|
||||||
rawResults: []
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,29 +76,24 @@ Call this tool to execute the search. You can optionally provide additional cont
|
|||||||
links: extractedKeywords.links
|
links: extractedKeywords.links
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const response = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||||
searchResults.push(response)
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
return {
|
return {
|
||||||
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||||
searchResults: [],
|
sources: [],
|
||||||
sources: '',
|
instructions: ''
|
||||||
instructions: '',
|
|
||||||
rawResults: []
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
console.log('searchResults', searchResults)
|
||||||
if (searchResults.length === 0 || !searchResults[0].results) {
|
if (searchResults.results.length === 0) {
|
||||||
return {
|
return {
|
||||||
summary: 'No search results found for the given query.',
|
summary: 'No search results found for the given query.',
|
||||||
searchResults: [],
|
sources: [],
|
||||||
sources: '',
|
instructions: ''
|
||||||
instructions: '',
|
|
||||||
rawResults: []
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const results = searchResults[0].results
|
const results = searchResults.results
|
||||||
const citationData = results.map((result, index) => ({
|
const citationData = results.map((result, index) => ({
|
||||||
number: index + 1,
|
number: index + 1,
|
||||||
title: result.title,
|
title: result.title,
|
||||||
@ -105,25 +102,18 @@ Call this tool to execute the search. You can optionally provide additional cont
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
// 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑
|
// 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑
|
||||||
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
|
// const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
|
||||||
|
|
||||||
// 构建完整的引用指导文本
|
// 构建完整的引用指导文本
|
||||||
const fullInstructions = REFERENCE_PROMPT.replace(
|
const fullInstructions = REFERENCE_PROMPT.replace(
|
||||||
'{question}',
|
'{question}',
|
||||||
"Based on the search results, please answer the user's question with proper citations."
|
"Based on the search results, please answer the user's question with proper citations."
|
||||||
).replace('{references}', referenceContent)
|
).replace('{references}', 'searchResults:')
|
||||||
|
|
||||||
return {
|
return {
|
||||||
summary: `Found ${citationData.length} relevant sources. Use [number] format to cite specific information.`,
|
summary: `Found ${citationData.length} relevant sources. Use [number] format to cite specific information.`,
|
||||||
searchResults,
|
searchResults,
|
||||||
sources: citationData
|
instructions: fullInstructions
|
||||||
.map((source) => `[${source.number}] ${source.title}\n${source.content}\nURL: ${source.url}`)
|
|
||||||
.join('\n\n'),
|
|
||||||
|
|
||||||
instructions: fullInstructions,
|
|
||||||
|
|
||||||
// 原始数据,便于后续处理
|
|
||||||
rawResults: citationData
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@ -2,8 +2,8 @@ import { MCPToolResponse } from '@renderer/types'
|
|||||||
import type { ToolMessageBlock } from '@renderer/types/newMessage'
|
import type { ToolMessageBlock } from '@renderer/types/newMessage'
|
||||||
import { Collapse } from 'antd'
|
import { Collapse } from 'antd'
|
||||||
|
|
||||||
import { MessageKnowledgeSearchToolBody, MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
|
import { MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
|
||||||
import { MessageWebSearchToolBody, MessageWebSearchToolTitle } from './MessageWebSearchTool'
|
import { MessageWebSearchToolTitle } from './MessageWebSearchTool'
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
block: ToolMessageBlock
|
block: ToolMessageBlock
|
||||||
@ -69,12 +69,12 @@ const ChooseTool = (toolResponse: MCPToolResponse): { label: React.ReactNode; bo
|
|||||||
case 'web_search_preview':
|
case 'web_search_preview':
|
||||||
return {
|
return {
|
||||||
label: <MessageWebSearchToolTitle toolResponse={toolResponse} />,
|
label: <MessageWebSearchToolTitle toolResponse={toolResponse} />,
|
||||||
body: <MessageWebSearchToolBody toolResponse={toolResponse} />
|
body: null
|
||||||
}
|
}
|
||||||
case 'knowledge_search':
|
case 'knowledge_search':
|
||||||
return {
|
return {
|
||||||
label: <MessageKnowledgeSearchToolTitle toolResponse={toolResponse} />,
|
label: <MessageKnowledgeSearchToolTitle toolResponse={toolResponse} />,
|
||||||
body: <MessageKnowledgeSearchToolBody toolResponse={toolResponse} />
|
body: null
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return null
|
return null
|
||||||
@ -82,7 +82,7 @@ const ChooseTool = (toolResponse: MCPToolResponse): { label: React.ReactNode; bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function MessageTool({ block }: Props) {
|
export default function MessageTool({ block }: Props) {
|
||||||
// FIXME: 语义错误,这里已经不是 MCP tool 了
|
// FIXME: 语义错误,这里已经不是 MCP tool 了,更改rawMcpToolResponse需要改用户数据, 所以暂时保留
|
||||||
const toolResponse = block.metadata?.rawMcpToolResponse
|
const toolResponse = block.metadata?.rawMcpToolResponse
|
||||||
|
|
||||||
if (!toolResponse) return null
|
if (!toolResponse) return null
|
||||||
@ -91,7 +91,7 @@ export default function MessageTool({ block }: Props) {
|
|||||||
|
|
||||||
if (!toolRenderer) return null
|
if (!toolRenderer) return null
|
||||||
|
|
||||||
return (
|
return toolRenderer.body ? (
|
||||||
<Collapse
|
<Collapse
|
||||||
items={[
|
items={[
|
||||||
{
|
{
|
||||||
@ -109,6 +109,8 @@ export default function MessageTool({ block }: Props) {
|
|||||||
size="small"
|
size="small"
|
||||||
ghost
|
ghost
|
||||||
/>
|
/>
|
||||||
|
) : (
|
||||||
|
toolRenderer.label
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
// const PrepareToolWrapper = styled.span`
|
// const PrepareToolWrapper = styled.span`
|
||||||
|
|||||||
@ -25,7 +25,7 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT
|
|||||||
<MessageWebSearchToolTitleTextWrapper type="secondary">
|
<MessageWebSearchToolTitleTextWrapper type="secondary">
|
||||||
<Search size={16} style={{ color: 'unset' }} />
|
<Search size={16} style={{ color: 'unset' }} />
|
||||||
{i18n.t('message.websearch.fetch_complete', {
|
{i18n.t('message.websearch.fetch_complete', {
|
||||||
count: toolOutput?.searchResults?.reduce((acc, result) => acc + result.results.length, 0) ?? 0
|
count: toolOutput?.searchResults?.results?.length ?? 0
|
||||||
})}
|
})}
|
||||||
</MessageWebSearchToolTitleTextWrapper>
|
</MessageWebSearchToolTitleTextWrapper>
|
||||||
)
|
)
|
||||||
|
|||||||
@ -380,7 +380,7 @@ export async function fetchGenerate({
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
const middlewareConfig: AiSdkMiddlewareConfig = {
|
const middlewareConfig: AiSdkMiddlewareConfig = {
|
||||||
streamOutput: assistant.settings?.streamOutput ?? true,
|
streamOutput: assistant.settings?.streamOutput ?? false,
|
||||||
enableReasoning: false,
|
enableReasoning: false,
|
||||||
isPromptToolUse: false,
|
isPromptToolUse: false,
|
||||||
isSupportedToolUse: false,
|
isSupportedToolUse: false,
|
||||||
@ -393,6 +393,7 @@ export async function fetchGenerate({
|
|||||||
const result = await AI.completions(
|
const result = await AI.completions(
|
||||||
model.id,
|
model.id,
|
||||||
{
|
{
|
||||||
|
system: prompt,
|
||||||
prompt: content
|
prompt: content
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@ -103,11 +103,11 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
|||||||
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
|
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
|
||||||
|
|
||||||
// Handle citation block creation for web search results
|
// Handle citation block creation for web search results
|
||||||
if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response?.rawResults) {
|
if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response?.searchResults) {
|
||||||
const citationBlock = createCitationBlock(
|
const citationBlock = createCitationBlock(
|
||||||
assistantMsgId,
|
assistantMsgId,
|
||||||
{
|
{
|
||||||
response: { results: toolResponse.response.rawResults, source: WebSearchSource.AISDK }
|
response: { results: toolResponse.response.searchResults, source: WebSearchSource.WEBSEARCH }
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
status: MessageBlockStatus.SUCCESS
|
status: MessageBlockStatus.SUCCESS
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user