diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index 10a4d5938..b2a796bd3 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -7,7 +7,6 @@ import type { Chunk } from '@renderer/types/chunk' import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider' import type { LanguageModelMiddleware } from 'ai' import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' -import { isEmpty } from 'lodash' import { getAiSdkProviderId } from '../provider/factory' import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' @@ -16,7 +15,6 @@ import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMidd import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware' import { qwenThinkingMiddleware } from './qwenThinkingMiddleware' import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware' -import { toolChoiceMiddleware } from './toolChoiceMiddleware' const logger = loggerService.withContext('AiSdkMiddlewareBuilder') @@ -136,15 +134,6 @@ export class AiSdkMiddlewareBuilder { export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] { const builder = new AiSdkMiddlewareBuilder() - // 0. 知识库强制调用中间件(必须在最前面,确保第一轮强制调用知识库) - if (!isEmpty(config.assistant?.knowledge_bases?.map((base) => base.id)) && config.knowledgeRecognition !== 'on') { - builder.add({ - name: 'force-knowledge-first', - middleware: toolChoiceMiddleware('builtin_knowledge_search') - }) - logger.debug('Added toolChoice middleware to force knowledge base search on first round') - } - // 1. 根据provider添加特定中间件 if (config.provider) { addProviderSpecificMiddlewares(builder, config) diff --git a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts index 6be577f19..5b095a446 100644 --- a/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts +++ b/src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts @@ -31,7 +31,7 @@ import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool' const logger = loggerService.withContext('SearchOrchestrationPlugin') -const getMessageContent = (message: ModelMessage) => { +export const getMessageContent = (message: ModelMessage) => { if (typeof message.content === 'string') return message.content return message.content.reduce((acc, part) => { if (part.type === 'text') { @@ -266,14 +266,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) // 判断是否需要各种搜索 const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) - const knowledgeRecognition = assistant.knowledgeRecognition || 'on' + const knowledgeRecognition = assistant.knowledgeRecognition || 'off' const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState()) const shouldWebSearch = !!assistant.webSearchProviderId const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on' const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory // 执行意图分析 - if (shouldWebSearch || hasKnowledgeBase) { + if (shouldWebSearch || shouldKnowledgeSearch) { const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, { shouldWebSearch, shouldKnowledgeSearch, @@ -330,41 +330,25 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string) // 📚 知识库搜索工具配置 const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) const hasKnowledgeBase = !isEmpty(knowledgeBaseIds) - const knowledgeRecognition = assistant.knowledgeRecognition || 'on' + const knowledgeRecognition = assistant.knowledgeRecognition || 'off' + const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on' - if (hasKnowledgeBase) { - if (knowledgeRecognition === 'off') { - // off 模式:直接添加知识库搜索工具,使用用户消息作为搜索关键词 + if (shouldKnowledgeSearch) { + // on 模式:根据意图识别结果决定是否添加工具 + const needsKnowledgeSearch = + analysisResult?.knowledge && + analysisResult.knowledge.question && + analysisResult.knowledge.question[0] !== 'not_needed' + + if (needsKnowledgeSearch && analysisResult.knowledge) { + // logger.info('📚 Adding knowledge search tool (intent-based)') const userMessage = userMessages[context.requestId] - const fallbackKeywords = { - question: [getMessageContent(userMessage) || 'search'], - rewrite: getMessageContent(userMessage) || 'search' - } - // logger.info('📚 Adding knowledge search tool (force mode)') params.tools['builtin_knowledge_search'] = knowledgeSearchTool( assistant, - fallbackKeywords, + analysisResult.knowledge, getMessageContent(userMessage), topicId ) - // params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' } - } else { - // on 模式:根据意图识别结果决定是否添加工具 - const needsKnowledgeSearch = - analysisResult?.knowledge && - analysisResult.knowledge.question && - analysisResult.knowledge.question[0] !== 'not_needed' - - if (needsKnowledgeSearch && analysisResult.knowledge) { - // logger.info('📚 Adding knowledge search tool (intent-based)') - const userMessage = userMessages[context.requestId] - params.tools['builtin_knowledge_search'] = knowledgeSearchTool( - assistant, - analysisResult.knowledge, - getMessageContent(userMessage), - topicId - ) - } } } diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 3c081c3da..0cd57a353 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -34,6 +34,10 @@ import { getProviderByModel, getQuickModel } from './AssistantService' +import { ConversationService } from './ConversationService' +import { injectUserMessageWithKnowledgeSearchPrompt } from './KnowledgeService' +import type { BlockManager } from './messageStreaming' +import type { StreamProcessorCallbacks } from './StreamProcessingService' // import { processKnowledgeSearch } from './KnowledgeService' // import { // filterContextMessages, @@ -79,6 +83,59 @@ export async function fetchMcpTools(assistant: Assistant) { return mcpTools } +/** + * 将用户消息转换为LLM可以理解的格式并发送请求 + * @param request - 包含消息内容和助手信息的请求对象 + * @param onChunkReceived - 接收流式响应数据的回调函数 + */ +// 目前先按照函数来写,后续如果有需要到class的地方就改回来 +export async function transformMessagesAndFetch( + request: { + messages: Message[] + assistant: Assistant + blockManager: BlockManager + assistantMsgId: string + callbacks: StreamProcessorCallbacks + topicId?: string // 添加 topicId 用于 trace + options: { + signal?: AbortSignal + timeout?: number + headers?: Record + } + }, + onChunkReceived: (chunk: Chunk) => void +) { + const { messages, assistant } = request + + try { + const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant) + + // replace prompt variables + assistant.prompt = await replacePromptVariables(assistant.prompt, assistant.model?.name) + + // inject knowledge search prompt into model messages + await injectUserMessageWithKnowledgeSearchPrompt({ + modelMessages, + assistant, + assistantMsgId: request.assistantMsgId, + topicId: request.topicId, + blockManager: request.blockManager, + setCitationBlockId: request.callbacks.setCitationBlockId! + }) + + await fetchChatCompletion({ + messages: modelMessages, + assistant: assistant, + topicId: request.topicId, + requestOptions: request.options, + uiMessages, + onChunkReceived + }) + } catch (error: any) { + onChunkReceived({ type: ChunkType.ERROR, error }) + } +} + export async function fetchChatCompletion({ messages, prompt, diff --git a/src/renderer/src/services/KnowledgeService.ts b/src/renderer/src/services/KnowledgeService.ts index e78cfa62e..ce9577c68 100644 --- a/src/renderer/src/services/KnowledgeService.ts +++ b/src/renderer/src/services/KnowledgeService.ts @@ -2,10 +2,13 @@ import { loggerService } from '@logger' import type { Span } from '@opentelemetry/api' import { ModernAiProvider } from '@renderer/aiCore' import AiProvider from '@renderer/aiCore/legacy' +import { getMessageContent } from '@renderer/aiCore/plugins/searchOrchestrationPlugin' import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant' import { getEmbeddingMaxContext } from '@renderer/config/embedings' +import { REFERENCE_PROMPT } from '@renderer/config/prompts' import { addSpan, endSpan } from '@renderer/services/SpanManagerService' import store from '@renderer/store' +import type { Assistant } from '@renderer/types' import { type FileMetadata, type KnowledgeBase, @@ -16,13 +19,17 @@ import { } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' import { ChunkType } from '@renderer/types/chunk' +import { MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage' import { routeToEndpoint } from '@renderer/utils' import type { ExtractResults } from '@renderer/utils/extract' +import { createCitationBlock } from '@renderer/utils/messageUtils/create' import { isAzureOpenAIProvider, isGeminiProvider } from '@renderer/utils/provider' +import type { ModelMessage, UserModelMessage } from 'ai' import { isEmpty } from 'lodash' import { getProviderByModel } from './AssistantService' import FileManager from './FileManager' +import type { BlockManager } from './messageStreaming' const logger = loggerService.withContext('RendererKnowledgeService') @@ -338,3 +345,128 @@ export function processKnowledgeReferences( } } } + +export const injectUserMessageWithKnowledgeSearchPrompt = async ({ + modelMessages, + assistant, + assistantMsgId, + topicId, + blockManager, + setCitationBlockId +}: { + modelMessages: ModelMessage[] + assistant: Assistant + assistantMsgId: string + topicId?: string + blockManager: BlockManager + setCitationBlockId: (blockId: string) => void +}) => { + if (assistant.knowledge_bases?.length && modelMessages.length > 0) { + const lastUserMessage = modelMessages[modelMessages.length - 1] + const isUserMessage = lastUserMessage.role === 'user' + + if (!isUserMessage) { + return + } + + const knowledgeReferences = await getKnowledgeReferences({ + assistant, + lastUserMessage, + topicId: topicId + }) + + if (knowledgeReferences.length === 0) { + return + } + + await createKnowledgeReferencesBlock({ + assistantMsgId, + knowledgeReferences, + blockManager, + setCitationBlockId + }) + + const question = getMessageContent(lastUserMessage) || '' + const references = JSON.stringify(knowledgeReferences, null, 2) + + const knowledgeSearchPrompt = REFERENCE_PROMPT.replace('{question}', question).replace('{references}', references) + + if (typeof lastUserMessage.content === 'string') { + lastUserMessage.content = knowledgeSearchPrompt + } else if (Array.isArray(lastUserMessage.content)) { + const textPart = lastUserMessage.content.find((part) => part.type === 'text') + if (textPart) { + textPart.text = knowledgeSearchPrompt + } else { + lastUserMessage.content.push({ + type: 'text', + text: knowledgeSearchPrompt + }) + } + } + } +} + +export const getKnowledgeReferences = async ({ + assistant, + lastUserMessage, + topicId +}: { + assistant: Assistant + lastUserMessage: UserModelMessage + topicId?: string +}) => { + // 如果助手没有知识库,返回空字符串 + if (!assistant || isEmpty(assistant.knowledge_bases)) { + return [] + } + + // 获取知识库ID + const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id) + + // 获取用户消息内容 + const question = getMessageContent(lastUserMessage) || '' + + // 获取知识库引用 + const knowledgeReferences = await processKnowledgeSearch( + { + knowledge: { + question: [question], + rewrite: '' + } + }, + knowledgeBaseIds, + topicId! + ) + + // 返回提示词 + return knowledgeReferences +} + +export const createKnowledgeReferencesBlock = async ({ + assistantMsgId, + knowledgeReferences, + blockManager, + setCitationBlockId +}: { + assistantMsgId: string + knowledgeReferences: KnowledgeReference[] + blockManager: BlockManager + setCitationBlockId: (blockId: string) => void +}) => { + // 创建引用块 + const citationBlock = createCitationBlock( + assistantMsgId, + { knowledge: knowledgeReferences }, + { status: MessageBlockStatus.SUCCESS } + ) + + // 处理引用块 + blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION) + + // 设置引用块ID + setCitationBlockId(citationBlock.id) + + // 返回引用块 + return citationBlock +} diff --git a/src/renderer/src/services/OrchestrateService.ts b/src/renderer/src/services/OrchestrateService.ts deleted file mode 100644 index 71f17d680..000000000 --- a/src/renderer/src/services/OrchestrateService.ts +++ /dev/null @@ -1,91 +0,0 @@ -import type { Assistant, Message } from '@renderer/types' -import type { Chunk } from '@renderer/types/chunk' -import { ChunkType } from '@renderer/types/chunk' -import { replacePromptVariables } from '@renderer/utils/prompt' - -import { fetchChatCompletion } from './ApiService' -import { ConversationService } from './ConversationService' - -/** - * The request object for handling a user message. - */ -export interface OrchestrationRequest { - messages: Message[] - assistant: Assistant - options: { - signal?: AbortSignal - timeout?: number - headers?: Record - } - topicId?: string // 添加 topicId 用于 trace -} - -/** - * The OrchestrationService is responsible for orchestrating the different services - * to handle a user's message. It contains the core logic of the application. - */ -// NOTE:暂时没有用到这个类 -export class OrchestrationService { - constructor() { - // In the future, this could be a singleton, but for now, a new instance is fine. - // this.conversationService = new ConversationService() - } - - /** - * This is the core method to handle user messages. - * It takes the message context and an events object for callbacks, - * and orchestrates the call to the LLM. - * The logic is moved from `messageThunk.ts`. - * @param request The orchestration request containing messages and assistant info. - * @param events A set of callbacks to report progress and results to the UI layer. - */ - async transformMessagesAndFetch(request: OrchestrationRequest, onChunkReceived: (chunk: Chunk) => void) { - const { messages, assistant } = request - - try { - const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant) - - await fetchChatCompletion({ - messages: modelMessages, - assistant: assistant, - requestOptions: request.options, - onChunkReceived, - topicId: request.topicId, - uiMessages: uiMessages - }) - } catch (error: any) { - onChunkReceived({ type: ChunkType.ERROR, error }) - } - } -} - -/** - * 将用户消息转换为LLM可以理解的格式并发送请求 - * @param request - 包含消息内容和助手信息的请求对象 - * @param onChunkReceived - 接收流式响应数据的回调函数 - */ -// 目前先按照函数来写,后续如果有需要到class的地方就改回来 -export async function transformMessagesAndFetch( - request: OrchestrationRequest, - onChunkReceived: (chunk: Chunk) => void -) { - const { messages, assistant } = request - - try { - const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant) - - // replace prompt variables - assistant.prompt = await replacePromptVariables(assistant.prompt, assistant.model?.name) - - await fetchChatCompletion({ - messages: modelMessages, - assistant: assistant, - requestOptions: request.options, - onChunkReceived, - topicId: request.topicId, - uiMessages - }) - } catch (error: any) { - onChunkReceived({ type: ChunkType.ERROR, error }) - } -} diff --git a/src/renderer/src/services/StreamProcessingService.ts b/src/renderer/src/services/StreamProcessingService.ts index 26f52b803..7e80672d5 100644 --- a/src/renderer/src/services/StreamProcessingService.ts +++ b/src/renderer/src/services/StreamProcessingService.ts @@ -34,6 +34,10 @@ export interface StreamProcessorCallbacks { onLLMWebSearchInProgress?: () => void // LLM Web search complete onLLMWebSearchComplete?: (llmWebSearchResult: WebSearchResponse) => void + // Get citation block ID + getCitationBlockId?: () => string | null + // Set citation block ID + setCitationBlockId?: (blockId: string) => void // Image generation chunk received onImageCreated?: () => void onImageDelta?: (imageData: GenerateImageResponse) => void diff --git a/src/renderer/src/services/messageStreaming/callbacks/citationCallbacks.ts b/src/renderer/src/services/messageStreaming/callbacks/citationCallbacks.ts index 9e99fe752..324549363 100644 --- a/src/renderer/src/services/messageStreaming/callbacks/citationCallbacks.ts +++ b/src/renderer/src/services/messageStreaming/callbacks/citationCallbacks.ts @@ -121,6 +121,11 @@ export const createCitationCallbacks = (deps: CitationCallbacksDependencies) => }, // 暴露给外部的方法,用于textCallbacks中获取citationBlockId - getCitationBlockId: () => citationBlockId + getCitationBlockId: () => citationBlockId, + + // 暴露给外部的方法,用于 KnowledgeService 中设置 citationBlockId + setCitationBlockId: (blockId: string) => { + citationBlockId = blockId + } } } diff --git a/src/renderer/src/store/thunk/messageThunk.ts b/src/renderer/src/store/thunk/messageThunk.ts index a70fdf572..8219fa0cc 100644 --- a/src/renderer/src/store/thunk/messageThunk.ts +++ b/src/renderer/src/store/thunk/messageThunk.ts @@ -2,12 +2,11 @@ import { loggerService } from '@logger' import { AiSdkToChunkAdapter } from '@renderer/aiCore/chunk/AiSdkToChunkAdapter' import { AgentApiClient } from '@renderer/api/agent' import db from '@renderer/databases' -import { fetchMessagesSummary } from '@renderer/services/ApiService' +import { fetchMessagesSummary, transformMessagesAndFetch } from '@renderer/services/ApiService' import { DbService } from '@renderer/services/db/DbService' import FileManager from '@renderer/services/FileManager' import { BlockManager } from '@renderer/services/messageStreaming/BlockManager' import { createCallbacks } from '@renderer/services/messageStreaming/callbacks' -import { transformMessagesAndFetch } from '@renderer/services/OrchestrateService' import { endSpan } from '@renderer/services/SpanManagerService' import { createStreamProcessor, type StreamProcessorCallbacks } from '@renderer/services/StreamProcessingService' import store from '@renderer/store' @@ -814,6 +813,9 @@ const fetchAndProcessAssistantResponseImpl = async ( messages: messagesForContext, assistant, topicId, + blockManager, + assistantMsgId, + callbacks, options: { signal: abortController.signal, timeout: 30000,