feat(aiCore): add MemorySearchTool and WebSearchTool components

- Introduced MessageMemorySearch and MessageWebSearch components for handling memory and web search tool responses.
- Updated MemorySearchTool and WebSearchTool to improve response handling and integrate with the new components.
- Removed unused console logs and streamlined code for better readability and maintainability.
- Added new dependencies in package.json for enhanced functionality.
This commit is contained in:
MyPrototypeWhat 2025-08-26 17:59:52 +08:00
parent 82d4637c9d
commit 0c7e221b4e
11 changed files with 142 additions and 109 deletions

View File

@ -169,6 +169,7 @@
"@viz-js/lang-dot": "^1.0.5",
"@viz-js/viz": "^3.14.0",
"@xyflow/react": "^12.4.4",
"ai": "^5.0.24",
"antd": "patch:antd@npm%3A5.26.7#~/.yarn/patches/antd-npm-5.26.7-029c5c381a.patch",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",

View File

@ -57,12 +57,10 @@ Call this tool to execute the search. You can optionally provide additional cont
if (additionalContext?.trim()) {
// 如果大模型提供了额外上下文,使用更具体的描述
console.log(`🔍 AI enhanced knowledge search with: ${additionalContext}`)
const cleanContext = additionalContext.trim()
if (cleanContext) {
finalQueries = [cleanContext]
finalRewrite = cleanContext
console.log(` Added additional context: ${cleanContext}`)
}
}
@ -101,8 +99,6 @@ Call this tool to execute the search. You can optionally provide additional cont
knowledge: searchCriteria
}
console.log('Knowledge search extractResults:', extractResults)
// 执行知识库搜索
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds, topicId)
const knowledgeReferencesData = knowledgeReferences.map((ref: KnowledgeReference) => ({
@ -131,8 +127,6 @@ Call this tool to execute the search. You can optionally provide additional cont
// rawResults: citationData
}
} catch (error) {
console.error('🔍 [KnowledgeSearchTool] Search failed:', error)
// 返回空对象而不是抛出错误,避免中断对话流程
return {
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,

View File

@ -1,13 +1,11 @@
import { aiSdk, type InferToolOutput } from '@cherrystudio/ai-core'
import store from '@renderer/store'
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
import type { Assistant } from '@renderer/types'
// import { type InferToolOutput, tool } from 'ai'
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
import { z } from 'zod'
import { MemoryProcessor } from '../../services/MemoryProcessor'
const { tool } = aiSdk
/**
* 🧠
* AI
@ -21,7 +19,7 @@ export const memorySearchTool = () => {
limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return')
}),
execute: async ({ query, limit = 5 }) => {
console.log('🧠 [memorySearchTool] Searching memories:', { query, limit })
// console.log('🧠 [memorySearchTool] Searching memories:', { query, limit })
try {
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
@ -31,7 +29,7 @@ export const memorySearchTool = () => {
const memoryConfig = selectMemoryConfig(store.getState())
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
console.warn('Memory search skipped: embedding or LLM model not configured')
// console.warn('Memory search skipped: embedding or LLM model not configured')
return []
}
@ -42,18 +40,29 @@ export const memorySearchTool = () => {
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
if (relevantMemories?.length > 0) {
console.log('🧠 [memorySearchTool] Found memories:', relevantMemories.length)
// console.log('🧠 [memorySearchTool] Found memories:', relevantMemories.length)
return relevantMemories
}
return []
} catch (error) {
console.error('🧠 [memorySearchTool] Error:', error)
// console.error('🧠 [memorySearchTool] Error:', error)
return []
}
}
})
}
// 方案4: 为第二个工具也使用类型断言
type MessageRole = 'user' | 'assistant' | 'system'
type MessageType = {
content: string
role: MessageRole
}
type MemorySearchWithExtractionInput = {
userMessage: MessageType
lastAnswer?: MessageType
}
/**
* 🧠
*
@ -73,9 +82,9 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
})
.optional()
}),
execute: async ({ userMessage, lastAnswer }) => {
console.log('🧠 [memorySearchToolWithExtraction] Processing:', { userMessage, lastAnswer })
}) as z.ZodSchema<MemorySearchWithExtractionInput>,
execute: async ({ userMessage }) => {
// console.log('🧠 [memorySearchToolWithExtraction] Processing:', { userMessage, lastAnswer })
try {
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
@ -88,7 +97,7 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
const memoryConfig = selectMemoryConfig(store.getState())
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
console.warn('Memory search skipped: embedding or LLM model not configured')
// console.warn('Memory search skipped: embedding or LLM model not configured')
return {
extractedKeywords: 'Memory models not configured',
searchResults: []
@ -116,7 +125,7 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
)
if (relevantMemories?.length > 0) {
console.log('🧠 [memorySearchToolWithExtraction] Found memories:', relevantMemories.length)
// console.log('🧠 [memorySearchToolWithExtraction] Found memories:', relevantMemories.length)
return {
extractedKeywords: content,
searchResults: relevantMemories
@ -128,7 +137,7 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
searchResults: []
}
} catch (error) {
console.error('🧠 [memorySearchToolWithExtraction] Error:', error)
// console.error('🧠 [memorySearchToolWithExtraction] Error:', error)
return {
extractedKeywords: 'Search failed',
searchResults: []
@ -137,6 +146,6 @@ export const memorySearchToolWithExtraction = (assistant: Assistant) => {
}
})
}
export type MemorySearchToolInput = InferToolInput<ReturnType<typeof memorySearchTool>>
export type MemorySearchToolOutput = InferToolOutput<ReturnType<typeof memorySearchTool>>
export type MemorySearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof memorySearchToolWithExtraction>>

View File

@ -1,12 +1,10 @@
import { aiSdk, InferToolInput, InferToolOutput } from '@cherrystudio/ai-core'
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import WebSearchService from '@renderer/services/WebSearchService'
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import { ExtractResults } from '@renderer/utils/extract'
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
import { z } from 'zod'
const { tool } = aiSdk
/**
* 使
* 使
@ -84,7 +82,6 @@ Call this tool to execute the search. You can optionally provide additional cont
instructions: ''
}
}
console.log('searchResults', searchResults)
if (searchResults.results.length === 0) {
return {
summary: 'No search results found for the given query.',

View File

@ -1,8 +0,0 @@
// import { Tool } from '@cherrystudio/ai-core'
// export type ToolCallResult = {
// success: boolean
// data: any
// }
// export type AiSdkTool = Tool<any, ToolCallResult>

View File

@ -37,8 +37,8 @@ const SearchWrapper = styled.div`
display: flex;
align-items: center;
gap: 4px;
font-size: 14px;
padding: 10px;
padding-left: 0;
/* font-size: 14px; */
padding: 0px;
/* padding-left: 0; */
`
const Searching = motion.create(SearchWrapper)

View File

@ -0,0 +1,41 @@
import { MemorySearchToolInput, MemorySearchToolOutput } from '@renderer/aiCore/tools/MemorySearchTool'
import Spinner from '@renderer/components/Spinner'
import { MCPToolResponse } from '@renderer/types'
import { Typography } from 'antd'
import { ChevronRight } from 'lucide-react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
const { Text } = Typography
export const MessageMemorySearchToolTitle = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
const { t } = useTranslation()
const toolInput = toolResponse.arguments as MemorySearchToolInput
const toolOutput = toolResponse.response as MemorySearchToolOutput
return toolResponse.status !== 'done' ? (
<Spinner
text={
<MessageWebSearchToolTitleTextWrapper>
{t('memory.search_placeholder')}
<span>{toolInput?.query ?? ''}</span>
</MessageWebSearchToolTitleTextWrapper>
}
/>
) : toolOutput?.length ? (
<MessageWebSearchToolTitleTextWrapper type="secondary">
<ChevronRight size={16} style={{ color: 'unset' }} />
{/* <Search size={16} style={{ color: 'unset' }} /> */}
<span>{toolOutput?.length ?? 0}</span>
{t('memory.memory')}
</MessageWebSearchToolTitleTextWrapper>
) : null
}
const MessageWebSearchToolTitleTextWrapper = styled(Text)`
display: flex;
align-items: center;
gap: 4px;
padding: 5px;
padding-left: 0;
`

View File

@ -3,61 +3,14 @@ import type { ToolMessageBlock } from '@renderer/types/newMessage'
import { Collapse } from 'antd'
import { MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
import { MessageWebSearchToolTitle } from './MessageWebSearchTool'
import { MessageMemorySearchToolTitle } from './MessageMemorySearch'
import { MessageWebSearchToolTitle } from './MessageWebSearch'
interface Props {
block: ToolMessageBlock
}
const prefix = 'builtin_'
// const toolNameMapText = {
// web_search: i18n.t('message.searching')
// }
// const toolDoneNameMapText = (args: Record<string, any>) => {
// const count = args.count ?? 0
// return i18n.t('message.websearch.fetch_complete', { count })
// }
// const PrepareTool = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
// const toolNameText = useMemo(
// () => toolNameMapText[toolResponse.tool.name] || toolResponse.tool.name,
// [toolResponse.tool]
// )
// return (
// <Spinner
// text={
// <PrepareToolWrapper>
// {toolNameText}
// <span>{JSON.stringify(toolResponse.arguments)}</span>
// </PrepareToolWrapper>
// }
// />
// )
// }
// const DoneTool = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
// const toolDoneNameText = useMemo(
// () => toolDoneNameMapText({ count: toolResponse.response?.data?.length ?? 0 }),
// [toolResponse.response]
// )
// return <p>{toolDoneNameText}</p>
// }
// const ToolLabelComponents = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
// if (webSearchToolNames.includes(toolResponse.tool.name)) {
// return <MessageWebSearchToolTitle toolResponse={toolResponse} />
// }
// return <MessageWebSearchToolTitle toolResponse={toolResponse} />
// }
// const ToolBodyComponents = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
// if (webSearchToolNames.includes(toolResponse.tool.name)) {
// return <MessageWebSearchToolBody toolResponse={toolResponse} />
// }
// return <MessageWebSearchToolBody toolResponse={toolResponse} />
// }
const ChooseTool = (toolResponse: MCPToolResponse): { label: React.ReactNode; body: React.ReactNode } | null => {
let toolName = toolResponse.tool.name
if (toolName.startsWith(prefix)) {
@ -76,6 +29,11 @@ const ChooseTool = (toolResponse: MCPToolResponse): { label: React.ReactNode; bo
label: <MessageKnowledgeSearchToolTitle toolResponse={toolResponse} />,
body: null
}
case 'memory_search':
return {
label: <MessageMemorySearchToolTitle toolResponse={toolResponse} />,
body: null
}
default:
return null
}

View File

@ -6,7 +6,7 @@ import MessageTool from './MessageTool'
interface Props {
block: ToolMessageBlock
}
// TODO: 知识库tool
export default function MessageTools({ block }: Props) {
const toolResponse = block.metadata?.rawMcpToolResponse
if (!toolResponse) return null

View File

@ -1,14 +1,15 @@
import { WebSearchToolInput, WebSearchToolOutput } from '@renderer/aiCore/tools/WebSearchTool'
import Spinner from '@renderer/components/Spinner'
import i18n from '@renderer/i18n'
import { MCPToolResponse } from '@renderer/types'
import { Typography } from 'antd'
import { Search } from 'lucide-react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
const { Text, Link } = Typography
const { Text } = Typography
export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
const { t } = useTranslation()
const toolInput = toolResponse.arguments as WebSearchToolInput
const toolOutput = toolResponse.response as WebSearchToolOutput
@ -16,7 +17,7 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT
<Spinner
text={
<PrepareToolWrapper>
{i18n.t('message.searching')}
{t('message.searching')}
<span>{toolInput?.additionalContext ?? ''}</span>
</PrepareToolWrapper>
}
@ -24,28 +25,28 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT
) : (
<MessageWebSearchToolTitleTextWrapper type="secondary">
<Search size={16} style={{ color: 'unset' }} />
{i18n.t('message.websearch.fetch_complete', {
{t('message.websearch.fetch_complete', {
count: toolOutput?.searchResults?.results?.length ?? 0
})}
</MessageWebSearchToolTitleTextWrapper>
)
}
export const MessageWebSearchToolBody = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
const toolOutput = toolResponse.response as WebSearchToolOutput
// export const MessageWebSearchToolBody = ({ toolResponse }: { toolResponse: MCPToolResponse }) => {
// const toolOutput = toolResponse.response as WebSearchToolOutput
return toolResponse.status === 'done'
? toolOutput?.searchResults?.map((result, index) => (
<MessageWebSearchToolBodyUlWrapper key={result?.query ?? '' + index}>
{result.results.map((item, index) => (
<li key={item.url + index}>
<Link href={item.url}>{item.title}</Link>
</li>
))}
</MessageWebSearchToolBodyUlWrapper>
))
: null
}
// return toolResponse.status === 'done'
// ? toolOutput?.searchResults?.map((result, index) => (
// <MessageWebSearchToolBodyUlWrapper key={result?.query ?? '' + index}>
// {result.results.map((item, index) => (
// <li key={item.url + index}>
// <Link href={item.url}>{item.title}</Link>
// </li>
// ))}
// </MessageWebSearchToolBodyUlWrapper>
// ))
// : null
// }
const PrepareToolWrapper = styled.span`
display: flex;
@ -61,9 +62,9 @@ const MessageWebSearchToolTitleTextWrapper = styled(Text)`
gap: 4px;
`
const MessageWebSearchToolBodyUlWrapper = styled.ul`
display: flex;
flex-direction: column;
gap: 4px;
padding: 0;
`
// const MessageWebSearchToolBodyUlWrapper = styled.ul`
// display: flex;
// flex-direction: column;
// gap: 4px;
// padding: 0;
// `

View File

@ -140,6 +140,18 @@ __metadata:
languageName: node
linkType: hard
"@ai-sdk/gateway@npm:1.0.13":
version: 1.0.13
resolution: "@ai-sdk/gateway@npm:1.0.13"
dependencies:
"@ai-sdk/provider": "npm:2.0.0"
"@ai-sdk/provider-utils": "npm:3.0.6"
peerDependencies:
zod: ^3.25.76 || ^4
checksum: 10c0/cfd655440d68e1e99204cdf3a4c30fcd9d9dd21d2b6851932584fc228617e9f844b0419c5b74e8e8ee9c00256b64e34c9346369cebe43f9c76e6f31203ca4b8c
languageName: node
linkType: hard
"@ai-sdk/gateway@npm:1.0.8":
version: 1.0.8
resolution: "@ai-sdk/gateway@npm:1.0.8"
@ -281,6 +293,19 @@ __metadata:
languageName: node
linkType: hard
"@ai-sdk/provider-utils@npm:3.0.6":
version: 3.0.6
resolution: "@ai-sdk/provider-utils@npm:3.0.6"
dependencies:
"@ai-sdk/provider": "npm:2.0.0"
"@standard-schema/spec": "npm:^1.0.0"
eventsource-parser: "npm:^3.0.3"
peerDependencies:
zod: ^3.25.76 || ^4
checksum: 10c0/00f8d5a4e76f66aad6ebde97d3512b8295c144c6802465ec0d7b38c6eaa0fc97fbcaeb51284f66bb7e42bcf38a92ed0edba95c1ccddf50af3040d2b64227352a
languageName: node
linkType: hard
"@ai-sdk/provider@npm:2.0.0, @ai-sdk/provider@npm:^2.0.0":
version: 2.0.0
resolution: "@ai-sdk/provider@npm:2.0.0"
@ -8966,6 +8991,7 @@ __metadata:
"@viz-js/lang-dot": "npm:^1.0.5"
"@viz-js/viz": "npm:^3.14.0"
"@xyflow/react": "npm:^12.4.4"
ai: "npm:^5.0.24"
antd: "patch:antd@npm%3A5.26.7#~/.yarn/patches/antd-npm-5.26.7-029c5c381a.patch"
archiver: "npm:^7.0.1"
async-mutex: "npm:^0.5.0"
@ -9190,6 +9216,20 @@ __metadata:
languageName: node
linkType: hard
"ai@npm:^5.0.24":
version: 5.0.24
resolution: "ai@npm:5.0.24"
dependencies:
"@ai-sdk/gateway": "npm:1.0.13"
"@ai-sdk/provider": "npm:2.0.0"
"@ai-sdk/provider-utils": "npm:3.0.6"
"@opentelemetry/api": "npm:1.9.0"
peerDependencies:
zod: ^3.25.76 || ^4
checksum: 10c0/a5735ee935a9499f1b87fa99d0d2aa87732db0e01dfdfd38ffa233c9e7242572f4442c15fe3f6e441b1ca62491a72ca603ca782f3f0ed7bf42c57fa7d6f2e891
languageName: node
linkType: hard
"ajv-formats@npm:^2.1.1":
version: 2.1.1
resolution: "ajv-formats@npm:2.1.1"