mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat: enhance web search functionality and tool integration
- Introduced `extractSearchKeywords` function to facilitate keyword extraction from user messages for web searches. - Updated `webSearchTool` to streamline the execution of web searches without requiring a request ID. - Enhanced `WebSearchService` methods to be static for improved accessibility and clarity. - Modified `ApiService` to pass `webSearchProviderId` for better integration with the web search functionality. - Improved `ToolCallChunkHandler` to handle built-in tools more effectively.
This commit is contained in:
parent
da455997ad
commit
0456094512
@ -67,6 +67,15 @@ export class ToolCallChunkHandler {
|
||||
description: toolName,
|
||||
type: 'provider'
|
||||
}
|
||||
} else if (toolName.startsWith('builtin_')) {
|
||||
// 如果是内置工具,沿用现有逻辑
|
||||
Logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`)
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'builtin'
|
||||
}
|
||||
} else {
|
||||
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
|
||||
Logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
|
||||
|
||||
@ -1,13 +1,16 @@
|
||||
import { extractSearchKeywords } from '@renderer/aiCore/transformParameters'
|
||||
import WebSearchService from '@renderer/services/WebSearchService'
|
||||
import { WebSearchProvider } from '@renderer/types'
|
||||
import aiSdk from 'ai'
|
||||
import { Assistant, Message, WebSearchProvider } from '@renderer/types'
|
||||
import { UserMessageStatus } from '@renderer/types/newMessage'
|
||||
import { ExtractResults } from '@renderer/utils/extract'
|
||||
import * as aiSdk from 'ai'
|
||||
|
||||
import { AiSdkTool, ToolCallResult } from './types'
|
||||
|
||||
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id'], requestId: string): AiSdkTool => {
|
||||
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']): AiSdkTool => {
|
||||
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||
return {
|
||||
name: 'web_search',
|
||||
name: 'builtin_web_search',
|
||||
description: 'Search the web for information',
|
||||
inputSchema: aiSdk.jsonSchema({
|
||||
type: 'object',
|
||||
@ -18,7 +21,9 @@ export const webSearchTool = (webSearchProviderId: WebSearchProvider['id'], requ
|
||||
}),
|
||||
execute: async ({ query }): Promise<ToolCallResult> => {
|
||||
try {
|
||||
const response = await webSearchService.processWebsearch(query, requestId)
|
||||
console.log('webSearchTool', query)
|
||||
const response = await webSearchService.search(query)
|
||||
console.log('webSearchTool response', response)
|
||||
return {
|
||||
success: true,
|
||||
data: response
|
||||
@ -32,3 +37,109 @@ export const webSearchTool = (webSearchProviderId: WebSearchProvider['id'], requ
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const webSearchToolWithExtraction = (
|
||||
webSearchProviderId: WebSearchProvider['id'],
|
||||
requestId: string,
|
||||
assistant: Assistant
|
||||
): AiSdkTool => {
|
||||
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||
|
||||
return {
|
||||
name: 'web_search_with_extraction',
|
||||
description: 'Search the web for information with automatic keyword extraction from user messages',
|
||||
inputSchema: aiSdk.jsonSchema({
|
||||
type: 'object',
|
||||
properties: {
|
||||
userMessage: {
|
||||
type: 'object',
|
||||
description: 'The user message to extract keywords from',
|
||||
properties: {
|
||||
content: { type: 'string', description: 'The main content of the message' },
|
||||
role: { type: 'string', description: 'Message role (user/assistant/system)' }
|
||||
},
|
||||
required: ['content', 'role']
|
||||
},
|
||||
lastAnswer: {
|
||||
type: 'object',
|
||||
description: 'Optional last assistant response for context',
|
||||
properties: {
|
||||
content: { type: 'string', description: 'The main content of the message' },
|
||||
role: { type: 'string', description: 'Message role (user/assistant/system)' }
|
||||
},
|
||||
required: ['content', 'role']
|
||||
}
|
||||
},
|
||||
required: ['userMessage']
|
||||
}),
|
||||
execute: async ({ userMessage, lastAnswer }): Promise<ToolCallResult> => {
|
||||
try {
|
||||
const lastUserMessage: Message = {
|
||||
id: requestId,
|
||||
role: userMessage.role as 'user' | 'assistant' | 'system',
|
||||
assistantId: assistant.id,
|
||||
topicId: 'temp',
|
||||
createdAt: new Date().toISOString(),
|
||||
status: UserMessageStatus.SUCCESS,
|
||||
blocks: []
|
||||
}
|
||||
|
||||
const lastAnswerMessage: Message | undefined = lastAnswer
|
||||
? {
|
||||
id: requestId + '_answer',
|
||||
role: lastAnswer.role as 'user' | 'assistant' | 'system',
|
||||
assistantId: assistant.id,
|
||||
topicId: 'temp',
|
||||
createdAt: new Date().toISOString(),
|
||||
status: UserMessageStatus.SUCCESS,
|
||||
blocks: []
|
||||
}
|
||||
: undefined
|
||||
|
||||
const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
|
||||
shouldWebSearch: true,
|
||||
shouldKnowledgeSearch: false,
|
||||
lastAnswer: lastAnswerMessage
|
||||
})
|
||||
|
||||
if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
|
||||
return {
|
||||
success: false,
|
||||
data: 'No search needed or extraction failed'
|
||||
}
|
||||
}
|
||||
|
||||
const searchQueries = extractResults.websearch.question
|
||||
const searchResults: Array<{ query: string; results: any }> = []
|
||||
|
||||
for (const query of searchQueries) {
|
||||
// 构建单个查询的ExtractResults结构
|
||||
const queryExtractResults: ExtractResults = {
|
||||
websearch: {
|
||||
question: [query],
|
||||
links: extractResults.websearch.links
|
||||
}
|
||||
}
|
||||
const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
|
||||
searchResults.push({
|
||||
query,
|
||||
results: response
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
success: true,
|
||||
data: {
|
||||
extractedKeywords: extractResults.websearch,
|
||||
searchResults
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
success: false,
|
||||
data: error
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -13,6 +13,8 @@ import {
|
||||
TextPart,
|
||||
UserModelMessage
|
||||
} from '@cherrystudio/ai-core'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import {
|
||||
isGenerateImageModel,
|
||||
@ -26,10 +28,18 @@ import {
|
||||
isVisionModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import {
|
||||
SEARCH_SUMMARY_PROMPT,
|
||||
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
|
||||
SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
} from '@renderer/config/prompts'
|
||||
import { getAssistantSettings, getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { getDefaultAssistant } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, MCPTool, Message, Model, Provider } from '@renderer/types'
|
||||
import { FileTypes } from '@renderer/types'
|
||||
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
|
||||
// import { getWebSearchTools } from './utils/websearch'
|
||||
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
|
||||
import {
|
||||
findFileBlocks,
|
||||
findImageBlocks,
|
||||
@ -38,12 +48,12 @@ import {
|
||||
} from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { webSearchTool } from './tools/WebSearchTool'
|
||||
// import { jsonSchemaToZod } from 'json-schema-to-zod'
|
||||
import { setupToolsConfig } from './utils/mcp'
|
||||
import { buildProviderOptions } from './utils/options'
|
||||
// import { getWebSearchTools } from './utils/websearch'
|
||||
|
||||
/**
|
||||
* 获取温度参数
|
||||
@ -289,10 +299,7 @@ export async function buildStreamTextParams(
|
||||
})
|
||||
|
||||
if (webSearchProviderId) {
|
||||
// 生成requestId用于网络搜索工具
|
||||
const requestId = `request_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`
|
||||
|
||||
tools['builtin_web_search'] = webSearchTool(webSearchProviderId, requestId)
|
||||
tools['builtin_web_search'] = webSearchTool(webSearchProviderId)
|
||||
}
|
||||
|
||||
// 构建真正的 providerOptions
|
||||
@ -336,3 +343,103 @@ export async function buildGenerateTextParams(
|
||||
// 复用流式参数的构建逻辑
|
||||
return await buildStreamTextParams(messages, assistant, provider, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 提取外部工具搜索关键词和问题
|
||||
* 从用户消息中提取用于网络搜索和知识库搜索的关键词
|
||||
*/
|
||||
export async function extractSearchKeywords(
|
||||
lastUserMessage: Message,
|
||||
assistant: Assistant,
|
||||
options: {
|
||||
shouldWebSearch?: boolean
|
||||
shouldKnowledgeSearch?: boolean
|
||||
lastAnswer?: Message
|
||||
} = {}
|
||||
): Promise<ExtractResults | undefined> {
|
||||
const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer } = options
|
||||
|
||||
if (!lastUserMessage) return undefined
|
||||
|
||||
// 根据配置决定是否需要提取
|
||||
const needWebExtract = shouldWebSearch
|
||||
const needKnowledgeExtract = shouldKnowledgeSearch
|
||||
|
||||
if (!needWebExtract && !needKnowledgeExtract) return undefined
|
||||
|
||||
// 选择合适的提示词
|
||||
let prompt: string
|
||||
if (needWebExtract && !needKnowledgeExtract) {
|
||||
prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
} else if (!needWebExtract && needKnowledgeExtract) {
|
||||
prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
|
||||
} else {
|
||||
prompt = SEARCH_SUMMARY_PROMPT
|
||||
}
|
||||
|
||||
// 构建用于提取的助手配置
|
||||
const summaryAssistant = getDefaultAssistant()
|
||||
summaryAssistant.model = assistant.model || getDefaultModel()
|
||||
summaryAssistant.prompt = prompt
|
||||
|
||||
try {
|
||||
const result = await fetchSearchSummary({
|
||||
messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
|
||||
assistant: summaryAssistant
|
||||
})
|
||||
|
||||
if (!result) return getFallbackResult()
|
||||
|
||||
const extracted = extractInfoFromXML(result.getText())
|
||||
// 根据需求过滤结果
|
||||
return {
|
||||
websearch: needWebExtract ? extracted?.websearch : undefined,
|
||||
knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined
|
||||
}
|
||||
} catch (e: any) {
|
||||
console.error('extract error', e)
|
||||
return getFallbackResult()
|
||||
}
|
||||
|
||||
function getFallbackResult(): ExtractResults {
|
||||
const fallbackContent = getMainTextContent(lastUserMessage)
|
||||
return {
|
||||
websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
|
||||
knowledge: shouldKnowledgeSearch
|
||||
? {
|
||||
question: [fallbackContent || 'search'],
|
||||
rewrite: fallbackContent || 'search'
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取搜索摘要 - 内部辅助函数
|
||||
*/
|
||||
async function fetchSearchSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
|
||||
const model = assistant.model || getDefaultModel()
|
||||
const provider = getProviderByModel(model)
|
||||
|
||||
if (!hasApiKey(provider)) {
|
||||
return null
|
||||
}
|
||||
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
const params: CompletionsParams = {
|
||||
callType: 'search',
|
||||
messages: messages,
|
||||
assistant,
|
||||
streamOutput: false
|
||||
}
|
||||
|
||||
return await AI.completions(params)
|
||||
}
|
||||
|
||||
function hasApiKey(provider: Provider) {
|
||||
if (!provider) return false
|
||||
if (provider.id === 'ollama' || provider.id === 'lmstudio' || provider.type === 'vertexai') return true
|
||||
return !isEmpty(provider.apiKey)
|
||||
}
|
||||
|
||||
@ -307,6 +307,7 @@ export async function fetchChatCompletion({
|
||||
} = await buildStreamTextParams(messages, assistant, provider, {
|
||||
mcpTools: mcpTools,
|
||||
enableTools: isEnabledToolUse(assistant),
|
||||
webSearchProviderId: assistant.webSearchProviderId,
|
||||
requestOptions: options
|
||||
})
|
||||
|
||||
|
||||
@ -109,7 +109,7 @@ export default class WebSearchService {
|
||||
* @private
|
||||
* @returns 网络搜索状态
|
||||
*/
|
||||
private getWebSearchState(): WebSearchState {
|
||||
private static getWebSearchState(): WebSearchState {
|
||||
return store.getState().websearch
|
||||
}
|
||||
|
||||
@ -118,8 +118,8 @@ export default class WebSearchService {
|
||||
* @public
|
||||
* @returns 如果默认搜索提供商已启用则返回true,否则返回false
|
||||
*/
|
||||
public isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean {
|
||||
const { providers } = this.getWebSearchState()
|
||||
public static isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean {
|
||||
const { providers } = WebSearchService.getWebSearchState()
|
||||
const provider = providers.find((provider) => provider.id === providerId)
|
||||
|
||||
if (!provider) {
|
||||
@ -149,7 +149,7 @@ export default class WebSearchService {
|
||||
* @returns 如果启用覆盖搜索则返回true,否则返回false
|
||||
*/
|
||||
public isOverwriteEnabled(): boolean {
|
||||
const { overwrite } = this.getWebSearchState()
|
||||
const { overwrite } = WebSearchService.getWebSearchState()
|
||||
return overwrite
|
||||
}
|
||||
|
||||
@ -159,7 +159,7 @@ export default class WebSearchService {
|
||||
* @returns 网络搜索提供商
|
||||
*/
|
||||
public getWebSearchProvider(providerId?: string): WebSearchProvider | undefined {
|
||||
const { providers } = this.getWebSearchState()
|
||||
const { providers } = WebSearchService.getWebSearchState()
|
||||
const provider = providers.find((provider) => provider.id === providerId)
|
||||
|
||||
return provider
|
||||
@ -172,7 +172,7 @@ export default class WebSearchService {
|
||||
* @returns 搜索响应
|
||||
*/
|
||||
public async search(query: string, httpOptions?: RequestInit): Promise<WebSearchProviderResponse> {
|
||||
const websearch = this.getWebSearchState()
|
||||
const websearch = WebSearchService.getWebSearchState()
|
||||
const webSearchProvider = this.getWebSearchProvider(this.webSearchProviderId)
|
||||
if (!webSearchProvider) {
|
||||
throw new Error(`WebSearchProvider ${this.webSearchProviderId} not found`)
|
||||
@ -495,7 +495,7 @@ export default class WebSearchService {
|
||||
}
|
||||
}
|
||||
|
||||
const { compressionConfig } = this.getWebSearchState()
|
||||
const { compressionConfig } = WebSearchService.getWebSearchState()
|
||||
|
||||
// RAG压缩处理
|
||||
if (compressionConfig?.method === 'rag' && requestId) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user