mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-07 13:59:28 +08:00
<type>: <subject>
<body> <footer> 用來簡要描述影響本次變動,概述即可
This commit is contained in:
parent
addd5ffdfa
commit
4b62384fc5
@ -42,6 +42,15 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
createConfigureContextPlugin() {
|
||||||
|
return definePlugin({
|
||||||
|
name: '_internal_configureContext',
|
||||||
|
configureContext: async (context: AiRequestContext) => {
|
||||||
|
context.executor = this
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// === 高阶重载:直接使用模型 ===
|
// === 高阶重载:直接使用模型 ===
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -51,10 +60,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
model: LanguageModel,
|
model: LanguageModel,
|
||||||
params: Omit<Parameters<typeof streamText>[0], 'model'>
|
params: Omit<Parameters<typeof streamText>[0], 'model'>
|
||||||
): Promise<ReturnType<typeof streamText>>
|
): Promise<ReturnType<typeof streamText>>
|
||||||
|
|
||||||
/**
|
|
||||||
* 流式文本生成 - 使用modelId + 可选middleware(灵活用法)
|
|
||||||
*/
|
|
||||||
async streamText(
|
async streamText(
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: Omit<Parameters<typeof streamText>[0], 'model'>,
|
params: Omit<Parameters<typeof streamText>[0], 'model'>,
|
||||||
@ -62,10 +67,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
}
|
}
|
||||||
): Promise<ReturnType<typeof streamText>>
|
): Promise<ReturnType<typeof streamText>>
|
||||||
|
|
||||||
/**
|
|
||||||
* 流式文本生成 - 内部实现(统一处理重载)
|
|
||||||
*/
|
|
||||||
async streamText(
|
async streamText(
|
||||||
modelOrId: LanguageModel,
|
modelOrId: LanguageModel,
|
||||||
params: Omit<Parameters<typeof streamText>[0], 'model'>,
|
params: Omit<Parameters<typeof streamText>[0], 'model'>,
|
||||||
@ -73,7 +74,10 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
}
|
}
|
||||||
): Promise<ReturnType<typeof streamText>> {
|
): Promise<ReturnType<typeof streamText>> {
|
||||||
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
this.pluginEngine.usePlugins([
|
||||||
|
this.createResolveModelPlugin(options?.middlewares),
|
||||||
|
this.createConfigureContextPlugin()
|
||||||
|
])
|
||||||
|
|
||||||
// 2. 执行插件处理
|
// 2. 执行插件处理
|
||||||
return this.pluginEngine.executeStreamWithPlugins(
|
return this.pluginEngine.executeStreamWithPlugins(
|
||||||
@ -102,10 +106,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
model: LanguageModel,
|
model: LanguageModel,
|
||||||
params: Omit<Parameters<typeof generateText>[0], 'model'>
|
params: Omit<Parameters<typeof generateText>[0], 'model'>
|
||||||
): Promise<ReturnType<typeof generateText>>
|
): Promise<ReturnType<typeof generateText>>
|
||||||
|
|
||||||
/**
|
|
||||||
* 生成文本 - 使用modelId + 可选middleware
|
|
||||||
*/
|
|
||||||
async generateText(
|
async generateText(
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: Omit<Parameters<typeof generateText>[0], 'model'>,
|
params: Omit<Parameters<typeof generateText>[0], 'model'>,
|
||||||
@ -113,10 +113,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
}
|
}
|
||||||
): Promise<ReturnType<typeof generateText>>
|
): Promise<ReturnType<typeof generateText>>
|
||||||
|
|
||||||
/**
|
|
||||||
* 生成文本 - 内部实现
|
|
||||||
*/
|
|
||||||
async generateText(
|
async generateText(
|
||||||
modelOrId: LanguageModel | string,
|
modelOrId: LanguageModel | string,
|
||||||
params: Omit<Parameters<typeof generateText>[0], 'model'>,
|
params: Omit<Parameters<typeof generateText>[0], 'model'>,
|
||||||
@ -124,7 +120,10 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
}
|
}
|
||||||
): Promise<ReturnType<typeof generateText>> {
|
): Promise<ReturnType<typeof generateText>> {
|
||||||
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
this.pluginEngine.usePlugins([
|
||||||
|
this.createResolveModelPlugin(options?.middlewares),
|
||||||
|
this.createConfigureContextPlugin()
|
||||||
|
])
|
||||||
|
|
||||||
return this.pluginEngine.executeWithPlugins(
|
return this.pluginEngine.executeWithPlugins(
|
||||||
'generateText',
|
'generateText',
|
||||||
@ -143,10 +142,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
model: LanguageModel,
|
model: LanguageModel,
|
||||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>
|
params: Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||||
): Promise<ReturnType<typeof generateObject>>
|
): Promise<ReturnType<typeof generateObject>>
|
||||||
|
|
||||||
/**
|
|
||||||
* 生成结构化对象 - 使用modelId + 可选middleware
|
|
||||||
*/
|
|
||||||
async generateObject(
|
async generateObject(
|
||||||
modelOrId: string,
|
modelOrId: string,
|
||||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
|
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
|
||||||
@ -154,10 +149,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
}
|
}
|
||||||
): Promise<ReturnType<typeof generateObject>>
|
): Promise<ReturnType<typeof generateObject>>
|
||||||
|
|
||||||
/**
|
|
||||||
* 生成结构化对象 - 内部实现
|
|
||||||
*/
|
|
||||||
async generateObject(
|
async generateObject(
|
||||||
modelOrId: LanguageModel | string,
|
modelOrId: LanguageModel | string,
|
||||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
|
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
|
||||||
@ -165,7 +156,10 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
}
|
}
|
||||||
): Promise<ReturnType<typeof generateObject>> {
|
): Promise<ReturnType<typeof generateObject>> {
|
||||||
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
this.pluginEngine.usePlugins([
|
||||||
|
this.createResolveModelPlugin(options?.middlewares),
|
||||||
|
this.createConfigureContextPlugin()
|
||||||
|
])
|
||||||
|
|
||||||
return this.pluginEngine.executeWithPlugins(
|
return this.pluginEngine.executeWithPlugins(
|
||||||
'generateObject',
|
'generateObject',
|
||||||
@ -184,10 +178,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
model: LanguageModel,
|
model: LanguageModel,
|
||||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>
|
params: Omit<Parameters<typeof streamObject>[0], 'model'>
|
||||||
): Promise<ReturnType<typeof streamObject>>
|
): Promise<ReturnType<typeof streamObject>>
|
||||||
|
|
||||||
/**
|
|
||||||
* 流式生成结构化对象 - 使用modelId + 可选middleware
|
|
||||||
*/
|
|
||||||
async streamObject(
|
async streamObject(
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
|
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
|
||||||
@ -195,10 +185,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
}
|
}
|
||||||
): Promise<ReturnType<typeof streamObject>>
|
): Promise<ReturnType<typeof streamObject>>
|
||||||
|
|
||||||
/**
|
|
||||||
* 流式生成结构化对象 - 内部实现
|
|
||||||
*/
|
|
||||||
async streamObject(
|
async streamObject(
|
||||||
modelOrId: LanguageModel | string,
|
modelOrId: LanguageModel | string,
|
||||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
|
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
|
||||||
@ -206,7 +192,10 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
}
|
}
|
||||||
): Promise<ReturnType<typeof streamObject>> {
|
): Promise<ReturnType<typeof streamObject>> {
|
||||||
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
this.pluginEngine.usePlugins([
|
||||||
|
this.createResolveModelPlugin(options?.middlewares),
|
||||||
|
this.createConfigureContextPlugin()
|
||||||
|
])
|
||||||
|
|
||||||
return this.pluginEngine.executeWithPlugins(
|
return this.pluginEngine.executeWithPlugins(
|
||||||
'streamObject',
|
'streamObject',
|
||||||
|
|||||||
@ -14,7 +14,13 @@ import type { ProviderSettingsMap } from './core/providers/types'
|
|||||||
import { createExecutor } from './core/runtime'
|
import { createExecutor } from './core/runtime'
|
||||||
|
|
||||||
// ==================== 主要用户接口 ====================
|
// ==================== 主要用户接口 ====================
|
||||||
export { createExecutor, createOpenAICompatibleExecutor } from './core/runtime'
|
export {
|
||||||
|
createExecutor,
|
||||||
|
createOpenAICompatibleExecutor,
|
||||||
|
generateObject,
|
||||||
|
generateText,
|
||||||
|
streamText
|
||||||
|
} from './core/runtime'
|
||||||
|
|
||||||
// ==================== 高级API ====================
|
// ==================== 高级API ====================
|
||||||
export { createModel } from './core/models'
|
export { createModel } from './core/models'
|
||||||
@ -139,11 +145,6 @@ export const AiCore = {
|
|||||||
return createExecutor(providerId, options, plugins)
|
return createExecutor(providerId, options, plugins)
|
||||||
},
|
},
|
||||||
|
|
||||||
// 创建底层客户端(高级用法)
|
|
||||||
createClient(providerId: ProviderId, options: ProviderSettingsMap[ProviderId], plugins: any[] = []) {
|
|
||||||
return createExecutor(providerId, options, plugins)
|
|
||||||
},
|
|
||||||
|
|
||||||
// 获取支持的providers
|
// 获取支持的providers
|
||||||
getSupportedProviders() {
|
getSupportedProviders() {
|
||||||
return factoryGetSupportedProviders()
|
return factoryGetSupportedProviders()
|
||||||
|
|||||||
@ -153,11 +153,11 @@ export class AiSdkToChunkAdapter {
|
|||||||
break
|
break
|
||||||
|
|
||||||
// === 步骤相关事件 ===
|
// === 步骤相关事件 ===
|
||||||
case 'start':
|
// case 'start':
|
||||||
this.onChunk({
|
// this.onChunk({
|
||||||
type: ChunkType.LLM_RESPONSE_CREATED
|
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||||
})
|
// })
|
||||||
break
|
// break
|
||||||
// TODO: 需要区分接口开始和步骤开始
|
// TODO: 需要区分接口开始和步骤开始
|
||||||
// case 'start-step':
|
// case 'start-step':
|
||||||
// this.onChunk({
|
// this.onChunk({
|
||||||
|
|||||||
@ -21,7 +21,7 @@ import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-cor
|
|||||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||||
import { formatApiHost } from '@renderer/utils/api'
|
import { formatApiHost } from '@renderer/utils/api'
|
||||||
import { cloneDeep } from 'lodash'
|
import { cloneDeep } from 'lodash'
|
||||||
|
|
||||||
@ -30,6 +30,7 @@ import LegacyAiProvider from './index'
|
|||||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
|
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
|
||||||
import { CompletionsResult } from './middleware/schemas'
|
import { CompletionsResult } from './middleware/schemas'
|
||||||
import reasoningTimePlugin from './plugins/reasoningTimePlugin'
|
import reasoningTimePlugin from './plugins/reasoningTimePlugin'
|
||||||
|
import { searchOrchestrationPlugin } from './plugins/searchOrchestrationPlugin'
|
||||||
import { createAihubmixProvider } from './provider/aihubmix'
|
import { createAihubmixProvider } from './provider/aihubmix'
|
||||||
import { getAiSdkProviderId } from './provider/factory'
|
import { getAiSdkProviderId } from './provider/factory'
|
||||||
|
|
||||||
@ -162,6 +163,7 @@ export default class ModernAiProvider {
|
|||||||
// 内置了默认搜索参数,如果改的话可以传config进去
|
// 内置了默认搜索参数,如果改的话可以传config进去
|
||||||
plugins.push(webSearchPlugin())
|
plugins.push(webSearchPlugin())
|
||||||
}
|
}
|
||||||
|
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant as Assistant))
|
||||||
|
|
||||||
// 2. 推理模型时添加推理插件
|
// 2. 推理模型时添加推理插件
|
||||||
if (middlewareConfig.enableReasoning) {
|
if (middlewareConfig.enableReasoning) {
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import {
|
|||||||
LanguageModelV2Middleware,
|
LanguageModelV2Middleware,
|
||||||
simulateStreamingMiddleware
|
simulateStreamingMiddleware
|
||||||
} from '@cherrystudio/ai-core'
|
} from '@cherrystudio/ai-core'
|
||||||
import type { BaseTool, Model, Provider } from '@renderer/types'
|
import type { Assistant, BaseTool, Model, Provider } from '@renderer/types'
|
||||||
import type { Chunk } from '@renderer/types/chunk'
|
import type { Chunk } from '@renderer/types/chunk'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -19,6 +19,8 @@ export interface AiSdkMiddlewareConfig {
|
|||||||
enableTool?: boolean
|
enableTool?: boolean
|
||||||
enableWebSearch?: boolean
|
enableWebSearch?: boolean
|
||||||
mcpTools?: BaseTool[]
|
mcpTools?: BaseTool[]
|
||||||
|
// TODO assistant
|
||||||
|
assistant?: Assistant
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
376
src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts
Normal file
376
src/renderer/src/aiCore/plugins/searchOrchestrationPlugin.ts
Normal file
@ -0,0 +1,376 @@
|
|||||||
|
/**
|
||||||
|
* 搜索编排插件
|
||||||
|
*
|
||||||
|
* 功能:
|
||||||
|
* 1. onRequestStart: 智能意图识别 - 分析是否需要网络搜索、知识库搜索、记忆搜索
|
||||||
|
* 2. transformParams: 根据意图分析结果动态添加对应的工具
|
||||||
|
* 3. onRequestEnd: 自动记忆存储
|
||||||
|
*/
|
||||||
|
import type { AiRequestContext } from '@cherrystudio/ai-core'
|
||||||
|
import { definePlugin } from '@cherrystudio/ai-core'
|
||||||
|
import { RuntimeExecutor } from '@cherrystudio/ai-core/core/runtime/executor'
|
||||||
|
// import { generateObject } from '@cherrystudio/ai-core'
|
||||||
|
import {
|
||||||
|
SEARCH_SUMMARY_PROMPT,
|
||||||
|
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
|
||||||
|
SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||||
|
} from '@renderer/config/prompts'
|
||||||
|
import { getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
|
import store from '@renderer/store'
|
||||||
|
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
||||||
|
import type { Assistant, Message } from '@renderer/types'
|
||||||
|
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||||
|
import { isEmpty } from 'lodash'
|
||||||
|
import { z } from 'zod'
|
||||||
|
|
||||||
|
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||||
|
import { memorySearchTool } from '../tools/MemorySearchTool'
|
||||||
|
import { webSearchTool } from '../tools/WebSearchTool'
|
||||||
|
|
||||||
|
// const getMessageContent = (message: Message) => {
|
||||||
|
// if (typeof message.content === 'string') return message.content
|
||||||
|
// return message.content.reduce((acc, part) => {
|
||||||
|
// if (part.type === 'text') {
|
||||||
|
// return acc + part.text + '\n'
|
||||||
|
// }
|
||||||
|
// return acc
|
||||||
|
// }, '')
|
||||||
|
// }
|
||||||
|
|
||||||
|
// === Schema Definitions ===
|
||||||
|
|
||||||
|
const WebSearchSchema = z.object({
|
||||||
|
question: z
|
||||||
|
.array(z.string())
|
||||||
|
.describe('Search queries for web search. Use "not_needed" if no web search is required.'),
|
||||||
|
links: z.array(z.string()).optional().describe('Specific URLs to search or summarize if mentioned in the query.')
|
||||||
|
})
|
||||||
|
|
||||||
|
const KnowledgeSearchSchema = z.object({
|
||||||
|
question: z
|
||||||
|
.array(z.string())
|
||||||
|
.describe('Search queries for knowledge base. Use "not_needed" if no knowledge search is required.'),
|
||||||
|
rewrite: z
|
||||||
|
.string()
|
||||||
|
.describe('Rewritten query with alternative phrasing while preserving original intent and meaning.')
|
||||||
|
})
|
||||||
|
|
||||||
|
const SearchIntentAnalysisSchema = z.object({
|
||||||
|
websearch: WebSearchSchema.optional().describe('Web search intent analysis results.'),
|
||||||
|
knowledge: KnowledgeSearchSchema.optional().describe('Knowledge base search intent analysis results.')
|
||||||
|
})
|
||||||
|
|
||||||
|
type SearchIntentResult = z.infer<typeof SearchIntentAnalysisSchema>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 🧠 意图分析函数 - 使用结构化输出重构
|
||||||
|
*/
|
||||||
|
async function analyzeSearchIntent(
|
||||||
|
lastUserMessage: Message,
|
||||||
|
assistant: Assistant,
|
||||||
|
options: {
|
||||||
|
shouldWebSearch?: boolean
|
||||||
|
shouldKnowledgeSearch?: boolean
|
||||||
|
shouldMemorySearch?: boolean
|
||||||
|
lastAnswer?: Message
|
||||||
|
context?:
|
||||||
|
| AiRequestContext
|
||||||
|
| {
|
||||||
|
executor: RuntimeExecutor
|
||||||
|
}
|
||||||
|
} = {}
|
||||||
|
): Promise<SearchIntentResult | undefined> {
|
||||||
|
const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer, context } = options
|
||||||
|
|
||||||
|
if (!lastUserMessage) return undefined
|
||||||
|
|
||||||
|
// 根据配置决定是否需要提取
|
||||||
|
const needWebExtract = shouldWebSearch
|
||||||
|
const needKnowledgeExtract = shouldKnowledgeSearch
|
||||||
|
|
||||||
|
if (!needWebExtract && !needKnowledgeExtract) return undefined
|
||||||
|
|
||||||
|
// 选择合适的提示词和schema
|
||||||
|
let prompt: string
|
||||||
|
let schema: z.Schema
|
||||||
|
|
||||||
|
if (needWebExtract && !needKnowledgeExtract) {
|
||||||
|
prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||||
|
schema = z.object({ websearch: WebSearchSchema })
|
||||||
|
} else if (!needWebExtract && needKnowledgeExtract) {
|
||||||
|
prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
|
||||||
|
schema = z.object({ knowledge: KnowledgeSearchSchema })
|
||||||
|
} else {
|
||||||
|
prompt = SEARCH_SUMMARY_PROMPT
|
||||||
|
schema = SearchIntentAnalysisSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建消息上下文
|
||||||
|
const messages = lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage]
|
||||||
|
console.log('messagesmessagesmessagesmessagesmessagesmessagesmessages', messages)
|
||||||
|
// 格式化消息为提示词期望的格式
|
||||||
|
// const chatHistory =
|
||||||
|
// messages.length > 1
|
||||||
|
// ? messages
|
||||||
|
// .slice(0, -1)
|
||||||
|
// .map((msg) => `${msg.role}: ${getMainTextContent(msg)}`)
|
||||||
|
// .join('\n')
|
||||||
|
// : ''
|
||||||
|
// const question = getMainTextContent(lastUserMessage) || ''
|
||||||
|
|
||||||
|
// // 使用模板替换变量
|
||||||
|
// const formattedPrompt = prompt.replace('{chat_history}', chatHistory).replace('{question}', question)
|
||||||
|
|
||||||
|
// 获取模型和provider信息
|
||||||
|
const model = assistant.model || getDefaultModel()
|
||||||
|
const provider = getProviderByModel(model)
|
||||||
|
|
||||||
|
if (!provider || isEmpty(provider.apiKey)) {
|
||||||
|
console.error('Provider not found or missing API key')
|
||||||
|
return getFallbackResult()
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result = await context?.executor?.generateObject(model.id, { schema, prompt })
|
||||||
|
console.log('result', context)
|
||||||
|
const parsedResult = result?.object as SearchIntentResult
|
||||||
|
|
||||||
|
// 根据需求过滤结果
|
||||||
|
return {
|
||||||
|
websearch: needWebExtract ? parsedResult?.websearch : undefined,
|
||||||
|
knowledge: needKnowledgeExtract ? parsedResult?.knowledge : undefined
|
||||||
|
}
|
||||||
|
} catch (e: any) {
|
||||||
|
console.error('analyze search intent error', e)
|
||||||
|
return getFallbackResult()
|
||||||
|
}
|
||||||
|
|
||||||
|
function getFallbackResult(): SearchIntentResult {
|
||||||
|
const fallbackContent = getMainTextContent(lastUserMessage)
|
||||||
|
return {
|
||||||
|
websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
|
||||||
|
knowledge: shouldKnowledgeSearch
|
||||||
|
? {
|
||||||
|
question: [fallbackContent || 'search'],
|
||||||
|
rewrite: fallbackContent || 'search'
|
||||||
|
}
|
||||||
|
: undefined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 🧠 记忆存储函数 - 基于注释代码中的 processConversationMemory
|
||||||
|
*/
|
||||||
|
async function storeConversationMemory(messages: Message[], assistant: Assistant): Promise<void> {
|
||||||
|
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||||
|
|
||||||
|
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||||
|
console.log('Memory storage is disabled')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const memoryConfig = selectMemoryConfig(store.getState())
|
||||||
|
|
||||||
|
// 转换消息为记忆处理器期望的格式
|
||||||
|
const conversationMessages = messages
|
||||||
|
.filter((msg) => msg.role === 'user' || msg.role === 'assistant')
|
||||||
|
.map((msg) => ({
|
||||||
|
role: msg.role as 'user' | 'assistant',
|
||||||
|
content: getMainTextContent(msg) || ''
|
||||||
|
}))
|
||||||
|
.filter((msg) => msg.content.trim().length > 0)
|
||||||
|
|
||||||
|
if (conversationMessages.length < 2) {
|
||||||
|
console.log('Need at least a user message and assistant response for memory processing')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const currentUserId = selectCurrentUserId(store.getState())
|
||||||
|
const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||||
|
|
||||||
|
const processorConfig = MemoryProcessor.getProcessorConfig(
|
||||||
|
memoryConfig,
|
||||||
|
assistant.id,
|
||||||
|
currentUserId,
|
||||||
|
lastUserMessage?.id
|
||||||
|
)
|
||||||
|
|
||||||
|
console.log('Processing conversation memory...', { messageCount: conversationMessages.length })
|
||||||
|
|
||||||
|
// 后台处理对话记忆(不阻塞 UI)
|
||||||
|
const memoryProcessor = new MemoryProcessor()
|
||||||
|
memoryProcessor
|
||||||
|
.processConversation(conversationMessages, processorConfig)
|
||||||
|
.then((result) => {
|
||||||
|
console.log('Memory processing completed:', result)
|
||||||
|
if (result.facts?.length > 0) {
|
||||||
|
console.log('Extracted facts from conversation:', result.facts)
|
||||||
|
console.log('Memory operations performed:', result.operations)
|
||||||
|
} else {
|
||||||
|
console.log('No facts extracted from conversation')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Background memory processing failed:', error)
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error in conversation memory processing:', error)
|
||||||
|
// 不抛出错误,避免影响主流程
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 🎯 搜索编排插件
|
||||||
|
*/
|
||||||
|
export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||||
|
// 存储意图分析结果
|
||||||
|
const intentAnalysisResults: { [requestId: string]: SearchIntentResult } = {}
|
||||||
|
const userMessages: { [requestId: string]: Message } = {}
|
||||||
|
console.log('searchOrchestrationPlugin', assistant)
|
||||||
|
return definePlugin({
|
||||||
|
name: 'search-orchestration',
|
||||||
|
enforce: 'pre', // 确保在其他插件之前执行
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 🔍 Step 1: 意图识别阶段
|
||||||
|
*/
|
||||||
|
onRequestStart: async (context: AiRequestContext) => {
|
||||||
|
console.log('🧠 [SearchOrchestration] Starting intent analysis...', context.requestId)
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 从参数中提取信息
|
||||||
|
const messages = context.originalParams.messages
|
||||||
|
|
||||||
|
if (!messages || messages.length === 0) {
|
||||||
|
console.log('🧠 [SearchOrchestration] No messages found, skipping analysis')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const lastUserMessage = messages[messages.length - 1]
|
||||||
|
const lastAssistantMessage = messages.length >= 2 ? messages[messages.length - 2] : undefined
|
||||||
|
|
||||||
|
// 存储用户消息用于后续记忆存储
|
||||||
|
userMessages[context.requestId] = lastUserMessage
|
||||||
|
|
||||||
|
// 判断是否需要各种搜索
|
||||||
|
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||||
|
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||||
|
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||||
|
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||||
|
|
||||||
|
const shouldWebSearch = !!assistant.webSearchProviderId
|
||||||
|
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
|
||||||
|
const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory
|
||||||
|
|
||||||
|
console.log('🧠 [SearchOrchestration] Search capabilities:', {
|
||||||
|
shouldWebSearch,
|
||||||
|
shouldKnowledgeSearch,
|
||||||
|
shouldMemorySearch
|
||||||
|
})
|
||||||
|
|
||||||
|
// 执行意图分析
|
||||||
|
if (shouldWebSearch || shouldKnowledgeSearch) {
|
||||||
|
const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, {
|
||||||
|
shouldWebSearch,
|
||||||
|
shouldKnowledgeSearch,
|
||||||
|
shouldMemorySearch,
|
||||||
|
lastAnswer: lastAssistantMessage,
|
||||||
|
context
|
||||||
|
})
|
||||||
|
|
||||||
|
if (analysisResult) {
|
||||||
|
intentAnalysisResults[context.requestId] = analysisResult
|
||||||
|
console.log('🧠 [SearchOrchestration] Intent analysis completed:', analysisResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('🧠 [SearchOrchestration] Intent analysis failed:', error)
|
||||||
|
// 不抛出错误,让流程继续
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 🔧 Step 2: 工具配置阶段
|
||||||
|
*/
|
||||||
|
transformParams: async (params: any, context: AiRequestContext) => {
|
||||||
|
console.log('🔧 [SearchOrchestration] Configuring tools based on intent...', context.requestId)
|
||||||
|
|
||||||
|
try {
|
||||||
|
const analysisResult = intentAnalysisResults[context.requestId]
|
||||||
|
console.log('analysisResult', analysisResult)
|
||||||
|
if (!analysisResult || !assistant) {
|
||||||
|
console.log('🔧 [SearchOrchestration] No analysis result or assistant, skipping tool configuration')
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保 tools 对象存在
|
||||||
|
if (!params.tools) {
|
||||||
|
params.tools = {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 🌐 网络搜索工具配置
|
||||||
|
if (analysisResult.websearch && assistant.webSearchProviderId) {
|
||||||
|
const needsSearch = analysisResult.websearch.question && analysisResult.websearch.question[0] !== 'not_needed'
|
||||||
|
|
||||||
|
if (needsSearch) {
|
||||||
|
console.log('🌐 [SearchOrchestration] Adding web search tool')
|
||||||
|
params.tools['builtin_web_search'] = webSearchTool(assistant.webSearchProviderId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 📚 知识库搜索工具配置
|
||||||
|
if (analysisResult.knowledge) {
|
||||||
|
const needsKnowledgeSearch =
|
||||||
|
analysisResult.knowledge.question && analysisResult.knowledge.question[0] !== 'not_needed'
|
||||||
|
|
||||||
|
if (needsKnowledgeSearch) {
|
||||||
|
console.log('📚 [SearchOrchestration] Adding knowledge search tool')
|
||||||
|
// TODO: 添加知识库搜索工具
|
||||||
|
// params.tools['builtin_knowledge_search'] = knowledgeSearchTool(assistant.knowledge_bases)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 🧠 记忆搜索工具配置
|
||||||
|
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||||
|
if (globalMemoryEnabled && assistant.enableMemory) {
|
||||||
|
console.log('🧠 [SearchOrchestration] Adding memory search tool')
|
||||||
|
params.tools['builtin_memory_search'] = memorySearchTool()
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log('🔧 [SearchOrchestration] Tools configured:', Object.keys(params.tools))
|
||||||
|
return params
|
||||||
|
} catch (error) {
|
||||||
|
console.error('🔧 [SearchOrchestration] Tool configuration failed:', error)
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 💾 Step 3: 记忆存储阶段
|
||||||
|
*/
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||||
|
onRequestEnd: async (context: AiRequestContext, _result: any) => {
|
||||||
|
console.log('💾 [SearchOrchestration] Starting memory storage...', context.requestId)
|
||||||
|
|
||||||
|
try {
|
||||||
|
const assistant = context.originalParams.assistant as Assistant
|
||||||
|
const messages = context.originalParams.messages as Message[]
|
||||||
|
|
||||||
|
if (messages && assistant) {
|
||||||
|
await storeConversationMemory(messages, assistant)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理缓存
|
||||||
|
delete intentAnalysisResults[context.requestId]
|
||||||
|
delete userMessages[context.requestId]
|
||||||
|
} catch (error) {
|
||||||
|
console.error('💾 [SearchOrchestration] Memory storage failed:', error)
|
||||||
|
// 不抛出错误,避免影响主流程
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export default searchOrchestrationPlugin
|
||||||
@ -50,7 +50,7 @@ import { buildSystemPrompt } from '@renderer/utils/prompt'
|
|||||||
import { defaultTimeout } from '@shared/config/constant'
|
import { defaultTimeout } from '@shared/config/constant'
|
||||||
import { isEmpty } from 'lodash'
|
import { isEmpty } from 'lodash'
|
||||||
|
|
||||||
import { webSearchTool } from './tools/WebSearchTool'
|
// import { webSearchTool } from './tools/WebSearchTool'
|
||||||
// import { jsonSchemaToZod } from 'json-schema-to-zod'
|
// import { jsonSchemaToZod } from 'json-schema-to-zod'
|
||||||
import { setupToolsConfig } from './utils/mcp'
|
import { setupToolsConfig } from './utils/mcp'
|
||||||
import { buildProviderOptions } from './utils/options'
|
import { buildProviderOptions } from './utils/options'
|
||||||
@ -298,9 +298,9 @@ export async function buildStreamTextParams(
|
|||||||
enableToolUse: enableTools
|
enableToolUse: enableTools
|
||||||
})
|
})
|
||||||
|
|
||||||
if (webSearchProviderId) {
|
// if (webSearchProviderId) {
|
||||||
tools['builtin_web_search'] = webSearchTool(webSearchProviderId)
|
// tools['builtin_web_search'] = webSearchTool(webSearchProviderId)
|
||||||
}
|
// }
|
||||||
|
|
||||||
// 构建真正的 providerOptions
|
// 构建真正的 providerOptions
|
||||||
const providerOptions = buildProviderOptions(assistant, model, provider, {
|
const providerOptions = buildProviderOptions(assistant, model, provider, {
|
||||||
|
|||||||
@ -16,7 +16,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
|
|||||||
import i18n from '@renderer/i18n'
|
import i18n from '@renderer/i18n'
|
||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
import { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
|
import { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
|
||||||
import { type Chunk } from '@renderer/types/chunk'
|
import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||||
import { Message } from '@renderer/types/newMessage'
|
import { Message } from '@renderer/types/newMessage'
|
||||||
import { SdkModel } from '@renderer/types/sdk'
|
import { SdkModel } from '@renderer/types/sdk'
|
||||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||||
@ -432,14 +432,16 @@ export async function fetchChatCompletion({
|
|||||||
enableReasoning: capabilities.enableReasoning,
|
enableReasoning: capabilities.enableReasoning,
|
||||||
enableTool: assistant.settings?.toolUseMode === 'prompt',
|
enableTool: assistant.settings?.toolUseMode === 'prompt',
|
||||||
enableWebSearch: capabilities.enableWebSearch,
|
enableWebSearch: capabilities.enableWebSearch,
|
||||||
mcpTools
|
mcpTools,
|
||||||
|
assistant
|
||||||
}
|
}
|
||||||
// --- Call AI Completions ---
|
// if (capabilities.enableWebSearch) {
|
||||||
// onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
|
||||||
await AI.completions(modelId, aiSdkParams, middlewareConfig)
|
|
||||||
// if (enableWebSearch) {
|
|
||||||
// onChunkReceived({ type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS })
|
// onChunkReceived({ type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS })
|
||||||
// }
|
// }
|
||||||
|
// --- Call AI Completions ---
|
||||||
|
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||||
|
await AI.completions(modelId, aiSdkParams, middlewareConfig)
|
||||||
|
|
||||||
// await AI.completions(
|
// await AI.completions(
|
||||||
// {
|
// {
|
||||||
// callType: 'chat',
|
// callType: 'chat',
|
||||||
|
|||||||
28
yarn.lock
28
yarn.lock
@ -148,6 +148,18 @@ __metadata:
|
|||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
|
"@ai-sdk/openai-compatible@npm:1.0.0-beta.8":
|
||||||
|
version: 1.0.0-beta.8
|
||||||
|
resolution: "@ai-sdk/openai-compatible@npm:1.0.0-beta.8"
|
||||||
|
dependencies:
|
||||||
|
"@ai-sdk/provider": "npm:2.0.0-beta.1"
|
||||||
|
"@ai-sdk/provider-utils": "npm:3.0.0-beta.5"
|
||||||
|
peerDependencies:
|
||||||
|
zod: ^3.25.76 || ^4
|
||||||
|
checksum: 10c0/047f044bf0da9608e09073957916373bd39760ec00f498ba0c4a597ec70ba9eb4ef31f06b21b363b3c1ba775f64fcc46d41b60a171e0e99250824817ecb19ba8
|
||||||
|
languageName: node
|
||||||
|
linkType: hard
|
||||||
|
|
||||||
"@ai-sdk/openai@npm:2.0.0-beta.9":
|
"@ai-sdk/openai@npm:2.0.0-beta.9":
|
||||||
version: 2.0.0-beta.9
|
version: 2.0.0-beta.9
|
||||||
resolution: "@ai-sdk/openai@npm:2.0.0-beta.9"
|
resolution: "@ai-sdk/openai@npm:2.0.0-beta.9"
|
||||||
@ -188,6 +200,20 @@ __metadata:
|
|||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
|
"@ai-sdk/provider-utils@npm:3.0.0-beta.5":
|
||||||
|
version: 3.0.0-beta.5
|
||||||
|
resolution: "@ai-sdk/provider-utils@npm:3.0.0-beta.5"
|
||||||
|
dependencies:
|
||||||
|
"@ai-sdk/provider": "npm:2.0.0-beta.1"
|
||||||
|
"@standard-schema/spec": "npm:^1.0.0"
|
||||||
|
eventsource-parser: "npm:^3.0.3"
|
||||||
|
zod-to-json-schema: "npm:^3.24.1"
|
||||||
|
peerDependencies:
|
||||||
|
zod: ^3.25.76 || ^4
|
||||||
|
checksum: 10c0/229a53672accc5d9d986da2e18f619dbcfaf64ab269c8cc9e955480c4428d2a87255330c587453d01eb66ac297bb6975f91c24a93f87dd4b84f6428cb60d4211
|
||||||
|
languageName: node
|
||||||
|
linkType: hard
|
||||||
|
|
||||||
"@ai-sdk/provider@npm:2.0.0-beta.1":
|
"@ai-sdk/provider@npm:2.0.0-beta.1":
|
||||||
version: 2.0.0-beta.1
|
version: 2.0.0-beta.1
|
||||||
resolution: "@ai-sdk/provider@npm:2.0.0-beta.1"
|
resolution: "@ai-sdk/provider@npm:2.0.0-beta.1"
|
||||||
@ -1251,7 +1277,7 @@ __metadata:
|
|||||||
"@ai-sdk/deepseek": "npm:1.0.0-beta.6"
|
"@ai-sdk/deepseek": "npm:1.0.0-beta.6"
|
||||||
"@ai-sdk/google": "npm:2.0.0-beta.11"
|
"@ai-sdk/google": "npm:2.0.0-beta.11"
|
||||||
"@ai-sdk/openai": "npm:2.0.0-beta.9"
|
"@ai-sdk/openai": "npm:2.0.0-beta.9"
|
||||||
"@ai-sdk/openai-compatible": "npm:1.0.0-beta.6"
|
"@ai-sdk/openai-compatible": "npm:1.0.0-beta.8"
|
||||||
"@ai-sdk/provider": "npm:2.0.0-beta.1"
|
"@ai-sdk/provider": "npm:2.0.0-beta.1"
|
||||||
"@ai-sdk/provider-utils": "npm:3.0.0-beta.3"
|
"@ai-sdk/provider-utils": "npm:3.0.0-beta.3"
|
||||||
"@ai-sdk/xai": "npm:2.0.0-beta.8"
|
"@ai-sdk/xai": "npm:2.0.0-beta.8"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user