feat: enhance web search functionality and tool integration

- Introduced `extractSearchKeywords` function to facilitate keyword extraction from user messages for web searches.
- Updated `webSearchTool` to streamline the execution of web searches without requiring a request ID.
- Enhanced `WebSearchService` methods to be static for improved accessibility and clarity.
- Modified `ApiService` to pass `webSearchProviderId` for better integration with the web search functionality.
- Improved `ToolCallChunkHandler` to handle built-in tools more effectively.
This commit is contained in:
suyao 2025-07-15 23:39:49 +08:00
parent da455997ad
commit 0456094512
No known key found for this signature in database
5 changed files with 246 additions and 18 deletions

View File

@ -67,6 +67,15 @@ export class ToolCallChunkHandler {
description: toolName,
type: 'provider'
}
} else if (toolName.startsWith('builtin_')) {
// 如果是内置工具,沿用现有逻辑
Logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`)
tool = {
id: toolCallId,
name: toolName,
description: toolName,
type: 'builtin'
}
} else {
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
Logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)

View File

@ -1,13 +1,16 @@
import { extractSearchKeywords } from '@renderer/aiCore/transformParameters'
import WebSearchService from '@renderer/services/WebSearchService'
import { WebSearchProvider } from '@renderer/types'
import aiSdk from 'ai'
import { Assistant, Message, WebSearchProvider } from '@renderer/types'
import { UserMessageStatus } from '@renderer/types/newMessage'
import { ExtractResults } from '@renderer/utils/extract'
import * as aiSdk from 'ai'
import { AiSdkTool, ToolCallResult } from './types'
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id'], requestId: string): AiSdkTool => {
export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']): AiSdkTool => {
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
return {
name: 'web_search',
name: 'builtin_web_search',
description: 'Search the web for information',
inputSchema: aiSdk.jsonSchema({
type: 'object',
@ -18,7 +21,9 @@ export const webSearchTool = (webSearchProviderId: WebSearchProvider['id'], requ
}),
execute: async ({ query }): Promise<ToolCallResult> => {
try {
const response = await webSearchService.processWebsearch(query, requestId)
console.log('webSearchTool', query)
const response = await webSearchService.search(query)
console.log('webSearchTool response', response)
return {
success: true,
data: response
@ -32,3 +37,109 @@ export const webSearchTool = (webSearchProviderId: WebSearchProvider['id'], requ
}
}
}
export const webSearchToolWithExtraction = (
webSearchProviderId: WebSearchProvider['id'],
requestId: string,
assistant: Assistant
): AiSdkTool => {
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
return {
name: 'web_search_with_extraction',
description: 'Search the web for information with automatic keyword extraction from user messages',
inputSchema: aiSdk.jsonSchema({
type: 'object',
properties: {
userMessage: {
type: 'object',
description: 'The user message to extract keywords from',
properties: {
content: { type: 'string', description: 'The main content of the message' },
role: { type: 'string', description: 'Message role (user/assistant/system)' }
},
required: ['content', 'role']
},
lastAnswer: {
type: 'object',
description: 'Optional last assistant response for context',
properties: {
content: { type: 'string', description: 'The main content of the message' },
role: { type: 'string', description: 'Message role (user/assistant/system)' }
},
required: ['content', 'role']
}
},
required: ['userMessage']
}),
execute: async ({ userMessage, lastAnswer }): Promise<ToolCallResult> => {
try {
const lastUserMessage: Message = {
id: requestId,
role: userMessage.role as 'user' | 'assistant' | 'system',
assistantId: assistant.id,
topicId: 'temp',
createdAt: new Date().toISOString(),
status: UserMessageStatus.SUCCESS,
blocks: []
}
const lastAnswerMessage: Message | undefined = lastAnswer
? {
id: requestId + '_answer',
role: lastAnswer.role as 'user' | 'assistant' | 'system',
assistantId: assistant.id,
topicId: 'temp',
createdAt: new Date().toISOString(),
status: UserMessageStatus.SUCCESS,
blocks: []
}
: undefined
const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
shouldWebSearch: true,
shouldKnowledgeSearch: false,
lastAnswer: lastAnswerMessage
})
if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
return {
success: false,
data: 'No search needed or extraction failed'
}
}
const searchQueries = extractResults.websearch.question
const searchResults: Array<{ query: string; results: any }> = []
for (const query of searchQueries) {
// 构建单个查询的ExtractResults结构
const queryExtractResults: ExtractResults = {
websearch: {
question: [query],
links: extractResults.websearch.links
}
}
const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
searchResults.push({
query,
results: response
})
}
return {
success: true,
data: {
extractedKeywords: extractResults.websearch,
searchResults
}
}
} catch (error) {
return {
success: false,
data: error
}
}
}
}
}

View File

@ -13,6 +13,8 @@ import {
TextPart,
UserModelMessage
} from '@cherrystudio/ai-core'
import AiProvider from '@renderer/aiCore'
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
isGenerateImageModel,
@ -26,10 +28,18 @@ import {
isVisionModel,
isWebSearchModel
} from '@renderer/config/models'
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import {
SEARCH_SUMMARY_PROMPT,
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
SEARCH_SUMMARY_PROMPT_WEB_ONLY
} from '@renderer/config/prompts'
import { getAssistantSettings, getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService'
import { getDefaultAssistant } from '@renderer/services/AssistantService'
import type { Assistant, MCPTool, Message, Model, Provider } from '@renderer/types'
import { FileTypes } from '@renderer/types'
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
// import { getWebSearchTools } from './utils/websearch'
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
import {
findFileBlocks,
findImageBlocks,
@ -38,12 +48,12 @@ import {
} from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { defaultTimeout } from '@shared/config/constant'
import { isEmpty } from 'lodash'
import { webSearchTool } from './tools/WebSearchTool'
// import { jsonSchemaToZod } from 'json-schema-to-zod'
import { setupToolsConfig } from './utils/mcp'
import { buildProviderOptions } from './utils/options'
// import { getWebSearchTools } from './utils/websearch'
/**
*
@ -289,10 +299,7 @@ export async function buildStreamTextParams(
})
if (webSearchProviderId) {
// 生成requestId用于网络搜索工具
const requestId = `request_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`
tools['builtin_web_search'] = webSearchTool(webSearchProviderId, requestId)
tools['builtin_web_search'] = webSearchTool(webSearchProviderId)
}
// 构建真正的 providerOptions
@ -336,3 +343,103 @@ export async function buildGenerateTextParams(
// 复用流式参数的构建逻辑
return await buildStreamTextParams(messages, assistant, provider, options)
}
/**
*
*
*/
export async function extractSearchKeywords(
lastUserMessage: Message,
assistant: Assistant,
options: {
shouldWebSearch?: boolean
shouldKnowledgeSearch?: boolean
lastAnswer?: Message
} = {}
): Promise<ExtractResults | undefined> {
const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer } = options
if (!lastUserMessage) return undefined
// 根据配置决定是否需要提取
const needWebExtract = shouldWebSearch
const needKnowledgeExtract = shouldKnowledgeSearch
if (!needWebExtract && !needKnowledgeExtract) return undefined
// 选择合适的提示词
let prompt: string
if (needWebExtract && !needKnowledgeExtract) {
prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
} else if (!needWebExtract && needKnowledgeExtract) {
prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
} else {
prompt = SEARCH_SUMMARY_PROMPT
}
// 构建用于提取的助手配置
const summaryAssistant = getDefaultAssistant()
summaryAssistant.model = assistant.model || getDefaultModel()
summaryAssistant.prompt = prompt
try {
const result = await fetchSearchSummary({
messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
assistant: summaryAssistant
})
if (!result) return getFallbackResult()
const extracted = extractInfoFromXML(result.getText())
// 根据需求过滤结果
return {
websearch: needWebExtract ? extracted?.websearch : undefined,
knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined
}
} catch (e: any) {
console.error('extract error', e)
return getFallbackResult()
}
function getFallbackResult(): ExtractResults {
const fallbackContent = getMainTextContent(lastUserMessage)
return {
websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
knowledge: shouldKnowledgeSearch
? {
question: [fallbackContent || 'search'],
rewrite: fallbackContent || 'search'
}
: undefined
}
}
}
/**
* -
*/
async function fetchSearchSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
const model = assistant.model || getDefaultModel()
const provider = getProviderByModel(model)
if (!hasApiKey(provider)) {
return null
}
const AI = new AiProvider(provider)
const params: CompletionsParams = {
callType: 'search',
messages: messages,
assistant,
streamOutput: false
}
return await AI.completions(params)
}
function hasApiKey(provider: Provider) {
if (!provider) return false
if (provider.id === 'ollama' || provider.id === 'lmstudio' || provider.type === 'vertexai') return true
return !isEmpty(provider.apiKey)
}

View File

@ -307,6 +307,7 @@ export async function fetchChatCompletion({
} = await buildStreamTextParams(messages, assistant, provider, {
mcpTools: mcpTools,
enableTools: isEnabledToolUse(assistant),
webSearchProviderId: assistant.webSearchProviderId,
requestOptions: options
})

View File

@ -109,7 +109,7 @@ export default class WebSearchService {
* @private
* @returns
*/
private getWebSearchState(): WebSearchState {
private static getWebSearchState(): WebSearchState {
return store.getState().websearch
}
@ -118,8 +118,8 @@ export default class WebSearchService {
* @public
* @returns truefalse
*/
public isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean {
const { providers } = this.getWebSearchState()
public static isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean {
const { providers } = WebSearchService.getWebSearchState()
const provider = providers.find((provider) => provider.id === providerId)
if (!provider) {
@ -149,7 +149,7 @@ export default class WebSearchService {
* @returns truefalse
*/
public isOverwriteEnabled(): boolean {
const { overwrite } = this.getWebSearchState()
const { overwrite } = WebSearchService.getWebSearchState()
return overwrite
}
@ -159,7 +159,7 @@ export default class WebSearchService {
* @returns
*/
public getWebSearchProvider(providerId?: string): WebSearchProvider | undefined {
const { providers } = this.getWebSearchState()
const { providers } = WebSearchService.getWebSearchState()
const provider = providers.find((provider) => provider.id === providerId)
return provider
@ -172,7 +172,7 @@ export default class WebSearchService {
* @returns
*/
public async search(query: string, httpOptions?: RequestInit): Promise<WebSearchProviderResponse> {
const websearch = this.getWebSearchState()
const websearch = WebSearchService.getWebSearchState()
const webSearchProvider = this.getWebSearchProvider(this.webSearchProviderId)
if (!webSearchProvider) {
throw new Error(`WebSearchProvider ${this.webSearchProviderId} not found`)
@ -495,7 +495,7 @@ export default class WebSearchService {
}
}
const { compressionConfig } = this.getWebSearchState()
const { compressionConfig } = WebSearchService.getWebSearchState()
// RAG压缩处理
if (compressionConfig?.method === 'rag' && requestId) {