diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index 9f98c22efa..cb840e2d79 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -21,6 +21,7 @@ export { createModel, type ModelConfig } from './core/models' // ==================== 插件系统 ==================== export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './core/plugins' export { createContext, definePlugin, PluginManager } from './core/plugins' +export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in' export { PluginEngine } from './core/runtime/pluginEngine' // ==================== 低级 API ==================== diff --git a/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts index 9cc9960807..3cb4df1d19 100644 --- a/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts +++ b/src/renderer/src/aiCore/AiSdkToChunkAdapter.ts @@ -4,7 +4,7 @@ */ import { TextStreamPart, ToolSet } from '@cherrystudio/ai-core' -import { MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types' +import { BaseTool, WebSearchResults, WebSearchSource } from '@renderer/types' import { Chunk, ChunkType } from '@renderer/types/chunk' import { ToolCallChunkHandler } from './chunk/handleTooCallChunk' @@ -27,7 +27,7 @@ export class AiSdkToChunkAdapter { toolCallHandler: ToolCallChunkHandler constructor( private onChunk: (chunk: Chunk) => void, - private mcpTools: MCPTool[] = [] + private mcpTools: BaseTool[] = [] ) { this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools) } diff --git a/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts b/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts index 70134c5e43..2140950e19 100644 --- a/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts +++ b/src/renderer/src/aiCore/chunk/handleTooCallChunk.ts @@ -4,22 +4,15 @@ * 提供工具调用相关的处理API,每个交互使用一个新的实例 */ -import { ToolCallUnion, ToolResultUnion, ToolSet } from '@cherrystudio/ai-core/index' +import { ToolCallUnion, ToolResultUnion, ToolSet } from '@cherrystudio/ai-core' import Logger from '@renderer/config/logger' -import { MCPTool, MCPToolResponse } from '@renderer/types' +import { BaseTool, MCPToolResponse } from '@renderer/types' import { Chunk, ChunkType } from '@renderer/types/chunk' // import type { // AnthropicSearchOutput, // WebSearchPluginConfig // } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin' -// 为 Provider 执行的工具创建一个通用类型 -// 这避免了污染 MCPTool 的定义,同时提供了 UI 显示所需的基本信息 -type GenericProviderTool = { - name: string - description: string - type: 'provider' -} /** * 工具调用处理器类 */ @@ -32,12 +25,12 @@ export class ToolCallChunkHandler { toolName: string args: any // mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型 - mcpTool: MCPTool | GenericProviderTool + mcpTool: BaseTool } >() constructor( private onChunk: (chunk: Chunk) => void, - private mcpTools: MCPTool[] + private mcpTools: BaseTool[] ) {} // /** @@ -62,13 +55,14 @@ export class ToolCallChunkHandler { return } - let tool: MCPTool | GenericProviderTool + let tool: BaseTool // 根据 providerExecuted 标志区分处理逻辑 if (providerExecuted) { // 如果是 Provider 执行的工具(如 web_search) Logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`) tool = { + id: toolCallId, name: toolName, description: toolName, type: 'provider' diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 597dcd9835..7514961154 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -17,7 +17,7 @@ import { type ProviderSettingsMap, StreamTextParams } from '@cherrystudio/ai-core' -import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/core/plugins/built-in' +import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core' import { isDedicatedImageGenerationModel } from '@renderer/config/models' import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import { getProviderByModel } from '@renderer/services/AssistantService' diff --git a/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts index dd5461a639..0f18d4e195 100644 --- a/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts @@ -3,7 +3,7 @@ import { LanguageModelV2Middleware, simulateStreamingMiddleware } from '@cherrystudio/ai-core' -import type { MCPTool, Model, Provider } from '@renderer/types' +import type { BaseTool, Model, Provider } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' /** @@ -18,7 +18,7 @@ export interface AiSdkMiddlewareConfig { // 是否开启提示词工具调用 enableTool?: boolean enableWebSearch?: boolean - mcpTools?: MCPTool[] + mcpTools?: BaseTool[] } /** diff --git a/src/renderer/src/aiCore/tools/WebSearchTool.ts b/src/renderer/src/aiCore/tools/WebSearchTool.ts new file mode 100644 index 0000000000..52a6d743a3 --- /dev/null +++ b/src/renderer/src/aiCore/tools/WebSearchTool.ts @@ -0,0 +1,34 @@ +import WebSearchService from '@renderer/services/WebSearchService' +import { WebSearchProvider } from '@renderer/types' +import aiSdk from 'ai' + +import { AiSdkTool, ToolCallResult } from './types' + +export const webSearchTool = (webSearchProviderId: WebSearchProvider['id'], requestId: string): AiSdkTool => { + const webSearchService = WebSearchService.getInstance(webSearchProviderId) + return { + name: 'web_search', + description: 'Search the web for information', + inputSchema: aiSdk.jsonSchema({ + type: 'object', + properties: { + query: { type: 'string', description: 'The query to search for' } + }, + required: ['query'] + }), + execute: async ({ query }): Promise => { + try { + const response = await webSearchService.processWebsearch(query, requestId) + return { + success: true, + data: response + } + } catch (error) { + return { + success: false, + data: error + } + } + } + } +} diff --git a/src/renderer/src/aiCore/tools/types.ts b/src/renderer/src/aiCore/tools/types.ts new file mode 100644 index 0000000000..23591c7f00 --- /dev/null +++ b/src/renderer/src/aiCore/tools/types.ts @@ -0,0 +1,8 @@ +import { Tool } from '@cherrystudio/ai-core' + +export type ToolCallResult = { + success: boolean + data: any +} + +export type AiSdkTool = Tool diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts index af8618f69e..82e46c1cd9 100644 --- a/src/renderer/src/aiCore/transformParameters.ts +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -39,6 +39,7 @@ import { import { buildSystemPrompt } from '@renderer/utils/prompt' import { defaultTimeout } from '@shared/config/constant' +import { webSearchTool } from './tools/WebSearchTool' // import { jsonSchemaToZod } from 'json-schema-to-zod' import { setupToolsConfig } from './utils/mcp' import { buildProviderOptions } from './utils/options' @@ -246,6 +247,7 @@ export async function buildStreamTextParams( mcpTools?: MCPTool[] enableTools?: boolean enableWebSearch?: boolean + webSearchProviderId?: string requestOptions?: { signal?: AbortSignal timeout?: number @@ -257,7 +259,7 @@ export async function buildStreamTextParams( modelId: string capabilities: { enableReasoning?: boolean; enableWebSearch?: boolean; enableGenerateImage?: boolean } }> { - const { mcpTools, enableTools } = options + const { mcpTools, enableTools, webSearchProviderId } = options const model = assistant.model || getDefaultModel() @@ -286,6 +288,13 @@ export async function buildStreamTextParams( enableToolUse: enableTools }) + if (webSearchProviderId) { + // 生成requestId用于网络搜索工具 + const requestId = `request_${Date.now()}_${Math.random().toString(36).substring(2, 9)}` + + tools['builtin_web_search'] = webSearchTool(webSearchProviderId, requestId) + } + // 构建真正的 providerOptions const providerOptions = buildProviderOptions(assistant, model, provider, { enableReasoning, diff --git a/src/renderer/src/aiCore/utils/mcp.ts b/src/renderer/src/aiCore/utils/mcp.ts index 39c36deb4f..26a77c4f93 100644 --- a/src/renderer/src/aiCore/utils/mcp.ts +++ b/src/renderer/src/aiCore/utils/mcp.ts @@ -1,17 +1,11 @@ import { aiSdk, Tool } from '@cherrystudio/ai-core' +import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types' import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant' import { isFunctionCallingModel } from '@renderer/config/models' -import { MCPCallToolResponse, MCPTool, MCPToolResponse, Model } from '@renderer/types' +import { MCPTool, MCPToolResponse, Model } from '@renderer/types' import { callMCPTool } from '@renderer/utils/mcp-tools' import { JSONSchema7 } from 'json-schema' -type ToolCallResult = { - success: boolean - data: MCPCallToolResponse -} - -type AiSdkTool = Tool - // Setup tools configuration based on provided parameters export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): { tools: Record diff --git a/src/renderer/src/services/WebSearchService.ts b/src/renderer/src/services/WebSearchService.ts index efb726a0aa..ea66fe7d54 100644 --- a/src/renderer/src/services/WebSearchService.ts +++ b/src/renderer/src/services/WebSearchService.ts @@ -36,7 +36,21 @@ interface RequestState { /** * 提供网络搜索相关功能的服务类 */ -class WebSearchService { +export default class WebSearchService { + private static instance: WebSearchService + private webSearchProviderId: WebSearchProvider['id'] + + private constructor(webSearchProviderId: WebSearchProvider['id']) { + this.webSearchProviderId = webSearchProviderId + } + + public static getInstance(webSearchProviderId: WebSearchProvider['id']): WebSearchService { + if (!WebSearchService.instance) { + WebSearchService.instance = new WebSearchService(webSearchProviderId) + } + return WebSearchService.instance + } + /** * 是否暂停 */ @@ -154,17 +168,16 @@ class WebSearchService { /** * 使用指定的提供商执行网络搜索 * @public - * @param provider 搜索提供商 * @param query 搜索查询 * @returns 搜索响应 */ - public async search( - provider: WebSearchProvider, - query: string, - httpOptions?: RequestInit - ): Promise { + public async search(query: string, httpOptions?: RequestInit): Promise { const websearch = this.getWebSearchState() - const webSearchEngine = new WebSearchEngineProvider(provider) + const webSearchProvider = this.getWebSearchProvider(this.webSearchProviderId) + if (!webSearchProvider) { + throw new Error(`WebSearchProvider ${this.webSearchProviderId} not found`) + } + const webSearchEngine = new WebSearchEngineProvider(webSearchProvider) let formattedQuery = query // FIXME: 有待商榷,效果一般 @@ -186,9 +199,9 @@ class WebSearchService { * @param provider 要检查的搜索提供商 * @returns 如果提供商可用返回true,否则返回false */ - public async checkSearch(provider: WebSearchProvider): Promise<{ valid: boolean; error?: any }> { + public async checkSearch(): Promise<{ valid: boolean; error?: any }> { try { - const response = await this.search(provider, 'test query') + const response = await this.search('test query') Logger.log('[checkSearch] Search response:', response) // 优化的判断条件:检查结果是否有效且没有错误 return { valid: response.results !== undefined, error: undefined } @@ -423,11 +436,7 @@ class WebSearchService { * * @returns 包含搜索结果的响应对象 */ - public async processWebsearch( - webSearchProvider: WebSearchProvider, - extractResults: ExtractResults, - requestId: string - ): Promise { + public async processWebsearch(extractResults: ExtractResults, requestId: string): Promise { // 重置状态 await this.setWebSearchStatus(requestId, { phase: 'default' }) @@ -449,7 +458,7 @@ class WebSearchService { return { query: 'summaries', results: contents } } - const searchPromises = questions.map((q) => this.search(webSearchProvider, q, { signal })) + const searchPromises = questions.map((q) => this.search(q, { signal })) const searchResults = await Promise.allSettled(searchPromises) // 统计成功完成的搜索数量 @@ -532,5 +541,3 @@ class WebSearchService { } } } - -export default new WebSearchService() diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 38a038c8ac..258cb2a804 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -4,12 +4,7 @@ import type OpenAI from 'openai' import type { CSSProperties } from 'react' import type { Message } from './newMessage' - -export type GenericProviderTool = { - name: string - description: string - type: 'provider' -} +import type { BaseTool, MCPTool } from './tool' export type Assistant = { id: string @@ -616,15 +611,6 @@ export interface MCPToolInputSchema { properties: Record } -export interface MCPTool { - id: string - serverId: string - serverName: string - name: string - description?: string - inputSchema: MCPToolInputSchema -} - export interface MCPPromptArguments { name: string description?: string @@ -661,7 +647,7 @@ export interface MCPConfig { interface BaseToolResponse { id: string // unique id - tool: MCPTool | GenericProviderTool + tool: BaseTool arguments: Record | undefined status: string // 'invoking' | 'done' response?: any @@ -755,3 +741,4 @@ export type S3Config = { } export type { Message } from './newMessage' +export * from './tool' diff --git a/src/renderer/src/types/tool.ts b/src/renderer/src/types/tool.ts new file mode 100644 index 0000000000..5f41e4a85c --- /dev/null +++ b/src/renderer/src/types/tool.ts @@ -0,0 +1,26 @@ +import type { MCPToolInputSchema } from './index' + +export type ToolType = 'builtin' | 'provider' | 'mcp' + +export interface BaseTool { + id: string + name: string + description?: string + type: ToolType +} + +export interface GenericProviderTool extends BaseTool { + type: 'provider' +} + +export interface BuiltinTool extends BaseTool { + inputSchema: MCPToolInputSchema + type: 'builtin' +} + +export interface MCPTool extends BaseTool { + serverId: string + serverName: string + inputSchema: MCPToolInputSchema + type: 'mcp' +}