diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index 20b89cf2e5..e2b253374b 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -4,6 +4,8 @@ import type { MCPTool, Message, Model, Provider } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' import { extractReasoningMiddleware, LanguageModelMiddleware, simulateStreamingMiddleware } from 'ai' +import { toolChoiceMiddleware } from './toolChoiceMiddleware' + const logger = loggerService.withContext('AiSdkMiddlewareBuilder') /** @@ -29,6 +31,8 @@ export interface AiSdkMiddlewareConfig { uiMessages?: Message[] // 内置搜索配置 webSearchPluginConfig?: WebSearchPluginConfig + // 知识库识别开关,默认开启 + knowledgeRecognition?: 'off' | 'on' } /** @@ -119,6 +123,15 @@ export class AiSdkMiddlewareBuilder { export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] { const builder = new AiSdkMiddlewareBuilder() + // 0. 知识库强制调用中间件(必须在最前面,确保第一轮强制调用知识库) + if (config.knowledgeRecognition === 'off') { + builder.add({ + name: 'force-knowledge-first', + middleware: toolChoiceMiddleware('builtin_knowledge_search') + }) + logger.debug('Added toolChoice middleware to force knowledge base search on first round') + } + // 1. 根据provider添加特定中间件 if (config.provider) { addProviderSpecificMiddlewares(builder, config) diff --git a/src/renderer/src/aiCore/middleware/toolChoiceMiddleware.ts b/src/renderer/src/aiCore/middleware/toolChoiceMiddleware.ts new file mode 100644 index 0000000000..6d3ba37d1d --- /dev/null +++ b/src/renderer/src/aiCore/middleware/toolChoiceMiddleware.ts @@ -0,0 +1,45 @@ +import { loggerService } from '@logger' +import { LanguageModelMiddleware } from 'ai' + +const logger = loggerService.withContext('toolChoiceMiddleware') + +/** + * Tool Choice Middleware + * Controls tool selection strategy across multiple rounds of tool calls: + * - First round: Forces the model to call a specific tool (e.g., knowledge base search) + * - Subsequent rounds: Allows the model to automatically choose any available tool + * + * This ensures knowledge base is consulted first while still enabling MCP tools + * and other capabilities in follow-up interactions. + * + * @param forceFirstToolName - The tool name to force on the first round + * @returns LanguageModelMiddleware + */ +export function toolChoiceMiddleware(forceFirstToolName: string): LanguageModelMiddleware { + let toolCallRound = 0 + + return { + middlewareVersion: 'v2', + + transformParams: async ({ params }) => { + toolCallRound++ + + const transformedParams = { ...params } + + if (toolCallRound === 1) { + // First round: force the specified tool + logger.debug(`Round ${toolCallRound}: Forcing tool choice to '${forceFirstToolName}'`) + transformedParams.toolChoice = { + type: 'tool', + toolName: forceFirstToolName + } + } else { + // Subsequent rounds: allow automatic tool selection + logger.debug(`Round ${toolCallRound}: Using automatic tool choice`) + transformedParams.toolChoice = { type: 'auto' } + } + + return transformedParams + } + } +} diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 6ab07662cb..403fdc2cfe 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -138,7 +138,8 @@ export async function fetchChatCompletion({ enableGenerateImage: capabilities.enableGenerateImage, enableUrlContext: capabilities.enableUrlContext, mcpTools, - uiMessages + uiMessages, + knowledgeRecognition: assistant.knowledgeRecognition } // --- Call AI Completions ---