feat(knowledge): use prompt injection for forced knowledge base search

Change the default knowledge base retrieval behavior from tool call to prompt injection mode.
This provides faster response times when knowledge base search is forced.
Intent recognition mode (tool call) is still available as an opt-in option.

- Remove toolChoiceMiddleware for forced knowledge base search
- Add prompt injection for knowledge base references in KnowledgeService
- Move transformMessagesAndFetch to ApiService, delete OrchestrateService
- Export getMessageContent from searchOrchestrationPlugin
- Add setCitationBlockId callback to citationCallbacks
- Default knowledgeRecognition to 'off' (prompt mode)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
kangfenmao 2025-12-16 12:18:11 +08:00
parent c676a93595
commit ef25eef0eb
8 changed files with 218 additions and 136 deletions

View File

@ -7,7 +7,6 @@ import type { Chunk } from '@renderer/types/chunk'
import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import type { LanguageModelMiddleware } from 'ai'
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
import { isEmpty } from 'lodash'
import { getAiSdkProviderId } from '../provider/factory'
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
@ -16,7 +15,6 @@ import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMidd
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
@ -136,15 +134,6 @@ export class AiSdkMiddlewareBuilder {
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] {
const builder = new AiSdkMiddlewareBuilder()
// 0. 知识库强制调用中间件(必须在最前面,确保第一轮强制调用知识库)
if (!isEmpty(config.assistant?.knowledge_bases?.map((base) => base.id)) && config.knowledgeRecognition !== 'on') {
builder.add({
name: 'force-knowledge-first',
middleware: toolChoiceMiddleware('builtin_knowledge_search')
})
logger.debug('Added toolChoice middleware to force knowledge base search on first round')
}
// 1. 根据provider添加特定中间件
if (config.provider) {
addProviderSpecificMiddlewares(builder, config)

View File

@ -31,7 +31,7 @@ import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool'
const logger = loggerService.withContext('SearchOrchestrationPlugin')
const getMessageContent = (message: ModelMessage) => {
export const getMessageContent = (message: ModelMessage) => {
if (typeof message.content === 'string') return message.content
return message.content.reduce((acc, part) => {
if (part.type === 'text') {
@ -266,14 +266,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
// 判断是否需要各种搜索
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
const knowledgeRecognition = assistant.knowledgeRecognition || 'off'
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
const shouldWebSearch = !!assistant.webSearchProviderId
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory
// 执行意图分析
if (shouldWebSearch || hasKnowledgeBase) {
if (shouldWebSearch || shouldKnowledgeSearch) {
const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, {
shouldWebSearch,
shouldKnowledgeSearch,
@ -330,25 +330,10 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
// 📚 知识库搜索工具配置
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
const knowledgeRecognition = assistant.knowledgeRecognition || 'off'
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
if (hasKnowledgeBase) {
if (knowledgeRecognition === 'off') {
// off 模式:直接添加知识库搜索工具,使用用户消息作为搜索关键词
const userMessage = userMessages[context.requestId]
const fallbackKeywords = {
question: [getMessageContent(userMessage) || 'search'],
rewrite: getMessageContent(userMessage) || 'search'
}
// logger.info('📚 Adding knowledge search tool (force mode)')
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
assistant,
fallbackKeywords,
getMessageContent(userMessage),
topicId
)
// params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' }
} else {
if (shouldKnowledgeSearch) {
// on 模式:根据意图识别结果决定是否添加工具
const needsKnowledgeSearch =
analysisResult?.knowledge &&
@ -366,7 +351,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
)
}
}
}
// 🧠 记忆搜索工具配置
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())

View File

@ -34,6 +34,10 @@ import {
getProviderByModel,
getQuickModel
} from './AssistantService'
import { ConversationService } from './ConversationService'
import { injectUserMessageWithKnowledgeSearchPrompt } from './KnowledgeService'
import type { BlockManager } from './messageStreaming'
import type { StreamProcessorCallbacks } from './StreamProcessingService'
// import { processKnowledgeSearch } from './KnowledgeService'
// import {
// filterContextMessages,
@ -79,6 +83,59 @@ export async function fetchMcpTools(assistant: Assistant) {
return mcpTools
}
/**
* LLM可以理解的格式并发送请求
* @param request -
* @param onChunkReceived -
*/
// 目前先按照函数来写,后续如果有需要到class的地方就改回来
export async function transformMessagesAndFetch(
request: {
messages: Message[]
assistant: Assistant
blockManager: BlockManager
assistantMsgId: string
callbacks: StreamProcessorCallbacks
topicId?: string // 添加 topicId 用于 trace
options: {
signal?: AbortSignal
timeout?: number
headers?: Record<string, string>
}
},
onChunkReceived: (chunk: Chunk) => void
) {
const { messages, assistant } = request
try {
const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant)
// replace prompt variables
assistant.prompt = await replacePromptVariables(assistant.prompt, assistant.model?.name)
// inject knowledge search prompt into model messages
await injectUserMessageWithKnowledgeSearchPrompt({
modelMessages,
assistant,
assistantMsgId: request.assistantMsgId,
topicId: request.topicId,
blockManager: request.blockManager,
setCitationBlockId: request.callbacks.setCitationBlockId!
})
await fetchChatCompletion({
messages: modelMessages,
assistant: assistant,
topicId: request.topicId,
requestOptions: request.options,
uiMessages,
onChunkReceived
})
} catch (error: any) {
onChunkReceived({ type: ChunkType.ERROR, error })
}
}
export async function fetchChatCompletion({
messages,
prompt,

View File

@ -2,10 +2,13 @@ import { loggerService } from '@logger'
import type { Span } from '@opentelemetry/api'
import { ModernAiProvider } from '@renderer/aiCore'
import AiProvider from '@renderer/aiCore/legacy'
import { getMessageContent } from '@renderer/aiCore/plugins/searchOrchestrationPlugin'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
import store from '@renderer/store'
import type { Assistant } from '@renderer/types'
import {
type FileMetadata,
type KnowledgeBase,
@ -16,13 +19,17 @@ import {
} from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import { MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
import { routeToEndpoint } from '@renderer/utils'
import type { ExtractResults } from '@renderer/utils/extract'
import { createCitationBlock } from '@renderer/utils/messageUtils/create'
import { isAzureOpenAIProvider, isGeminiProvider } from '@renderer/utils/provider'
import type { ModelMessage, UserModelMessage } from 'ai'
import { isEmpty } from 'lodash'
import { getProviderByModel } from './AssistantService'
import FileManager from './FileManager'
import type { BlockManager } from './messageStreaming'
const logger = loggerService.withContext('RendererKnowledgeService')
@ -338,3 +345,128 @@ export function processKnowledgeReferences(
}
}
}
export const injectUserMessageWithKnowledgeSearchPrompt = async ({
modelMessages,
assistant,
assistantMsgId,
topicId,
blockManager,
setCitationBlockId
}: {
modelMessages: ModelMessage[]
assistant: Assistant
assistantMsgId: string
topicId?: string
blockManager: BlockManager
setCitationBlockId: (blockId: string) => void
}) => {
if (assistant.knowledge_bases?.length && modelMessages.length > 0) {
const lastUserMessage = modelMessages[modelMessages.length - 1]
const isUserMessage = lastUserMessage.role === 'user'
if (!isUserMessage) {
return
}
const knowledgeReferences = await getKnowledgeReferences({
assistant,
lastUserMessage,
topicId: topicId
})
if (knowledgeReferences.length === 0) {
return
}
await createKnowledgeReferencesBlock({
assistantMsgId,
knowledgeReferences,
blockManager,
setCitationBlockId
})
const question = getMessageContent(lastUserMessage) || ''
const references = JSON.stringify(knowledgeReferences, null, 2)
const knowledgeSearchPrompt = REFERENCE_PROMPT.replace('{question}', question).replace('{references}', references)
if (typeof lastUserMessage.content === 'string') {
lastUserMessage.content = knowledgeSearchPrompt
} else if (Array.isArray(lastUserMessage.content)) {
const textPart = lastUserMessage.content.find((part) => part.type === 'text')
if (textPart) {
textPart.text = knowledgeSearchPrompt
} else {
lastUserMessage.content.push({
type: 'text',
text: knowledgeSearchPrompt
})
}
}
}
}
export const getKnowledgeReferences = async ({
assistant,
lastUserMessage,
topicId
}: {
assistant: Assistant
lastUserMessage: UserModelMessage
topicId?: string
}) => {
// 如果助手没有知识库,返回空字符串
if (!assistant || isEmpty(assistant.knowledge_bases)) {
return []
}
// 获取知识库ID
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
// 获取用户消息内容
const question = getMessageContent(lastUserMessage) || ''
// 获取知识库引用
const knowledgeReferences = await processKnowledgeSearch(
{
knowledge: {
question: [question],
rewrite: ''
}
},
knowledgeBaseIds,
topicId!
)
// 返回提示词
return knowledgeReferences
}
export const createKnowledgeReferencesBlock = async ({
assistantMsgId,
knowledgeReferences,
blockManager,
setCitationBlockId
}: {
assistantMsgId: string
knowledgeReferences: KnowledgeReference[]
blockManager: BlockManager
setCitationBlockId: (blockId: string) => void
}) => {
// 创建引用块
const citationBlock = createCitationBlock(
assistantMsgId,
{ knowledge: knowledgeReferences },
{ status: MessageBlockStatus.SUCCESS }
)
// 处理引用块
blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
// 设置引用块ID
setCitationBlockId(citationBlock.id)
// 返回引用块
return citationBlock
}

View File

@ -1,91 +0,0 @@
import type { Assistant, Message } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import { replacePromptVariables } from '@renderer/utils/prompt'
import { fetchChatCompletion } from './ApiService'
import { ConversationService } from './ConversationService'
/**
* The request object for handling a user message.
*/
export interface OrchestrationRequest {
messages: Message[]
assistant: Assistant
options: {
signal?: AbortSignal
timeout?: number
headers?: Record<string, string>
}
topicId?: string // 添加 topicId 用于 trace
}
/**
* The OrchestrationService is responsible for orchestrating the different services
* to handle a user's message. It contains the core logic of the application.
*/
// NOTE暂时没有用到这个类
export class OrchestrationService {
constructor() {
// In the future, this could be a singleton, but for now, a new instance is fine.
// this.conversationService = new ConversationService()
}
/**
* This is the core method to handle user messages.
* It takes the message context and an events object for callbacks,
* and orchestrates the call to the LLM.
* The logic is moved from `messageThunk.ts`.
* @param request The orchestration request containing messages and assistant info.
* @param events A set of callbacks to report progress and results to the UI layer.
*/
async transformMessagesAndFetch(request: OrchestrationRequest, onChunkReceived: (chunk: Chunk) => void) {
const { messages, assistant } = request
try {
const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant)
await fetchChatCompletion({
messages: modelMessages,
assistant: assistant,
requestOptions: request.options,
onChunkReceived,
topicId: request.topicId,
uiMessages: uiMessages
})
} catch (error: any) {
onChunkReceived({ type: ChunkType.ERROR, error })
}
}
}
/**
* LLM可以理解的格式并发送请求
* @param request -
* @param onChunkReceived -
*/
// 目前先按照函数来写,后续如果有需要到class的地方就改回来
export async function transformMessagesAndFetch(
request: OrchestrationRequest,
onChunkReceived: (chunk: Chunk) => void
) {
const { messages, assistant } = request
try {
const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant)
// replace prompt variables
assistant.prompt = await replacePromptVariables(assistant.prompt, assistant.model?.name)
await fetchChatCompletion({
messages: modelMessages,
assistant: assistant,
requestOptions: request.options,
onChunkReceived,
topicId: request.topicId,
uiMessages
})
} catch (error: any) {
onChunkReceived({ type: ChunkType.ERROR, error })
}
}

View File

@ -34,6 +34,10 @@ export interface StreamProcessorCallbacks {
onLLMWebSearchInProgress?: () => void
// LLM Web search complete
onLLMWebSearchComplete?: (llmWebSearchResult: WebSearchResponse) => void
// Get citation block ID
getCitationBlockId?: () => string | null
// Set citation block ID
setCitationBlockId?: (blockId: string) => void
// Image generation chunk received
onImageCreated?: () => void
onImageDelta?: (imageData: GenerateImageResponse) => void

View File

@ -121,6 +121,11 @@ export const createCitationCallbacks = (deps: CitationCallbacksDependencies) =>
},
// 暴露给外部的方法用于textCallbacks中获取citationBlockId
getCitationBlockId: () => citationBlockId
getCitationBlockId: () => citationBlockId,
// 暴露给外部的方法,用于 KnowledgeService 中设置 citationBlockId
setCitationBlockId: (blockId: string) => {
citationBlockId = blockId
}
}
}

View File

@ -2,12 +2,11 @@ import { loggerService } from '@logger'
import { AiSdkToChunkAdapter } from '@renderer/aiCore/chunk/AiSdkToChunkAdapter'
import { AgentApiClient } from '@renderer/api/agent'
import db from '@renderer/databases'
import { fetchMessagesSummary } from '@renderer/services/ApiService'
import { fetchMessagesSummary, transformMessagesAndFetch } from '@renderer/services/ApiService'
import { DbService } from '@renderer/services/db/DbService'
import FileManager from '@renderer/services/FileManager'
import { BlockManager } from '@renderer/services/messageStreaming/BlockManager'
import { createCallbacks } from '@renderer/services/messageStreaming/callbacks'
import { transformMessagesAndFetch } from '@renderer/services/OrchestrateService'
import { endSpan } from '@renderer/services/SpanManagerService'
import { createStreamProcessor, type StreamProcessorCallbacks } from '@renderer/services/StreamProcessingService'
import store from '@renderer/store'
@ -814,6 +813,9 @@ const fetchAndProcessAssistantResponseImpl = async (
messages: messagesForContext,
assistant,
topicId,
blockManager,
assistantMsgId,
callbacks,
options: {
signal: abortController.signal,
timeout: 30000,