mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-05 12:29:44 +08:00
fix: support toolchoice for knowledge (#10763)
* fix: support toolchoice for knowledge * fix: ci
This commit is contained in:
parent
e46a45f409
commit
cafd40bc1c
@ -4,6 +4,8 @@ import type { MCPTool, Message, Model, Provider } from '@renderer/types'
|
|||||||
import type { Chunk } from '@renderer/types/chunk'
|
import type { Chunk } from '@renderer/types/chunk'
|
||||||
import { extractReasoningMiddleware, LanguageModelMiddleware, simulateStreamingMiddleware } from 'ai'
|
import { extractReasoningMiddleware, LanguageModelMiddleware, simulateStreamingMiddleware } from 'ai'
|
||||||
|
|
||||||
|
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
|
||||||
|
|
||||||
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
|
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -29,6 +31,8 @@ export interface AiSdkMiddlewareConfig {
|
|||||||
uiMessages?: Message[]
|
uiMessages?: Message[]
|
||||||
// 内置搜索配置
|
// 内置搜索配置
|
||||||
webSearchPluginConfig?: WebSearchPluginConfig
|
webSearchPluginConfig?: WebSearchPluginConfig
|
||||||
|
// 知识库识别开关,默认开启
|
||||||
|
knowledgeRecognition?: 'off' | 'on'
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -119,6 +123,15 @@ export class AiSdkMiddlewareBuilder {
|
|||||||
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] {
|
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelMiddleware[] {
|
||||||
const builder = new AiSdkMiddlewareBuilder()
|
const builder = new AiSdkMiddlewareBuilder()
|
||||||
|
|
||||||
|
// 0. 知识库强制调用中间件(必须在最前面,确保第一轮强制调用知识库)
|
||||||
|
if (config.knowledgeRecognition === 'off') {
|
||||||
|
builder.add({
|
||||||
|
name: 'force-knowledge-first',
|
||||||
|
middleware: toolChoiceMiddleware('builtin_knowledge_search')
|
||||||
|
})
|
||||||
|
logger.debug('Added toolChoice middleware to force knowledge base search on first round')
|
||||||
|
}
|
||||||
|
|
||||||
// 1. 根据provider添加特定中间件
|
// 1. 根据provider添加特定中间件
|
||||||
if (config.provider) {
|
if (config.provider) {
|
||||||
addProviderSpecificMiddlewares(builder, config)
|
addProviderSpecificMiddlewares(builder, config)
|
||||||
|
|||||||
45
src/renderer/src/aiCore/middleware/toolChoiceMiddleware.ts
Normal file
45
src/renderer/src/aiCore/middleware/toolChoiceMiddleware.ts
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import { loggerService } from '@logger'
|
||||||
|
import { LanguageModelMiddleware } from 'ai'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('toolChoiceMiddleware')
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tool Choice Middleware
|
||||||
|
* Controls tool selection strategy across multiple rounds of tool calls:
|
||||||
|
* - First round: Forces the model to call a specific tool (e.g., knowledge base search)
|
||||||
|
* - Subsequent rounds: Allows the model to automatically choose any available tool
|
||||||
|
*
|
||||||
|
* This ensures knowledge base is consulted first while still enabling MCP tools
|
||||||
|
* and other capabilities in follow-up interactions.
|
||||||
|
*
|
||||||
|
* @param forceFirstToolName - The tool name to force on the first round
|
||||||
|
* @returns LanguageModelMiddleware
|
||||||
|
*/
|
||||||
|
export function toolChoiceMiddleware(forceFirstToolName: string): LanguageModelMiddleware {
|
||||||
|
let toolCallRound = 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
middlewareVersion: 'v2',
|
||||||
|
|
||||||
|
transformParams: async ({ params }) => {
|
||||||
|
toolCallRound++
|
||||||
|
|
||||||
|
const transformedParams = { ...params }
|
||||||
|
|
||||||
|
if (toolCallRound === 1) {
|
||||||
|
// First round: force the specified tool
|
||||||
|
logger.debug(`Round ${toolCallRound}: Forcing tool choice to '${forceFirstToolName}'`)
|
||||||
|
transformedParams.toolChoice = {
|
||||||
|
type: 'tool',
|
||||||
|
toolName: forceFirstToolName
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Subsequent rounds: allow automatic tool selection
|
||||||
|
logger.debug(`Round ${toolCallRound}: Using automatic tool choice`)
|
||||||
|
transformedParams.toolChoice = { type: 'auto' }
|
||||||
|
}
|
||||||
|
|
||||||
|
return transformedParams
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -138,7 +138,8 @@ export async function fetchChatCompletion({
|
|||||||
enableGenerateImage: capabilities.enableGenerateImage,
|
enableGenerateImage: capabilities.enableGenerateImage,
|
||||||
enableUrlContext: capabilities.enableUrlContext,
|
enableUrlContext: capabilities.enableUrlContext,
|
||||||
mcpTools,
|
mcpTools,
|
||||||
uiMessages
|
uiMessages,
|
||||||
|
knowledgeRecognition: assistant.knowledgeRecognition
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Call AI Completions ---
|
// --- Call AI Completions ---
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user