mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 06:30:10 +08:00
feat(knowledge): use prompt injection for forced knowledge base search
Change the default knowledge base retrieval behavior from tool call to prompt injection mode. This provides faster response times when knowledge base search is forced. Intent recognition mode (tool call) is still available as an opt-in option. - Remove toolChoiceMiddleware for forced knowledge base search - Add prompt injection for knowledge base references in KnowledgeService - Move transformMessagesAndFetch to ApiService, delete OrchestrateService - Export getMessageContent from searchOrchestrationPlugin - Add setCitationBlockId callback to citationCallbacks - Default knowledgeRecognition to 'off' (prompt mode) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
c676a93595
commit
ef25eef0eb
@ -7,7 +7,6 @@ import type { Chunk } from '@renderer/types/chunk'
|
||||
import { isOllamaProvider, isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
|
||||
@ -16,7 +15,6 @@ import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMidd
|
||||
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
|
||||
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
|
||||
import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
|
||||
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
|
||||
|
||||
@ -136,15 +134,6 @@ export class AiSdkMiddlewareBuilder {
|
||||
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] {
|
||||
const builder = new AiSdkMiddlewareBuilder()
|
||||
|
||||
// 0. 知识库强制调用中间件(必须在最前面,确保第一轮强制调用知识库)
|
||||
if (!isEmpty(config.assistant?.knowledge_bases?.map((base) => base.id)) && config.knowledgeRecognition !== 'on') {
|
||||
builder.add({
|
||||
name: 'force-knowledge-first',
|
||||
middleware: toolChoiceMiddleware('builtin_knowledge_search')
|
||||
})
|
||||
logger.debug('Added toolChoice middleware to force knowledge base search on first round')
|
||||
}
|
||||
|
||||
// 1. 根据provider添加特定中间件
|
||||
if (config.provider) {
|
||||
addProviderSpecificMiddlewares(builder, config)
|
||||
|
||||
@ -31,7 +31,7 @@ import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool'
|
||||
|
||||
const logger = loggerService.withContext('SearchOrchestrationPlugin')
|
||||
|
||||
const getMessageContent = (message: ModelMessage) => {
|
||||
export const getMessageContent = (message: ModelMessage) => {
|
||||
if (typeof message.content === 'string') return message.content
|
||||
return message.content.reduce((acc, part) => {
|
||||
if (part.type === 'text') {
|
||||
@ -266,14 +266,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
||||
// 判断是否需要各种搜索
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'off'
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
const shouldWebSearch = !!assistant.webSearchProviderId
|
||||
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
|
||||
const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory
|
||||
|
||||
// 执行意图分析
|
||||
if (shouldWebSearch || hasKnowledgeBase) {
|
||||
if (shouldWebSearch || shouldKnowledgeSearch) {
|
||||
const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, {
|
||||
shouldWebSearch,
|
||||
shouldKnowledgeSearch,
|
||||
@ -330,41 +330,25 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
||||
// 📚 知识库搜索工具配置
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'off'
|
||||
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
|
||||
|
||||
if (hasKnowledgeBase) {
|
||||
if (knowledgeRecognition === 'off') {
|
||||
// off 模式:直接添加知识库搜索工具,使用用户消息作为搜索关键词
|
||||
if (shouldKnowledgeSearch) {
|
||||
// on 模式:根据意图识别结果决定是否添加工具
|
||||
const needsKnowledgeSearch =
|
||||
analysisResult?.knowledge &&
|
||||
analysisResult.knowledge.question &&
|
||||
analysisResult.knowledge.question[0] !== 'not_needed'
|
||||
|
||||
if (needsKnowledgeSearch && analysisResult.knowledge) {
|
||||
// logger.info('📚 Adding knowledge search tool (intent-based)')
|
||||
const userMessage = userMessages[context.requestId]
|
||||
const fallbackKeywords = {
|
||||
question: [getMessageContent(userMessage) || 'search'],
|
||||
rewrite: getMessageContent(userMessage) || 'search'
|
||||
}
|
||||
// logger.info('📚 Adding knowledge search tool (force mode)')
|
||||
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
|
||||
assistant,
|
||||
fallbackKeywords,
|
||||
analysisResult.knowledge,
|
||||
getMessageContent(userMessage),
|
||||
topicId
|
||||
)
|
||||
// params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' }
|
||||
} else {
|
||||
// on 模式:根据意图识别结果决定是否添加工具
|
||||
const needsKnowledgeSearch =
|
||||
analysisResult?.knowledge &&
|
||||
analysisResult.knowledge.question &&
|
||||
analysisResult.knowledge.question[0] !== 'not_needed'
|
||||
|
||||
if (needsKnowledgeSearch && analysisResult.knowledge) {
|
||||
// logger.info('📚 Adding knowledge search tool (intent-based)')
|
||||
const userMessage = userMessages[context.requestId]
|
||||
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
|
||||
assistant,
|
||||
analysisResult.knowledge,
|
||||
getMessageContent(userMessage),
|
||||
topicId
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -34,6 +34,10 @@ import {
|
||||
getProviderByModel,
|
||||
getQuickModel
|
||||
} from './AssistantService'
|
||||
import { ConversationService } from './ConversationService'
|
||||
import { injectUserMessageWithKnowledgeSearchPrompt } from './KnowledgeService'
|
||||
import type { BlockManager } from './messageStreaming'
|
||||
import type { StreamProcessorCallbacks } from './StreamProcessingService'
|
||||
// import { processKnowledgeSearch } from './KnowledgeService'
|
||||
// import {
|
||||
// filterContextMessages,
|
||||
@ -79,6 +83,59 @@ export async function fetchMcpTools(assistant: Assistant) {
|
||||
return mcpTools
|
||||
}
|
||||
|
||||
/**
|
||||
* 将用户消息转换为LLM可以理解的格式并发送请求
|
||||
* @param request - 包含消息内容和助手信息的请求对象
|
||||
* @param onChunkReceived - 接收流式响应数据的回调函数
|
||||
*/
|
||||
// 目前先按照函数来写,后续如果有需要到class的地方就改回来
|
||||
export async function transformMessagesAndFetch(
|
||||
request: {
|
||||
messages: Message[]
|
||||
assistant: Assistant
|
||||
blockManager: BlockManager
|
||||
assistantMsgId: string
|
||||
callbacks: StreamProcessorCallbacks
|
||||
topicId?: string // 添加 topicId 用于 trace
|
||||
options: {
|
||||
signal?: AbortSignal
|
||||
timeout?: number
|
||||
headers?: Record<string, string>
|
||||
}
|
||||
},
|
||||
onChunkReceived: (chunk: Chunk) => void
|
||||
) {
|
||||
const { messages, assistant } = request
|
||||
|
||||
try {
|
||||
const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant)
|
||||
|
||||
// replace prompt variables
|
||||
assistant.prompt = await replacePromptVariables(assistant.prompt, assistant.model?.name)
|
||||
|
||||
// inject knowledge search prompt into model messages
|
||||
await injectUserMessageWithKnowledgeSearchPrompt({
|
||||
modelMessages,
|
||||
assistant,
|
||||
assistantMsgId: request.assistantMsgId,
|
||||
topicId: request.topicId,
|
||||
blockManager: request.blockManager,
|
||||
setCitationBlockId: request.callbacks.setCitationBlockId!
|
||||
})
|
||||
|
||||
await fetchChatCompletion({
|
||||
messages: modelMessages,
|
||||
assistant: assistant,
|
||||
topicId: request.topicId,
|
||||
requestOptions: request.options,
|
||||
uiMessages,
|
||||
onChunkReceived
|
||||
})
|
||||
} catch (error: any) {
|
||||
onChunkReceived({ type: ChunkType.ERROR, error })
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchChatCompletion({
|
||||
messages,
|
||||
prompt,
|
||||
|
||||
@ -2,10 +2,13 @@ import { loggerService } from '@logger'
|
||||
import type { Span } from '@opentelemetry/api'
|
||||
import { ModernAiProvider } from '@renderer/aiCore'
|
||||
import AiProvider from '@renderer/aiCore/legacy'
|
||||
import { getMessageContent } from '@renderer/aiCore/plugins/searchOrchestrationPlugin'
|
||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant'
|
||||
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
|
||||
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
import store from '@renderer/store'
|
||||
import type { Assistant } from '@renderer/types'
|
||||
import {
|
||||
type FileMetadata,
|
||||
type KnowledgeBase,
|
||||
@ -16,13 +19,17 @@ import {
|
||||
} from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
|
||||
import { routeToEndpoint } from '@renderer/utils'
|
||||
import type { ExtractResults } from '@renderer/utils/extract'
|
||||
import { createCitationBlock } from '@renderer/utils/messageUtils/create'
|
||||
import { isAzureOpenAIProvider, isGeminiProvider } from '@renderer/utils/provider'
|
||||
import type { ModelMessage, UserModelMessage } from 'ai'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { getProviderByModel } from './AssistantService'
|
||||
import FileManager from './FileManager'
|
||||
import type { BlockManager } from './messageStreaming'
|
||||
|
||||
const logger = loggerService.withContext('RendererKnowledgeService')
|
||||
|
||||
@ -338,3 +345,128 @@ export function processKnowledgeReferences(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const injectUserMessageWithKnowledgeSearchPrompt = async ({
|
||||
modelMessages,
|
||||
assistant,
|
||||
assistantMsgId,
|
||||
topicId,
|
||||
blockManager,
|
||||
setCitationBlockId
|
||||
}: {
|
||||
modelMessages: ModelMessage[]
|
||||
assistant: Assistant
|
||||
assistantMsgId: string
|
||||
topicId?: string
|
||||
blockManager: BlockManager
|
||||
setCitationBlockId: (blockId: string) => void
|
||||
}) => {
|
||||
if (assistant.knowledge_bases?.length && modelMessages.length > 0) {
|
||||
const lastUserMessage = modelMessages[modelMessages.length - 1]
|
||||
const isUserMessage = lastUserMessage.role === 'user'
|
||||
|
||||
if (!isUserMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
const knowledgeReferences = await getKnowledgeReferences({
|
||||
assistant,
|
||||
lastUserMessage,
|
||||
topicId: topicId
|
||||
})
|
||||
|
||||
if (knowledgeReferences.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
await createKnowledgeReferencesBlock({
|
||||
assistantMsgId,
|
||||
knowledgeReferences,
|
||||
blockManager,
|
||||
setCitationBlockId
|
||||
})
|
||||
|
||||
const question = getMessageContent(lastUserMessage) || ''
|
||||
const references = JSON.stringify(knowledgeReferences, null, 2)
|
||||
|
||||
const knowledgeSearchPrompt = REFERENCE_PROMPT.replace('{question}', question).replace('{references}', references)
|
||||
|
||||
if (typeof lastUserMessage.content === 'string') {
|
||||
lastUserMessage.content = knowledgeSearchPrompt
|
||||
} else if (Array.isArray(lastUserMessage.content)) {
|
||||
const textPart = lastUserMessage.content.find((part) => part.type === 'text')
|
||||
if (textPart) {
|
||||
textPart.text = knowledgeSearchPrompt
|
||||
} else {
|
||||
lastUserMessage.content.push({
|
||||
type: 'text',
|
||||
text: knowledgeSearchPrompt
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const getKnowledgeReferences = async ({
|
||||
assistant,
|
||||
lastUserMessage,
|
||||
topicId
|
||||
}: {
|
||||
assistant: Assistant
|
||||
lastUserMessage: UserModelMessage
|
||||
topicId?: string
|
||||
}) => {
|
||||
// 如果助手没有知识库,返回空字符串
|
||||
if (!assistant || isEmpty(assistant.knowledge_bases)) {
|
||||
return []
|
||||
}
|
||||
|
||||
// 获取知识库ID
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
|
||||
// 获取用户消息内容
|
||||
const question = getMessageContent(lastUserMessage) || ''
|
||||
|
||||
// 获取知识库引用
|
||||
const knowledgeReferences = await processKnowledgeSearch(
|
||||
{
|
||||
knowledge: {
|
||||
question: [question],
|
||||
rewrite: ''
|
||||
}
|
||||
},
|
||||
knowledgeBaseIds,
|
||||
topicId!
|
||||
)
|
||||
|
||||
// 返回提示词
|
||||
return knowledgeReferences
|
||||
}
|
||||
|
||||
export const createKnowledgeReferencesBlock = async ({
|
||||
assistantMsgId,
|
||||
knowledgeReferences,
|
||||
blockManager,
|
||||
setCitationBlockId
|
||||
}: {
|
||||
assistantMsgId: string
|
||||
knowledgeReferences: KnowledgeReference[]
|
||||
blockManager: BlockManager
|
||||
setCitationBlockId: (blockId: string) => void
|
||||
}) => {
|
||||
// 创建引用块
|
||||
const citationBlock = createCitationBlock(
|
||||
assistantMsgId,
|
||||
{ knowledge: knowledgeReferences },
|
||||
{ status: MessageBlockStatus.SUCCESS }
|
||||
)
|
||||
|
||||
// 处理引用块
|
||||
blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
|
||||
// 设置引用块ID
|
||||
setCitationBlockId(citationBlock.id)
|
||||
|
||||
// 返回引用块
|
||||
return citationBlock
|
||||
}
|
||||
|
||||
@ -1,91 +0,0 @@
|
||||
import type { Assistant, Message } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { replacePromptVariables } from '@renderer/utils/prompt'
|
||||
|
||||
import { fetchChatCompletion } from './ApiService'
|
||||
import { ConversationService } from './ConversationService'
|
||||
|
||||
/**
|
||||
* The request object for handling a user message.
|
||||
*/
|
||||
export interface OrchestrationRequest {
|
||||
messages: Message[]
|
||||
assistant: Assistant
|
||||
options: {
|
||||
signal?: AbortSignal
|
||||
timeout?: number
|
||||
headers?: Record<string, string>
|
||||
}
|
||||
topicId?: string // 添加 topicId 用于 trace
|
||||
}
|
||||
|
||||
/**
|
||||
* The OrchestrationService is responsible for orchestrating the different services
|
||||
* to handle a user's message. It contains the core logic of the application.
|
||||
*/
|
||||
// NOTE:暂时没有用到这个类
|
||||
export class OrchestrationService {
|
||||
constructor() {
|
||||
// In the future, this could be a singleton, but for now, a new instance is fine.
|
||||
// this.conversationService = new ConversationService()
|
||||
}
|
||||
|
||||
/**
|
||||
* This is the core method to handle user messages.
|
||||
* It takes the message context and an events object for callbacks,
|
||||
* and orchestrates the call to the LLM.
|
||||
* The logic is moved from `messageThunk.ts`.
|
||||
* @param request The orchestration request containing messages and assistant info.
|
||||
* @param events A set of callbacks to report progress and results to the UI layer.
|
||||
*/
|
||||
async transformMessagesAndFetch(request: OrchestrationRequest, onChunkReceived: (chunk: Chunk) => void) {
|
||||
const { messages, assistant } = request
|
||||
|
||||
try {
|
||||
const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant)
|
||||
|
||||
await fetchChatCompletion({
|
||||
messages: modelMessages,
|
||||
assistant: assistant,
|
||||
requestOptions: request.options,
|
||||
onChunkReceived,
|
||||
topicId: request.topicId,
|
||||
uiMessages: uiMessages
|
||||
})
|
||||
} catch (error: any) {
|
||||
onChunkReceived({ type: ChunkType.ERROR, error })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将用户消息转换为LLM可以理解的格式并发送请求
|
||||
* @param request - 包含消息内容和助手信息的请求对象
|
||||
* @param onChunkReceived - 接收流式响应数据的回调函数
|
||||
*/
|
||||
// 目前先按照函数来写,后续如果有需要到class的地方就改回来
|
||||
export async function transformMessagesAndFetch(
|
||||
request: OrchestrationRequest,
|
||||
onChunkReceived: (chunk: Chunk) => void
|
||||
) {
|
||||
const { messages, assistant } = request
|
||||
|
||||
try {
|
||||
const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant)
|
||||
|
||||
// replace prompt variables
|
||||
assistant.prompt = await replacePromptVariables(assistant.prompt, assistant.model?.name)
|
||||
|
||||
await fetchChatCompletion({
|
||||
messages: modelMessages,
|
||||
assistant: assistant,
|
||||
requestOptions: request.options,
|
||||
onChunkReceived,
|
||||
topicId: request.topicId,
|
||||
uiMessages
|
||||
})
|
||||
} catch (error: any) {
|
||||
onChunkReceived({ type: ChunkType.ERROR, error })
|
||||
}
|
||||
}
|
||||
@ -34,6 +34,10 @@ export interface StreamProcessorCallbacks {
|
||||
onLLMWebSearchInProgress?: () => void
|
||||
// LLM Web search complete
|
||||
onLLMWebSearchComplete?: (llmWebSearchResult: WebSearchResponse) => void
|
||||
// Get citation block ID
|
||||
getCitationBlockId?: () => string | null
|
||||
// Set citation block ID
|
||||
setCitationBlockId?: (blockId: string) => void
|
||||
// Image generation chunk received
|
||||
onImageCreated?: () => void
|
||||
onImageDelta?: (imageData: GenerateImageResponse) => void
|
||||
|
||||
@ -121,6 +121,11 @@ export const createCitationCallbacks = (deps: CitationCallbacksDependencies) =>
|
||||
},
|
||||
|
||||
// 暴露给外部的方法,用于textCallbacks中获取citationBlockId
|
||||
getCitationBlockId: () => citationBlockId
|
||||
getCitationBlockId: () => citationBlockId,
|
||||
|
||||
// 暴露给外部的方法,用于 KnowledgeService 中设置 citationBlockId
|
||||
setCitationBlockId: (blockId: string) => {
|
||||
citationBlockId = blockId
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,12 +2,11 @@ import { loggerService } from '@logger'
|
||||
import { AiSdkToChunkAdapter } from '@renderer/aiCore/chunk/AiSdkToChunkAdapter'
|
||||
import { AgentApiClient } from '@renderer/api/agent'
|
||||
import db from '@renderer/databases'
|
||||
import { fetchMessagesSummary } from '@renderer/services/ApiService'
|
||||
import { fetchMessagesSummary, transformMessagesAndFetch } from '@renderer/services/ApiService'
|
||||
import { DbService } from '@renderer/services/db/DbService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { BlockManager } from '@renderer/services/messageStreaming/BlockManager'
|
||||
import { createCallbacks } from '@renderer/services/messageStreaming/callbacks'
|
||||
import { transformMessagesAndFetch } from '@renderer/services/OrchestrateService'
|
||||
import { endSpan } from '@renderer/services/SpanManagerService'
|
||||
import { createStreamProcessor, type StreamProcessorCallbacks } from '@renderer/services/StreamProcessingService'
|
||||
import store from '@renderer/store'
|
||||
@ -814,6 +813,9 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
messages: messagesForContext,
|
||||
assistant,
|
||||
topicId,
|
||||
blockManager,
|
||||
assistantMsgId,
|
||||
callbacks,
|
||||
options: {
|
||||
signal: abortController.signal,
|
||||
timeout: 30000,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user