mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
<type>: <subject>
<body> <footer> 用來簡要描述影響本次變動,概述即可
This commit is contained in:
parent
addd5ffdfa
commit
4b62384fc5
@ -42,6 +42,15 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
})
|
||||
}
|
||||
|
||||
createConfigureContextPlugin() {
|
||||
return definePlugin({
|
||||
name: '_internal_configureContext',
|
||||
configureContext: async (context: AiRequestContext) => {
|
||||
context.executor = this
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// === 高阶重载:直接使用模型 ===
|
||||
|
||||
/**
|
||||
@ -51,10 +60,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
model: LanguageModel,
|
||||
params: Omit<Parameters<typeof streamText>[0], 'model'>
|
||||
): Promise<ReturnType<typeof streamText>>
|
||||
|
||||
/**
|
||||
* 流式文本生成 - 使用modelId + 可选middleware(灵活用法)
|
||||
*/
|
||||
async streamText(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof streamText>[0], 'model'>,
|
||||
@ -62,10 +67,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamText>>
|
||||
|
||||
/**
|
||||
* 流式文本生成 - 内部实现(统一处理重载)
|
||||
*/
|
||||
async streamText(
|
||||
modelOrId: LanguageModel,
|
||||
params: Omit<Parameters<typeof streamText>[0], 'model'>,
|
||||
@ -73,7 +74,10 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamText>> {
|
||||
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
|
||||
// 2. 执行插件处理
|
||||
return this.pluginEngine.executeStreamWithPlugins(
|
||||
@ -102,10 +106,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
model: LanguageModel,
|
||||
params: Omit<Parameters<typeof generateText>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateText>>
|
||||
|
||||
/**
|
||||
* 生成文本 - 使用modelId + 可选middleware
|
||||
*/
|
||||
async generateText(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof generateText>[0], 'model'>,
|
||||
@ -113,10 +113,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateText>>
|
||||
|
||||
/**
|
||||
* 生成文本 - 内部实现
|
||||
*/
|
||||
async generateText(
|
||||
modelOrId: LanguageModel | string,
|
||||
params: Omit<Parameters<typeof generateText>[0], 'model'>,
|
||||
@ -124,7 +120,10 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateText>> {
|
||||
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateText',
|
||||
@ -143,10 +142,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
model: LanguageModel,
|
||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateObject>>
|
||||
|
||||
/**
|
||||
* 生成结构化对象 - 使用modelId + 可选middleware
|
||||
*/
|
||||
async generateObject(
|
||||
modelOrId: string,
|
||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
|
||||
@ -154,10 +149,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateObject>>
|
||||
|
||||
/**
|
||||
* 生成结构化对象 - 内部实现
|
||||
*/
|
||||
async generateObject(
|
||||
modelOrId: LanguageModel | string,
|
||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
|
||||
@ -165,7 +156,10 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateObject>> {
|
||||
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateObject',
|
||||
@ -184,10 +178,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
model: LanguageModel,
|
||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>
|
||||
): Promise<ReturnType<typeof streamObject>>
|
||||
|
||||
/**
|
||||
* 流式生成结构化对象 - 使用modelId + 可选middleware
|
||||
*/
|
||||
async streamObject(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
|
||||
@ -195,10 +185,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamObject>>
|
||||
|
||||
/**
|
||||
* 流式生成结构化对象 - 内部实现
|
||||
*/
|
||||
async streamObject(
|
||||
modelOrId: LanguageModel | string,
|
||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
|
||||
@ -206,7 +192,10 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamObject>> {
|
||||
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'streamObject',
|
||||
|
||||
@ -14,7 +14,13 @@ import type { ProviderSettingsMap } from './core/providers/types'
|
||||
import { createExecutor } from './core/runtime'
|
||||
|
||||
// ==================== 主要用户接口 ====================
|
||||
export { createExecutor, createOpenAICompatibleExecutor } from './core/runtime'
|
||||
export {
|
||||
createExecutor,
|
||||
createOpenAICompatibleExecutor,
|
||||
generateObject,
|
||||
generateText,
|
||||
streamText
|
||||
} from './core/runtime'
|
||||
|
||||
// ==================== 高级API ====================
|
||||
export { createModel } from './core/models'
|
||||
@ -139,11 +145,6 @@ export const AiCore = {
|
||||
return createExecutor(providerId, options, plugins)
|
||||
},
|
||||
|
||||
// 创建底层客户端(高级用法)
|
||||
createClient(providerId: ProviderId, options: ProviderSettingsMap[ProviderId], plugins: any[] = []) {
|
||||
return createExecutor(providerId, options, plugins)
|
||||
},
|
||||
|
||||
// 获取支持的providers
|
||||
getSupportedProviders() {
|
||||
return factoryGetSupportedProviders()
|
||||
|
||||
@ -153,11 +153,11 @@ export class AiSdkToChunkAdapter {
|
||||
break
|
||||
|
||||
// === 步骤相关事件 ===
|
||||
case 'start':
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_CREATED
|
||||
})
|
||||
break
|
||||
// case 'start':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||
// })
|
||||
// break
|
||||
// TODO: 需要区分接口开始和步骤开始
|
||||
// case 'start-step':
|
||||
// this.onChunk({
|
||||
|
||||
@ -21,7 +21,7 @@ import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-cor
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { cloneDeep } from 'lodash'
|
||||
|
||||
@ -30,6 +30,7 @@ import LegacyAiProvider from './index'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
|
||||
import { CompletionsResult } from './middleware/schemas'
|
||||
import reasoningTimePlugin from './plugins/reasoningTimePlugin'
|
||||
import { searchOrchestrationPlugin } from './plugins/searchOrchestrationPlugin'
|
||||
import { createAihubmixProvider } from './provider/aihubmix'
|
||||
import { getAiSdkProviderId } from './provider/factory'
|
||||
|
||||
@ -162,6 +163,7 @@ export default class ModernAiProvider {
|
||||
// 内置了默认搜索参数,如果改的话可以传config进去
|
||||
plugins.push(webSearchPlugin())
|
||||
}
|
||||
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant as Assistant))
|
||||
|
||||
// 2. 推理模型时添加推理插件
|
||||
if (middlewareConfig.enableReasoning) {
|
||||
|
||||
@ -3,7 +3,7 @@ import {
|
||||
LanguageModelV2Middleware,
|
||||
simulateStreamingMiddleware
|
||||
} from '@cherrystudio/ai-core'
|
||||
import type { BaseTool, Model, Provider } from '@renderer/types'
|
||||
import type { Assistant, BaseTool, Model, Provider } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
|
||||
/**
|
||||
@ -19,6 +19,8 @@ export interface AiSdkMiddlewareConfig {
|
||||
enableTool?: boolean
|
||||
enableWebSearch?: boolean
|
||||
mcpTools?: BaseTool[]
|
||||
// TODO assistant
|
||||
assistant?: Assistant
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
376
src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts
Normal file
376
src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts
Normal file
@ -0,0 +1,376 @@
|
||||
/**
|
||||
* 搜索编排插件
|
||||
*
|
||||
* 功能:
|
||||
* 1. onRequestStart: 智能意图识别 - 分析是否需要网络搜索、知识库搜索、记忆搜索
|
||||
* 2. transformParams: 根据意图分析结果动态添加对应的工具
|
||||
* 3. onRequestEnd: 自动记忆存储
|
||||
*/
|
||||
import type { AiRequestContext } from '@cherrystudio/ai-core'
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
import { RuntimeExecutor } from '@cherrystudio/ai-core/core/runtime/executor'
|
||||
// import { generateObject } from '@cherrystudio/ai-core'
|
||||
import {
|
||||
SEARCH_SUMMARY_PROMPT,
|
||||
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
|
||||
SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
} from '@renderer/config/prompts'
|
||||
import { getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import store from '@renderer/store'
|
||||
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
||||
import type { Assistant, Message } from '@renderer/types'
|
||||
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { isEmpty } from 'lodash'
|
||||
import { z } from 'zod'
|
||||
|
||||
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||
import { memorySearchTool } from '../tools/MemorySearchTool'
|
||||
import { webSearchTool } from '../tools/WebSearchTool'
|
||||
|
||||
// const getMessageContent = (message: Message) => {
|
||||
// if (typeof message.content === 'string') return message.content
|
||||
// return message.content.reduce((acc, part) => {
|
||||
// if (part.type === 'text') {
|
||||
// return acc + part.text + '\n'
|
||||
// }
|
||||
// return acc
|
||||
// }, '')
|
||||
// }
|
||||
|
||||
// === Schema Definitions ===
|
||||
|
||||
const WebSearchSchema = z.object({
|
||||
question: z
|
||||
.array(z.string())
|
||||
.describe('Search queries for web search. Use "not_needed" if no web search is required.'),
|
||||
links: z.array(z.string()).optional().describe('Specific URLs to search or summarize if mentioned in the query.')
|
||||
})
|
||||
|
||||
const KnowledgeSearchSchema = z.object({
|
||||
question: z
|
||||
.array(z.string())
|
||||
.describe('Search queries for knowledge base. Use "not_needed" if no knowledge search is required.'),
|
||||
rewrite: z
|
||||
.string()
|
||||
.describe('Rewritten query with alternative phrasing while preserving original intent and meaning.')
|
||||
})
|
||||
|
||||
const SearchIntentAnalysisSchema = z.object({
|
||||
websearch: WebSearchSchema.optional().describe('Web search intent analysis results.'),
|
||||
knowledge: KnowledgeSearchSchema.optional().describe('Knowledge base search intent analysis results.')
|
||||
})
|
||||
|
||||
type SearchIntentResult = z.infer<typeof SearchIntentAnalysisSchema>
|
||||
|
||||
/**
|
||||
* 🧠 意图分析函数 - 使用结构化输出重构
|
||||
*/
|
||||
async function analyzeSearchIntent(
|
||||
lastUserMessage: Message,
|
||||
assistant: Assistant,
|
||||
options: {
|
||||
shouldWebSearch?: boolean
|
||||
shouldKnowledgeSearch?: boolean
|
||||
shouldMemorySearch?: boolean
|
||||
lastAnswer?: Message
|
||||
context?:
|
||||
| AiRequestContext
|
||||
| {
|
||||
executor: RuntimeExecutor
|
||||
}
|
||||
} = {}
|
||||
): Promise<SearchIntentResult | undefined> {
|
||||
const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer, context } = options
|
||||
|
||||
if (!lastUserMessage) return undefined
|
||||
|
||||
// 根据配置决定是否需要提取
|
||||
const needWebExtract = shouldWebSearch
|
||||
const needKnowledgeExtract = shouldKnowledgeSearch
|
||||
|
||||
if (!needWebExtract && !needKnowledgeExtract) return undefined
|
||||
|
||||
// 选择合适的提示词和schema
|
||||
let prompt: string
|
||||
let schema: z.Schema
|
||||
|
||||
if (needWebExtract && !needKnowledgeExtract) {
|
||||
prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
schema = z.object({ websearch: WebSearchSchema })
|
||||
} else if (!needWebExtract && needKnowledgeExtract) {
|
||||
prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
|
||||
schema = z.object({ knowledge: KnowledgeSearchSchema })
|
||||
} else {
|
||||
prompt = SEARCH_SUMMARY_PROMPT
|
||||
schema = SearchIntentAnalysisSchema
|
||||
}
|
||||
|
||||
// 构建消息上下文
|
||||
const messages = lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage]
|
||||
console.log('messagesmessagesmessagesmessagesmessagesmessagesmessages', messages)
|
||||
// 格式化消息为提示词期望的格式
|
||||
// const chatHistory =
|
||||
// messages.length > 1
|
||||
// ? messages
|
||||
// .slice(0, -1)
|
||||
// .map((msg) => `${msg.role}: ${getMainTextContent(msg)}`)
|
||||
// .join('\n')
|
||||
// : ''
|
||||
// const question = getMainTextContent(lastUserMessage) || ''
|
||||
|
||||
// // 使用模板替换变量
|
||||
// const formattedPrompt = prompt.replace('{chat_history}', chatHistory).replace('{question}', question)
|
||||
|
||||
// 获取模型和provider信息
|
||||
const model = assistant.model || getDefaultModel()
|
||||
const provider = getProviderByModel(model)
|
||||
|
||||
if (!provider || isEmpty(provider.apiKey)) {
|
||||
console.error('Provider not found or missing API key')
|
||||
return getFallbackResult()
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await context?.executor?.generateObject(model.id, { schema, prompt })
|
||||
console.log('result', context)
|
||||
const parsedResult = result?.object as SearchIntentResult
|
||||
|
||||
// 根据需求过滤结果
|
||||
return {
|
||||
websearch: needWebExtract ? parsedResult?.websearch : undefined,
|
||||
knowledge: needKnowledgeExtract ? parsedResult?.knowledge : undefined
|
||||
}
|
||||
} catch (e: any) {
|
||||
console.error('analyze search intent error', e)
|
||||
return getFallbackResult()
|
||||
}
|
||||
|
||||
function getFallbackResult(): SearchIntentResult {
|
||||
const fallbackContent = getMainTextContent(lastUserMessage)
|
||||
return {
|
||||
websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
|
||||
knowledge: shouldKnowledgeSearch
|
||||
? {
|
||||
question: [fallbackContent || 'search'],
|
||||
rewrite: fallbackContent || 'search'
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 🧠 记忆存储函数 - 基于注释代码中的 processConversationMemory
|
||||
*/
|
||||
async function storeConversationMemory(messages: Message[], assistant: Assistant): Promise<void> {
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
|
||||
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||
console.log('Memory storage is disabled')
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const memoryConfig = selectMemoryConfig(store.getState())
|
||||
|
||||
// 转换消息为记忆处理器期望的格式
|
||||
const conversationMessages = messages
|
||||
.filter((msg) => msg.role === 'user' || msg.role === 'assistant')
|
||||
.map((msg) => ({
|
||||
role: msg.role as 'user' | 'assistant',
|
||||
content: getMainTextContent(msg) || ''
|
||||
}))
|
||||
.filter((msg) => msg.content.trim().length > 0)
|
||||
|
||||
if (conversationMessages.length < 2) {
|
||||
console.log('Need at least a user message and assistant response for memory processing')
|
||||
return
|
||||
}
|
||||
|
||||
const currentUserId = selectCurrentUserId(store.getState())
|
||||
const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||
|
||||
const processorConfig = MemoryProcessor.getProcessorConfig(
|
||||
memoryConfig,
|
||||
assistant.id,
|
||||
currentUserId,
|
||||
lastUserMessage?.id
|
||||
)
|
||||
|
||||
console.log('Processing conversation memory...', { messageCount: conversationMessages.length })
|
||||
|
||||
// 后台处理对话记忆(不阻塞 UI)
|
||||
const memoryProcessor = new MemoryProcessor()
|
||||
memoryProcessor
|
||||
.processConversation(conversationMessages, processorConfig)
|
||||
.then((result) => {
|
||||
console.log('Memory processing completed:', result)
|
||||
if (result.facts?.length > 0) {
|
||||
console.log('Extracted facts from conversation:', result.facts)
|
||||
console.log('Memory operations performed:', result.operations)
|
||||
} else {
|
||||
console.log('No facts extracted from conversation')
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Background memory processing failed:', error)
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Error in conversation memory processing:', error)
|
||||
// 不抛出错误,避免影响主流程
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 🎯 搜索编排插件
|
||||
*/
|
||||
export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
// 存储意图分析结果
|
||||
const intentAnalysisResults: { [requestId: string]: SearchIntentResult } = {}
|
||||
const userMessages: { [requestId: string]: Message } = {}
|
||||
console.log('searchOrchestrationPlugin', assistant)
|
||||
return definePlugin({
|
||||
name: 'search-orchestration',
|
||||
enforce: 'pre', // 确保在其他插件之前执行
|
||||
|
||||
/**
|
||||
* 🔍 Step 1: 意图识别阶段
|
||||
*/
|
||||
onRequestStart: async (context: AiRequestContext) => {
|
||||
console.log('🧠 [SearchOrchestration] Starting intent analysis...', context.requestId)
|
||||
|
||||
try {
|
||||
// 从参数中提取信息
|
||||
const messages = context.originalParams.messages
|
||||
|
||||
if (!messages || messages.length === 0) {
|
||||
console.log('🧠 [SearchOrchestration] No messages found, skipping analysis')
|
||||
return
|
||||
}
|
||||
|
||||
const lastUserMessage = messages[messages.length - 1]
|
||||
const lastAssistantMessage = messages.length >= 2 ? messages[messages.length - 2] : undefined
|
||||
|
||||
// 存储用户消息用于后续记忆存储
|
||||
userMessages[context.requestId] = lastUserMessage
|
||||
|
||||
// 判断是否需要各种搜索
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
|
||||
const shouldWebSearch = !!assistant.webSearchProviderId
|
||||
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
|
||||
const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory
|
||||
|
||||
console.log('🧠 [SearchOrchestration] Search capabilities:', {
|
||||
shouldWebSearch,
|
||||
shouldKnowledgeSearch,
|
||||
shouldMemorySearch
|
||||
})
|
||||
|
||||
// 执行意图分析
|
||||
if (shouldWebSearch || shouldKnowledgeSearch) {
|
||||
const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, {
|
||||
shouldWebSearch,
|
||||
shouldKnowledgeSearch,
|
||||
shouldMemorySearch,
|
||||
lastAnswer: lastAssistantMessage,
|
||||
context
|
||||
})
|
||||
|
||||
if (analysisResult) {
|
||||
intentAnalysisResults[context.requestId] = analysisResult
|
||||
console.log('🧠 [SearchOrchestration] Intent analysis completed:', analysisResult)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('🧠 [SearchOrchestration] Intent analysis failed:', error)
|
||||
// 不抛出错误,让流程继续
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 🔧 Step 2: 工具配置阶段
|
||||
*/
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
console.log('🔧 [SearchOrchestration] Configuring tools based on intent...', context.requestId)
|
||||
|
||||
try {
|
||||
const analysisResult = intentAnalysisResults[context.requestId]
|
||||
console.log('analysisResult', analysisResult)
|
||||
if (!analysisResult || !assistant) {
|
||||
console.log('🔧 [SearchOrchestration] No analysis result or assistant, skipping tool configuration')
|
||||
return params
|
||||
}
|
||||
|
||||
// 确保 tools 对象存在
|
||||
if (!params.tools) {
|
||||
params.tools = {}
|
||||
}
|
||||
|
||||
// 🌐 网络搜索工具配置
|
||||
if (analysisResult.websearch && assistant.webSearchProviderId) {
|
||||
const needsSearch = analysisResult.websearch.question && analysisResult.websearch.question[0] !== 'not_needed'
|
||||
|
||||
if (needsSearch) {
|
||||
console.log('🌐 [SearchOrchestration] Adding web search tool')
|
||||
params.tools['builtin_web_search'] = webSearchTool(assistant.webSearchProviderId)
|
||||
}
|
||||
}
|
||||
|
||||
// 📚 知识库搜索工具配置
|
||||
if (analysisResult.knowledge) {
|
||||
const needsKnowledgeSearch =
|
||||
analysisResult.knowledge.question && analysisResult.knowledge.question[0] !== 'not_needed'
|
||||
|
||||
if (needsKnowledgeSearch) {
|
||||
console.log('📚 [SearchOrchestration] Adding knowledge search tool')
|
||||
// TODO: 添加知识库搜索工具
|
||||
// params.tools['builtin_knowledge_search'] = knowledgeSearchTool(assistant.knowledge_bases)
|
||||
}
|
||||
}
|
||||
|
||||
// 🧠 记忆搜索工具配置
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
if (globalMemoryEnabled && assistant.enableMemory) {
|
||||
console.log('🧠 [SearchOrchestration] Adding memory search tool')
|
||||
params.tools['builtin_memory_search'] = memorySearchTool()
|
||||
}
|
||||
|
||||
console.log('🔧 [SearchOrchestration] Tools configured:', Object.keys(params.tools))
|
||||
return params
|
||||
} catch (error) {
|
||||
console.error('🔧 [SearchOrchestration] Tool configuration failed:', error)
|
||||
return params
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 💾 Step 3: 记忆存储阶段
|
||||
*/
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
onRequestEnd: async (context: AiRequestContext, _result: any) => {
|
||||
console.log('💾 [SearchOrchestration] Starting memory storage...', context.requestId)
|
||||
|
||||
try {
|
||||
const assistant = context.originalParams.assistant as Assistant
|
||||
const messages = context.originalParams.messages as Message[]
|
||||
|
||||
if (messages && assistant) {
|
||||
await storeConversationMemory(messages, assistant)
|
||||
}
|
||||
|
||||
// 清理缓存
|
||||
delete intentAnalysisResults[context.requestId]
|
||||
delete userMessages[context.requestId]
|
||||
} catch (error) {
|
||||
console.error('💾 [SearchOrchestration] Memory storage failed:', error)
|
||||
// 不抛出错误,避免影响主流程
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export default searchOrchestrationPlugin
|
||||
@ -50,7 +50,7 @@ import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { webSearchTool } from './tools/WebSearchTool'
|
||||
// import { webSearchTool } from './tools/WebSearchTool'
|
||||
// import { jsonSchemaToZod } from 'json-schema-to-zod'
|
||||
import { setupToolsConfig } from './utils/mcp'
|
||||
import { buildProviderOptions } from './utils/options'
|
||||
@ -298,9 +298,9 @@ export async function buildStreamTextParams(
|
||||
enableToolUse: enableTools
|
||||
})
|
||||
|
||||
if (webSearchProviderId) {
|
||||
tools['builtin_web_search'] = webSearchTool(webSearchProviderId)
|
||||
}
|
||||
// if (webSearchProviderId) {
|
||||
// tools['builtin_web_search'] = webSearchTool(webSearchProviderId)
|
||||
// }
|
||||
|
||||
// 构建真正的 providerOptions
|
||||
const providerOptions = buildProviderOptions(assistant, model, provider, {
|
||||
|
||||
@ -16,7 +16,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import store from '@renderer/store'
|
||||
import { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
|
||||
import { type Chunk } from '@renderer/types/chunk'
|
||||
import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import { SdkModel } from '@renderer/types/sdk'
|
||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||
@ -432,14 +432,16 @@ export async function fetchChatCompletion({
|
||||
enableReasoning: capabilities.enableReasoning,
|
||||
enableTool: assistant.settings?.toolUseMode === 'prompt',
|
||||
enableWebSearch: capabilities.enableWebSearch,
|
||||
mcpTools
|
||||
mcpTools,
|
||||
assistant
|
||||
}
|
||||
// --- Call AI Completions ---
|
||||
// onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
await AI.completions(modelId, aiSdkParams, middlewareConfig)
|
||||
// if (enableWebSearch) {
|
||||
// if (capabilities.enableWebSearch) {
|
||||
// onChunkReceived({ type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS })
|
||||
// }
|
||||
// --- Call AI Completions ---
|
||||
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
await AI.completions(modelId, aiSdkParams, middlewareConfig)
|
||||
|
||||
// await AI.completions(
|
||||
// {
|
||||
// callType: 'chat',
|
||||
|
||||
28
yarn.lock
28
yarn.lock
@ -148,6 +148,18 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/openai-compatible@npm:1.0.0-beta.8":
|
||||
version: 1.0.0-beta.8
|
||||
resolution: "@ai-sdk/openai-compatible@npm:1.0.0-beta.8"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0-beta.1"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.0-beta.5"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4
|
||||
checksum: 10c0/047f044bf0da9608e09073957916373bd39760ec00f498ba0c4a597ec70ba9eb4ef31f06b21b363b3c1ba775f64fcc46d41b60a171e0e99250824817ecb19ba8
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/openai@npm:2.0.0-beta.9":
|
||||
version: 2.0.0-beta.9
|
||||
resolution: "@ai-sdk/openai@npm:2.0.0-beta.9"
|
||||
@ -188,6 +200,20 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/provider-utils@npm:3.0.0-beta.5":
|
||||
version: 3.0.0-beta.5
|
||||
resolution: "@ai-sdk/provider-utils@npm:3.0.0-beta.5"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0-beta.1"
|
||||
"@standard-schema/spec": "npm:^1.0.0"
|
||||
eventsource-parser: "npm:^3.0.3"
|
||||
zod-to-json-schema: "npm:^3.24.1"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4
|
||||
checksum: 10c0/229a53672accc5d9d986da2e18f619dbcfaf64ab269c8cc9e955480c4428d2a87255330c587453d01eb66ac297bb6975f91c24a93f87dd4b84f6428cb60d4211
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/provider@npm:2.0.0-beta.1":
|
||||
version: 2.0.0-beta.1
|
||||
resolution: "@ai-sdk/provider@npm:2.0.0-beta.1"
|
||||
@ -1251,7 +1277,7 @@ __metadata:
|
||||
"@ai-sdk/deepseek": "npm:1.0.0-beta.6"
|
||||
"@ai-sdk/google": "npm:2.0.0-beta.11"
|
||||
"@ai-sdk/openai": "npm:2.0.0-beta.9"
|
||||
"@ai-sdk/openai-compatible": "npm:1.0.0-beta.6"
|
||||
"@ai-sdk/openai-compatible": "npm:1.0.0-beta.8"
|
||||
"@ai-sdk/provider": "npm:2.0.0-beta.1"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.0-beta.3"
|
||||
"@ai-sdk/xai": "npm:2.0.0-beta.8"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user