mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 14:41:24 +08:00
Feat/provider options and built-in tools (#10068)
* refactor(RuntimeExecutor, PluginEngine): streamline parameter handling and improve type definitions for model operations - Updated parameter handling in RuntimeExecutor methods to use specific types for generate and stream functions. - Refactored PluginEngine methods to enhance type safety and reduce redundancy in model resolution. - Introduced new type definitions for generate and stream parameters in types.ts for better clarity and maintainability. - Adjusted provider options mapping in buildProviderOptions to accommodate new provider types. * feat(googleToolsPlugin): enhance Google tools integration and update dependencies - Updated the `ai` package version to `5.0.38` and `@ai-sdk/gateway` to `1.0.20` in `package.json` and `yarn.lock`. - Introduced `googleToolsPlugin` with improved parameter handling and configuration options for Google tools. - Added support for `urlContext` in middleware configuration and plugin builder. - Refactored web search tool to streamline response handling and citation formatting. - Updated various service methods to include `enableUrlContext` capability. * fix: update tool response handling and clean up unused code - Modified the `processKnowledgeReferences` function to accept the full response object instead of just `knowledgeReferences`. - Cleaned up the `searchOrchestrationPlugin` by removing unnecessary blank lines. - Removed commented-out code in `telemetryPlugin` to improve readability. - Updated `KnowledgeSearchTool` to streamline the execution flow and return results more efficiently. - Adjusted `MessageKnowledgeSearch` components to reflect changes in the data structure returned from the knowledge search tool. - Enhanced `MemorySearchTool` by simplifying error handling and removing redundant code. * refactor: clean up KnowledgeSearchTool and WebSearchTool by removing commented-out code - Removed unnecessary commented-out code in `KnowledgeSearchTool` and `WebSearchTool` to improve code readability and maintainability. - Simplified the `processKnowledgeReferences` function by eliminating the console log statement for cleaner output. * chore: bump version to 1.0.0-alpha.14 in aiCore package.json * chore: update @ai-sdk/google-vertex to version 3.0.25 and add @ai-sdk/anthropic@2.0.15 to dependencies - Bumped the version of `@ai-sdk/google-vertex` in `package.json` and `yarn.lock` to 3.0.25. - Added `@ai-sdk/anthropic` version 2.0.15 to `yarn.lock` with updated dependencies. - Refactored `parameterBuilder.ts` to integrate new tools from `@ai-sdk/google-vertex` for enhanced functionality.
This commit is contained in:
parent
3eee8faad4
commit
e10042a433
@ -91,7 +91,7 @@
|
||||
"@agentic/searxng": "^7.3.3",
|
||||
"@agentic/tavily": "^7.3.3",
|
||||
"@ai-sdk/amazon-bedrock": "^3.0.0",
|
||||
"@ai-sdk/google-vertex": "^3.0.0",
|
||||
"@ai-sdk/google-vertex": "^3.0.25",
|
||||
"@ai-sdk/mistral": "^2.0.0",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
@ -202,7 +202,7 @@
|
||||
"@viz-js/lang-dot": "^1.0.5",
|
||||
"@viz-js/viz": "^3.14.0",
|
||||
"@xyflow/react": "^12.4.4",
|
||||
"ai": "^5.0.29",
|
||||
"ai": "^5.0.38",
|
||||
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
|
||||
"archiver": "^7.0.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@cherrystudio/ai-core",
|
||||
"version": "1.0.0-alpha.13",
|
||||
"version": "1.0.0-alpha.14",
|
||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/index.mjs",
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
import { google } from '@ai-sdk/google'
|
||||
|
||||
import { definePlugin } from '../../'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
|
||||
const toolNameMap = {
|
||||
googleSearch: 'google_search',
|
||||
urlContext: 'url_context',
|
||||
codeExecution: 'code_execution'
|
||||
} as const
|
||||
|
||||
type ToolConfigKey = keyof typeof toolNameMap
|
||||
type ToolConfig = { googleSearch?: boolean; urlContext?: boolean; codeExecution?: boolean }
|
||||
|
||||
export const googleToolsPlugin = (config?: ToolConfig) =>
|
||||
definePlugin({
|
||||
name: 'googleToolsPlugin',
|
||||
transformParams: <T>(params: T, context: AiRequestContext): T => {
|
||||
const { providerId } = context
|
||||
if (providerId === 'google' && config) {
|
||||
if (typeof params === 'object' && params !== null) {
|
||||
const typedParams = params as T & { tools?: Record<string, unknown> }
|
||||
|
||||
if (!typedParams.tools) {
|
||||
typedParams.tools = {}
|
||||
}
|
||||
|
||||
// 使用类型安全的方式遍历配置
|
||||
;(Object.keys(config) as ToolConfigKey[]).forEach((key) => {
|
||||
if (config[key] && key in toolNameMap && key in google.tools) {
|
||||
const toolName = toolNameMap[key]
|
||||
typedParams.tools![toolName] = google.tools[key]({})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
})
|
||||
@ -4,6 +4,7 @@
|
||||
*/
|
||||
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
|
||||
|
||||
export { googleToolsPlugin } from './googleToolsPlugin'
|
||||
export { createLoggingPlugin } from './logging'
|
||||
export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
|
||||
export type { PromptToolUseConfig, ToolUseRequestContext, ToolUseResult } from './toolUsePlugin/type'
|
||||
|
||||
@ -4,12 +4,12 @@
|
||||
*/
|
||||
import { ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
import {
|
||||
experimental_generateImage as generateImage,
|
||||
generateObject,
|
||||
generateText,
|
||||
experimental_generateImage as _generateImage,
|
||||
generateObject as _generateObject,
|
||||
generateText as _generateText,
|
||||
LanguageModel,
|
||||
streamObject,
|
||||
streamText
|
||||
streamObject as _streamObject,
|
||||
streamText as _streamText
|
||||
} from 'ai'
|
||||
|
||||
import { globalModelResolver } from '../models'
|
||||
@ -18,7 +18,14 @@ import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||
import { type ProviderId } from '../providers'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
||||
import { PluginEngine } from './pluginEngine'
|
||||
import { type RuntimeConfig } from './types'
|
||||
import type {
|
||||
generateImageParams,
|
||||
generateObjectParams,
|
||||
generateTextParams,
|
||||
RuntimeConfig,
|
||||
streamObjectParams,
|
||||
streamTextParams
|
||||
} from './types'
|
||||
|
||||
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
public pluginEngine: PluginEngine<T>
|
||||
@ -75,12 +82,12 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
* 流式文本生成
|
||||
*/
|
||||
async streamText(
|
||||
params: Parameters<typeof streamText>[0],
|
||||
params: streamTextParams,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamText>> {
|
||||
const { model, ...restParams } = params
|
||||
): Promise<ReturnType<typeof _streamText>> {
|
||||
const { model } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
@ -94,19 +101,16 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
|
||||
return this.pluginEngine.executeStreamWithPlugins(
|
||||
'streamText',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams, streamTransforms) => {
|
||||
params,
|
||||
(resolvedModel, transformedParams, streamTransforms) => {
|
||||
const experimental_transform =
|
||||
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
|
||||
|
||||
const finalParams = {
|
||||
model: resolvedModel,
|
||||
return _streamText({
|
||||
...transformedParams,
|
||||
model: resolvedModel,
|
||||
experimental_transform
|
||||
} as Parameters<typeof streamText>[0]
|
||||
|
||||
return await streamText(finalParams)
|
||||
})
|
||||
}
|
||||
)
|
||||
}
|
||||
@ -117,12 +121,12 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
* 生成文本
|
||||
*/
|
||||
async generateText(
|
||||
params: Parameters<typeof generateText>[0],
|
||||
params: generateTextParams,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateText>> {
|
||||
const { model, ...restParams } = params
|
||||
): Promise<ReturnType<typeof _generateText>> {
|
||||
const { model } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
@ -134,12 +138,10 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
return this.pluginEngine.executeWithPlugins<Parameters<typeof _generateText>[0], ReturnType<typeof _generateText>>(
|
||||
'generateText',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
generateText({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateText>[0])
|
||||
params,
|
||||
(resolvedModel, transformedParams) => _generateText({ ...transformedParams, model: resolvedModel })
|
||||
)
|
||||
}
|
||||
|
||||
@ -147,12 +149,12 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
* 生成结构化对象
|
||||
*/
|
||||
async generateObject(
|
||||
params: Parameters<typeof generateObject>[0],
|
||||
params: generateObjectParams,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateObject>> {
|
||||
const { model, ...restParams } = params
|
||||
): Promise<ReturnType<typeof _generateObject>> {
|
||||
const { model } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
@ -164,25 +166,23 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
return this.pluginEngine.executeWithPlugins<generateObjectParams, ReturnType<typeof _generateObject>>(
|
||||
'generateObject',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
generateObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateObject>[0])
|
||||
params,
|
||||
async (resolvedModel, transformedParams) => _generateObject({ ...transformedParams, model: resolvedModel })
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式生成结构化对象
|
||||
*/
|
||||
async streamObject(
|
||||
params: Parameters<typeof streamObject>[0],
|
||||
streamObject(
|
||||
params: streamObjectParams,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamObject>> {
|
||||
const { model, ...restParams } = params
|
||||
): Promise<ReturnType<typeof _streamObject>> {
|
||||
const { model } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
@ -194,23 +194,17 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'streamObject',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
streamObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof streamObject>[0])
|
||||
return this.pluginEngine.executeStreamWithPlugins('streamObject', params, (resolvedModel, transformedParams) =>
|
||||
_streamObject({ ...transformedParams, model: resolvedModel })
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成图像
|
||||
*/
|
||||
async generateImage(
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'> & { model: string | ImageModelV2 }
|
||||
): Promise<ReturnType<typeof generateImage>> {
|
||||
generateImage(params: generateImageParams): Promise<ReturnType<typeof _generateImage>> {
|
||||
try {
|
||||
const { model, ...restParams } = params
|
||||
const { model } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
@ -219,13 +213,8 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return await this.pluginEngine.executeImageWithPlugins(
|
||||
'generateImage',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) => {
|
||||
return await generateImage({ model: resolvedModel, ...transformedParams })
|
||||
}
|
||||
return this.pluginEngine.executeImageWithPlugins('generateImage', params, (resolvedModel, transformedParams) =>
|
||||
_generateImage({ ...transformedParams, model: resolvedModel })
|
||||
)
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
/* eslint-disable @eslint-react/naming-convention/context-name */
|
||||
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { LanguageModel } from 'ai'
|
||||
import { experimental_generateImage, generateObject, generateText, LanguageModel, streamObject, streamText } from 'ai'
|
||||
|
||||
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
@ -62,17 +62,19 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
* 执行带插件的操作(非流式)
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeWithPlugins<TParams, TResult>(
|
||||
async executeWithPlugins<
|
||||
TParams extends Parameters<typeof generateText | typeof generateObject>[0],
|
||||
TResult extends ReturnType<typeof generateText | typeof generateObject>
|
||||
>(
|
||||
methodName: string,
|
||||
model: LanguageModel,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
|
||||
executor: (model: LanguageModel, transformedParams: TParams) => TResult,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: LanguageModel | undefined
|
||||
let modelId: string
|
||||
|
||||
const { model } = params
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
@ -89,7 +91,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeWithPlugins(methodName, model, newParams, executor, context)
|
||||
const result = await this.executeWithPlugins(methodName, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
@ -138,17 +140,19 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
* 执行带插件的图像生成操作
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeImageWithPlugins<TParams, TResult>(
|
||||
async executeImageWithPlugins<
|
||||
TParams extends Omit<Parameters<typeof experimental_generateImage>[0], 'model'> & { model: string | ImageModelV2 },
|
||||
TResult extends ReturnType<typeof experimental_generateImage>
|
||||
>(
|
||||
methodName: string,
|
||||
model: ImageModelV2 | string,
|
||||
params: TParams,
|
||||
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
|
||||
executor: (model: ImageModelV2, transformedParams: TParams) => TResult,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: ImageModelV2 | undefined
|
||||
let modelId: string
|
||||
|
||||
const { model } = params
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
@ -165,7 +169,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeImageWithPlugins(methodName, model, newParams, executor, context)
|
||||
const result = await this.executeImageWithPlugins(methodName, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
@ -214,17 +218,19 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
* 执行流式调用的通用逻辑(支持流转换器)
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeStreamWithPlugins<TParams, TResult>(
|
||||
async executeStreamWithPlugins<
|
||||
TParams extends Parameters<typeof streamText | typeof streamObject>[0],
|
||||
TResult extends ReturnType<typeof streamText | typeof streamObject>
|
||||
>(
|
||||
methodName: string,
|
||||
model: LanguageModel,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
|
||||
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => TResult,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: LanguageModel | undefined
|
||||
let modelId: string
|
||||
|
||||
const { model } = params
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
@ -241,7 +247,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeStreamWithPlugins(methodName, model, newParams, executor, context)
|
||||
const result = await this.executeStreamWithPlugins(methodName, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
/**
|
||||
* Runtime 层类型定义
|
||||
*/
|
||||
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { experimental_generateImage, generateObject, generateText, streamObject, streamText } from 'ai'
|
||||
|
||||
import { type ModelConfig } from '../models/types'
|
||||
import { type AiPlugin } from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
@ -13,3 +16,11 @@ export interface RuntimeConfig<T extends ProviderId = ProviderId> {
|
||||
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
export type generateImageParams = Omit<Parameters<typeof experimental_generateImage>[0], 'model'> & {
|
||||
model: string | ImageModelV2
|
||||
}
|
||||
export type generateObjectParams = Parameters<typeof generateObject>[0]
|
||||
export type generateTextParams = Parameters<typeof generateText>[0]
|
||||
export type streamObjectParams = Parameters<typeof streamObject>[0]
|
||||
export type streamTextParams = Parameters<typeof streamText>[0]
|
||||
|
||||
@ -281,7 +281,7 @@ export class ToolCallChunkHandler {
|
||||
// 工具特定的后处理
|
||||
switch (toolResponse.tool.name) {
|
||||
case 'builtin_knowledge_search': {
|
||||
processKnowledgeReferences(toolResponse.response?.knowledgeReferences, this.onChunk)
|
||||
processKnowledgeReferences(toolResponse.response, this.onChunk)
|
||||
break
|
||||
}
|
||||
// 未来可以在这里添加其他工具的后处理逻辑
|
||||
|
||||
@ -22,6 +22,7 @@ export interface AiSdkMiddlewareConfig {
|
||||
isImageGenerationEndpoint: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
enableUrlContext: boolean
|
||||
mcpTools?: MCPTool[]
|
||||
uiMessages?: Message[]
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { AiPlugin } from '@cherrystudio/ai-core'
|
||||
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { createPromptToolUsePlugin, googleToolsPlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { loggerService } from '@logger'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
import { Assistant } from '@renderer/types'
|
||||
@ -70,9 +70,10 @@ export function buildPlugins(
|
||||
)
|
||||
}
|
||||
|
||||
// if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||
// plugins.push(createNativeToolUsePlugin())
|
||||
// }
|
||||
if (middlewareConfig.enableUrlContext) {
|
||||
plugins.push(googleToolsPlugin({ urlContext: true }))
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
'Final plugin list:',
|
||||
plugins.map((p) => p.name)
|
||||
|
||||
@ -267,7 +267,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
|
||||
const shouldWebSearch = !!assistant.webSearchProviderId
|
||||
const shouldKnowledgeSearch = hasKnowledgeBase && knowledgeRecognition === 'on'
|
||||
const shouldMemorySearch = globalMemoryEnabled && assistant.enableMemory
|
||||
|
||||
@ -58,91 +58,6 @@ class AdapterTracer {
|
||||
})
|
||||
}
|
||||
|
||||
// startSpan(name: string, options?: any, context?: any): Span {
|
||||
// // 如果提供了父 SpanContext 且未显式传入 context,则使用父上下文
|
||||
// const contextToUse = context ?? this.cachedParentContext ?? otelContext.active()
|
||||
|
||||
// const span = this.originalTracer.startSpan(name, options, contextToUse)
|
||||
|
||||
// // 标记父子关系,便于在转换阶段兜底重建层级
|
||||
// try {
|
||||
// if (this.parentSpanContext) {
|
||||
// span.setAttribute('trace.parentSpanId', this.parentSpanContext.spanId)
|
||||
// span.setAttribute('trace.parentTraceId', this.parentSpanContext.traceId)
|
||||
// }
|
||||
// if (this.topicId) {
|
||||
// span.setAttribute('trace.topicId', this.topicId)
|
||||
// }
|
||||
// } catch (e) {
|
||||
// logger.debug('Failed to set trace parent attributes', e as Error)
|
||||
// }
|
||||
|
||||
// logger.info('AI SDK span created via AdapterTracer', {
|
||||
// spanName: name,
|
||||
// spanId: span.spanContext().spanId,
|
||||
// traceId: span.spanContext().traceId,
|
||||
// parentTraceId: this.parentSpanContext?.traceId,
|
||||
// topicId: this.topicId,
|
||||
// modelName: this.modelName,
|
||||
// traceIdMatches: this.parentSpanContext ? span.spanContext().traceId === this.parentSpanContext.traceId : undefined
|
||||
// })
|
||||
|
||||
// // 包装 span 的 end 方法,在结束时进行数据转换
|
||||
// const originalEnd = span.end.bind(span)
|
||||
// span.end = (endTime?: any) => {
|
||||
// logger.info('AI SDK span.end() called - about to convert span', {
|
||||
// spanName: name,
|
||||
// spanId: span.spanContext().spanId,
|
||||
// traceId: span.spanContext().traceId,
|
||||
// topicId: this.topicId,
|
||||
// modelName: this.modelName
|
||||
// })
|
||||
|
||||
// // 调用原始 end 方法
|
||||
// originalEnd(endTime)
|
||||
|
||||
// // 转换并保存 span 数据
|
||||
// try {
|
||||
// logger.info('Converting AI SDK span to SpanEntity', {
|
||||
// spanName: name,
|
||||
// spanId: span.spanContext().spanId,
|
||||
// traceId: span.spanContext().traceId,
|
||||
// topicId: this.topicId,
|
||||
// modelName: this.modelName
|
||||
// })
|
||||
// logger.info('spanspanspanspanspanspan', span)
|
||||
// const spanEntity = AiSdkSpanAdapter.convertToSpanEntity({
|
||||
// span,
|
||||
// topicId: this.topicId,
|
||||
// modelName: this.modelName
|
||||
// })
|
||||
|
||||
// // 保存转换后的数据
|
||||
// window.api.trace.saveEntity(spanEntity)
|
||||
|
||||
// logger.info('AI SDK span converted and saved successfully', {
|
||||
// spanName: name,
|
||||
// spanId: span.spanContext().spanId,
|
||||
// traceId: span.spanContext().traceId,
|
||||
// topicId: this.topicId,
|
||||
// modelName: this.modelName,
|
||||
// hasUsage: !!spanEntity.usage,
|
||||
// usage: spanEntity.usage
|
||||
// })
|
||||
// } catch (error) {
|
||||
// logger.error('Failed to convert AI SDK span', error as Error, {
|
||||
// spanName: name,
|
||||
// spanId: span.spanContext().spanId,
|
||||
// traceId: span.spanContext().traceId,
|
||||
// topicId: this.topicId,
|
||||
// modelName: this.modelName
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
// return span
|
||||
// }
|
||||
|
||||
startActiveSpan<F extends (span: Span) => any>(name: string, fn: F): ReturnType<F>
|
||||
startActiveSpan<F extends (span: Span) => any>(name: string, options: any, fn: F): ReturnType<F>
|
||||
startActiveSpan<F extends (span: Span) => any>(name: string, options: any, context: any, fn: F): ReturnType<F>
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
* 构建AI SDK的流式和非流式参数
|
||||
*/
|
||||
|
||||
import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
|
||||
import { vertex } from '@ai-sdk/google-vertex/edge'
|
||||
import { loggerService } from '@logger'
|
||||
import {
|
||||
isGenerateImageModel,
|
||||
@ -19,6 +21,7 @@ import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import type { ModelMessage } from 'ai'
|
||||
import { stepCountIs } from 'ai'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import { setupToolsConfig } from '../utils/mcp'
|
||||
import { buildProviderOptions } from '../utils/options'
|
||||
import { getAnthropicThinkingBudget } from '../utils/reasoning'
|
||||
@ -56,6 +59,7 @@ export async function buildStreamTextParams(
|
||||
const { mcpTools } = options
|
||||
|
||||
const model = assistant.model || getDefaultModel()
|
||||
const aiSdkProviderId = getAiSdkProviderId(provider)
|
||||
|
||||
let { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
@ -80,7 +84,7 @@ export async function buildStreamTextParams(
|
||||
|
||||
const enableGenerateImage = !!(isGenerateImageModel(model) && assistant.enableGenerateImage)
|
||||
|
||||
const tools = setupToolsConfig(mcpTools)
|
||||
let tools = setupToolsConfig(mcpTools)
|
||||
|
||||
// if (webSearchProviderId) {
|
||||
// tools['builtin_web_search'] = webSearchTool(webSearchProviderId)
|
||||
@ -103,6 +107,26 @@ export async function buildStreamTextParams(
|
||||
maxTokens -= getAnthropicThinkingBudget(assistant, model)
|
||||
}
|
||||
|
||||
// google-vertex | google-vertex-anthropic
|
||||
if (enableWebSearch) {
|
||||
if (!tools) {
|
||||
tools = {}
|
||||
}
|
||||
if (aiSdkProviderId === 'google-vertex') {
|
||||
tools.google_search = vertex.tools.googleSearch({})
|
||||
} else if (aiSdkProviderId === 'google-vertex-anthropic') {
|
||||
tools.web_search = vertexAnthropic.tools.webSearch_20250305({})
|
||||
}
|
||||
}
|
||||
|
||||
// google-vertex
|
||||
if (enableUrlContext && aiSdkProviderId === 'google-vertex') {
|
||||
if (!tools) {
|
||||
tools = {}
|
||||
}
|
||||
tools.url_context = vertex.tools.urlContext({})
|
||||
}
|
||||
|
||||
// 构建基础参数
|
||||
const params: StreamTextParams = {
|
||||
messages: sdkMessages,
|
||||
@ -112,10 +136,12 @@ export async function buildStreamTextParams(
|
||||
abortSignal: options.requestOptions?.signal,
|
||||
headers: options.requestOptions?.headers,
|
||||
providerOptions,
|
||||
tools,
|
||||
stopWhen: stepCountIs(10),
|
||||
maxRetries: 0
|
||||
}
|
||||
if (tools) {
|
||||
params.tools = tools
|
||||
}
|
||||
if (assistant.prompt) {
|
||||
params.system = assistant.prompt
|
||||
}
|
||||
|
||||
@ -23,8 +23,6 @@ export const knowledgeSearchTool = (
|
||||
Pre-extracted search queries: "${extractedKeywords.question.join(', ')}"
|
||||
Rewritten query: "${extractedKeywords.rewrite}"
|
||||
|
||||
This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.
|
||||
|
||||
Call this tool to execute the search. You can optionally provide additional context to refine the search.`,
|
||||
|
||||
inputSchema: z.object({
|
||||
@ -35,99 +33,102 @@ Call this tool to execute the search. You can optionally provide additional cont
|
||||
}),
|
||||
|
||||
execute: async ({ additionalContext }) => {
|
||||
try {
|
||||
// 获取助手的知识库配置
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
// try {
|
||||
// 获取助手的知识库配置
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
|
||||
// 检查是否有知识库
|
||||
if (!hasKnowledgeBase) {
|
||||
return {
|
||||
summary: 'No knowledge base configured for this assistant.',
|
||||
knowledgeReferences: [],
|
||||
instructions: ''
|
||||
// 检查是否有知识库
|
||||
if (!hasKnowledgeBase) {
|
||||
return []
|
||||
}
|
||||
|
||||
let finalQueries = [...extractedKeywords.question]
|
||||
let finalRewrite = extractedKeywords.rewrite
|
||||
|
||||
if (additionalContext?.trim()) {
|
||||
// 如果大模型提供了额外上下文,使用更具体的描述
|
||||
const cleanContext = additionalContext.trim()
|
||||
if (cleanContext) {
|
||||
finalQueries = [cleanContext]
|
||||
finalRewrite = cleanContext
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要搜索
|
||||
if (finalQueries[0] === 'not_needed') {
|
||||
return []
|
||||
}
|
||||
|
||||
// 构建搜索条件
|
||||
let searchCriteria: { question: string[]; rewrite: string }
|
||||
|
||||
if (knowledgeRecognition === 'off') {
|
||||
// 直接模式:使用用户消息内容
|
||||
const directContent = userMessage || finalQueries[0] || 'search'
|
||||
searchCriteria = {
|
||||
question: [directContent],
|
||||
rewrite: directContent
|
||||
}
|
||||
} else {
|
||||
// 自动模式:使用意图识别的结果
|
||||
searchCriteria = {
|
||||
question: finalQueries,
|
||||
rewrite: finalRewrite
|
||||
}
|
||||
}
|
||||
|
||||
// 构建 ExtractResults 对象
|
||||
const extractResults: ExtractResults = {
|
||||
websearch: undefined,
|
||||
knowledge: searchCriteria
|
||||
}
|
||||
|
||||
// 执行知识库搜索
|
||||
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds, topicId)
|
||||
const knowledgeReferencesData = knowledgeReferences.map((ref: KnowledgeReference) => ({
|
||||
id: ref.id,
|
||||
content: ref.content,
|
||||
sourceUrl: ref.sourceUrl,
|
||||
type: ref.type,
|
||||
file: ref.file,
|
||||
metadata: ref.metadata
|
||||
}))
|
||||
|
||||
// TODO 在工具函数中添加搜索缓存机制
|
||||
// const searchCacheKey = `${topicId}-${JSON.stringify(finalQueries)}`
|
||||
|
||||
// 返回结果
|
||||
return knowledgeReferencesData
|
||||
},
|
||||
toModelOutput: (results) => {
|
||||
let summary = 'No search needed based on the query analysis.'
|
||||
if (results.length > 0) {
|
||||
summary = `Found ${results.length} relevant sources. Use [number] format to cite specific information.`
|
||||
}
|
||||
const referenceContent = `\`\`\`json\n${JSON.stringify(results, null, 2)}\n\`\`\``
|
||||
const fullInstructions = REFERENCE_PROMPT.replace(
|
||||
'{question}',
|
||||
"Based on the knowledge references, please answer the user's question with proper citations."
|
||||
).replace('{references}', referenceContent)
|
||||
|
||||
return {
|
||||
type: 'content',
|
||||
value: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.'
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: summary
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: fullInstructions
|
||||
}
|
||||
}
|
||||
|
||||
let finalQueries = [...extractedKeywords.question]
|
||||
let finalRewrite = extractedKeywords.rewrite
|
||||
|
||||
if (additionalContext?.trim()) {
|
||||
// 如果大模型提供了额外上下文,使用更具体的描述
|
||||
const cleanContext = additionalContext.trim()
|
||||
if (cleanContext) {
|
||||
finalQueries = [cleanContext]
|
||||
finalRewrite = cleanContext
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要搜索
|
||||
if (finalQueries[0] === 'not_needed') {
|
||||
return {
|
||||
summary: 'No search needed based on the query analysis.',
|
||||
knowledgeReferences: [],
|
||||
instructions: ''
|
||||
}
|
||||
}
|
||||
|
||||
// 构建搜索条件
|
||||
let searchCriteria: { question: string[]; rewrite: string }
|
||||
|
||||
if (knowledgeRecognition === 'off') {
|
||||
// 直接模式:使用用户消息内容
|
||||
const directContent = userMessage || finalQueries[0] || 'search'
|
||||
searchCriteria = {
|
||||
question: [directContent],
|
||||
rewrite: directContent
|
||||
}
|
||||
} else {
|
||||
// 自动模式:使用意图识别的结果
|
||||
searchCriteria = {
|
||||
question: finalQueries,
|
||||
rewrite: finalRewrite
|
||||
}
|
||||
}
|
||||
|
||||
// 构建 ExtractResults 对象
|
||||
const extractResults: ExtractResults = {
|
||||
websearch: undefined,
|
||||
knowledge: searchCriteria
|
||||
}
|
||||
|
||||
// 执行知识库搜索
|
||||
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds, topicId)
|
||||
const knowledgeReferencesData = knowledgeReferences.map((ref: KnowledgeReference) => ({
|
||||
id: ref.id,
|
||||
content: ref.content,
|
||||
sourceUrl: ref.sourceUrl,
|
||||
type: ref.type,
|
||||
file: ref.file,
|
||||
metadata: ref.metadata
|
||||
}))
|
||||
|
||||
// const referenceContent = `\`\`\`json\n${JSON.stringify(knowledgeReferencesData, null, 2)}\n\`\`\``
|
||||
// TODO 在工具函数中添加搜索缓存机制
|
||||
// const searchCacheKey = `${topicId}-${JSON.stringify(finalQueries)}`
|
||||
// 可以在插件层面管理已搜索的查询,避免重复搜索
|
||||
const fullInstructions = REFERENCE_PROMPT.replace(
|
||||
'{question}',
|
||||
"Based on the knowledge references, please answer the user's question with proper citations."
|
||||
).replace('{references}', 'knowledgeReferences:')
|
||||
|
||||
// 返回结果
|
||||
return {
|
||||
summary: `Found ${knowledgeReferencesData.length} relevant sources. Use [number] format to cite specific information.`,
|
||||
knowledgeReferences: knowledgeReferencesData,
|
||||
instructions: fullInstructions
|
||||
}
|
||||
} catch (error) {
|
||||
// 返回空对象而不是抛出错误,避免中断对话流程
|
||||
return {
|
||||
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
knowledgeReferences: [],
|
||||
instructions: ''
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import store from '@renderer/store'
|
||||
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
||||
import type { Assistant } from '@renderer/types'
|
||||
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
|
||||
import { z } from 'zod'
|
||||
|
||||
@ -19,123 +18,29 @@ export const memorySearchTool = () => {
|
||||
limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return')
|
||||
}),
|
||||
execute: async ({ query, limit = 5 }) => {
|
||||
try {
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
if (!globalMemoryEnabled) {
|
||||
return []
|
||||
}
|
||||
|
||||
const memoryConfig = selectMemoryConfig(store.getState())
|
||||
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||
return []
|
||||
}
|
||||
|
||||
const currentUserId = selectCurrentUserId(store.getState())
|
||||
const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, 'default', currentUserId)
|
||||
|
||||
const memoryProcessor = new MemoryProcessor()
|
||||
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
|
||||
|
||||
if (relevantMemories?.length > 0) {
|
||||
return relevantMemories
|
||||
}
|
||||
return []
|
||||
} catch (error) {
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
if (!globalMemoryEnabled) {
|
||||
return []
|
||||
}
|
||||
|
||||
const memoryConfig = selectMemoryConfig(store.getState())
|
||||
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||
return []
|
||||
}
|
||||
|
||||
const currentUserId = selectCurrentUserId(store.getState())
|
||||
const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, 'default', currentUserId)
|
||||
|
||||
const memoryProcessor = new MemoryProcessor()
|
||||
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
|
||||
|
||||
if (relevantMemories?.length > 0) {
|
||||
return relevantMemories
|
||||
}
|
||||
return []
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 方案4: 为第二个工具也使用类型断言
|
||||
type MessageRole = 'user' | 'assistant' | 'system'
|
||||
type MessageType = {
|
||||
content: string
|
||||
role: MessageRole
|
||||
}
|
||||
type MemorySearchWithExtractionInput = {
|
||||
userMessage: MessageType
|
||||
lastAnswer?: MessageType
|
||||
}
|
||||
|
||||
/**
|
||||
* 🧠 智能记忆搜索工具(带上下文提取)
|
||||
* 从用户消息和对话历史中自动提取关键词进行记忆搜索
|
||||
*/
|
||||
export const memorySearchToolWithExtraction = (assistant: Assistant) => {
|
||||
return tool({
|
||||
name: 'memory_search_with_extraction',
|
||||
description: 'Search memories with automatic keyword extraction from conversation context',
|
||||
inputSchema: z.object({
|
||||
userMessage: z.object({
|
||||
content: z.string().describe('The main content of the user message'),
|
||||
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
}),
|
||||
lastAnswer: z
|
||||
.object({
|
||||
content: z.string().describe('The main content of the last assistant response'),
|
||||
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
})
|
||||
.optional()
|
||||
}) satisfies z.ZodSchema<MemorySearchWithExtractionInput>,
|
||||
execute: async ({ userMessage }) => {
|
||||
try {
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||
return {
|
||||
extractedKeywords: 'Memory search disabled',
|
||||
searchResults: []
|
||||
}
|
||||
}
|
||||
|
||||
const memoryConfig = selectMemoryConfig(store.getState())
|
||||
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||
return {
|
||||
extractedKeywords: 'Memory models not configured',
|
||||
searchResults: []
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 使用用户消息内容作为搜索关键词
|
||||
const content = userMessage.content
|
||||
|
||||
if (!content) {
|
||||
return {
|
||||
extractedKeywords: 'No content to search',
|
||||
searchResults: []
|
||||
}
|
||||
}
|
||||
|
||||
const currentUserId = selectCurrentUserId(store.getState())
|
||||
const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, assistant.id, currentUserId)
|
||||
|
||||
const memoryProcessor = new MemoryProcessor()
|
||||
const relevantMemories = await memoryProcessor.searchRelevantMemories(
|
||||
content,
|
||||
processorConfig,
|
||||
5 // Limit to top 5 most relevant memories
|
||||
)
|
||||
|
||||
if (relevantMemories?.length > 0) {
|
||||
return {
|
||||
extractedKeywords: content,
|
||||
searchResults: relevantMemories
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
extractedKeywords: content,
|
||||
searchResults: []
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
extractedKeywords: 'Search failed',
|
||||
searchResults: []
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
export type MemorySearchToolInput = InferToolInput<ReturnType<typeof memorySearchTool>>
|
||||
export type MemorySearchToolOutput = InferToolOutput<ReturnType<typeof memorySearchTool>>
|
||||
export type MemorySearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof memorySearchToolWithExtraction>>
|
||||
|
||||
@ -30,8 +30,6 @@ Relevant links: ${extractedKeywords.links.join(', ')}`
|
||||
: ''
|
||||
}
|
||||
|
||||
This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.
|
||||
|
||||
Call this tool to execute the search. You can optionally provide additional context to refine the search.`,
|
||||
|
||||
inputSchema: z.object({
|
||||
@ -58,40 +56,27 @@ Call this tool to execute the search. You can optionally provide additional cont
|
||||
}
|
||||
// 检查是否需要搜索
|
||||
if (finalQueries[0] === 'not_needed') {
|
||||
return {
|
||||
summary: 'No search needed based on the query analysis.',
|
||||
searchResults,
|
||||
sources: '',
|
||||
instructions: ''
|
||||
}
|
||||
return searchResults
|
||||
}
|
||||
|
||||
try {
|
||||
// 构建 ExtractResults 结构用于 processWebsearch
|
||||
const extractResults: ExtractResults = {
|
||||
websearch: {
|
||||
question: finalQueries,
|
||||
links: extractedKeywords.links
|
||||
}
|
||||
}
|
||||
searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||
} catch (error) {
|
||||
return {
|
||||
summary: `Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
sources: [],
|
||||
instructions: ''
|
||||
// 构建 ExtractResults 结构用于 processWebsearch
|
||||
const extractResults: ExtractResults = {
|
||||
websearch: {
|
||||
question: finalQueries,
|
||||
links: extractedKeywords.links
|
||||
}
|
||||
}
|
||||
if (searchResults.results.length === 0) {
|
||||
return {
|
||||
summary: 'No search results found for the given query.',
|
||||
sources: [],
|
||||
instructions: ''
|
||||
}
|
||||
searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||
|
||||
return searchResults
|
||||
},
|
||||
toModelOutput: (results) => {
|
||||
let summary = 'No search needed based on the query analysis.'
|
||||
if (results.query && results.results.length > 0) {
|
||||
summary = `Found ${results.results.length} relevant sources. Use [number] format to cite specific information.`
|
||||
}
|
||||
|
||||
const results = searchResults.results
|
||||
const citationData = results.map((result, index) => ({
|
||||
const citationData = results.results.map((result, index) => ({
|
||||
number: index + 1,
|
||||
title: result.title,
|
||||
content: result.content,
|
||||
@ -99,18 +84,27 @@ Call this tool to execute the search. You can optionally provide additional cont
|
||||
}))
|
||||
|
||||
// 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑
|
||||
// const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
|
||||
|
||||
// 构建完整的引用指导文本
|
||||
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
|
||||
const fullInstructions = REFERENCE_PROMPT.replace(
|
||||
'{question}',
|
||||
"Based on the search results, please answer the user's question with proper citations."
|
||||
).replace('{references}', 'searchResults:')
|
||||
|
||||
).replace('{references}', referenceContent)
|
||||
return {
|
||||
summary: `Found ${citationData.length} relevant sources. Use [number] format to cite specific information.`,
|
||||
searchResults,
|
||||
instructions: fullInstructions
|
||||
type: 'content',
|
||||
value: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'This tool searches for relevant information and formats results for easy citation. The returned sources should be cited using [1], [2], etc. format in your response.'
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: summary
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: fullInstructions
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@ -120,6 +120,9 @@ export function buildProviderOptions(
|
||||
case 'google-vertex':
|
||||
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
case 'google-vertex-anthropic':
|
||||
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
default:
|
||||
// 对于其他 provider,使用通用的构建逻辑
|
||||
providerSpecificOptions = {
|
||||
@ -137,10 +140,16 @@ export function buildProviderOptions(
|
||||
...providerSpecificOptions,
|
||||
...getCustomParameters(assistant)
|
||||
}
|
||||
// vertex需要映射到google或anthropic
|
||||
const rawProviderKey =
|
||||
{
|
||||
'google-vertex': 'google',
|
||||
'google-vertex-anthropic': 'anthropic'
|
||||
}[rawProviderId] || rawProviderId
|
||||
|
||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
|
||||
return {
|
||||
[rawProviderId]: providerSpecificOptions
|
||||
[rawProviderKey]: providerSpecificOptions
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ export function MessageKnowledgeSearchToolTitle({ toolResponse }: { toolResponse
|
||||
) : (
|
||||
<MessageWebSearchToolTitleTextWrapper type="secondary">
|
||||
<FileSearch size={16} style={{ color: 'unset' }} />
|
||||
{i18n.t('message.websearch.fetch_complete', { count: toolOutput.knowledgeReferences.length ?? 0 })}
|
||||
{i18n.t('message.websearch.fetch_complete', { count: toolOutput.length ?? 0 })}
|
||||
</MessageWebSearchToolTitleTextWrapper>
|
||||
)
|
||||
}
|
||||
@ -33,7 +33,7 @@ export function MessageKnowledgeSearchToolBody({ toolResponse }: { toolResponse:
|
||||
|
||||
return toolResponse.status === 'done' ? (
|
||||
<MessageWebSearchToolBodyUlWrapper>
|
||||
{toolOutput.knowledgeReferences.map((result) => (
|
||||
{toolOutput.map((result) => (
|
||||
<li key={result.id}>
|
||||
<span>{result.id}</span>
|
||||
<span>{result.content}</span>
|
||||
|
||||
@ -26,7 +26,7 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT
|
||||
<MessageWebSearchToolTitleTextWrapper type="secondary">
|
||||
<Search size={16} style={{ color: 'unset' }} />
|
||||
{t('message.websearch.fetch_complete', {
|
||||
count: toolOutput?.searchResults?.results?.length ?? 0
|
||||
count: toolOutput?.results?.length ?? 0
|
||||
})}
|
||||
</MessageWebSearchToolTitleTextWrapper>
|
||||
)
|
||||
@ -36,7 +36,7 @@ export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: MCPT
|
||||
// const toolOutput = toolResponse.response as WebSearchToolOutput
|
||||
|
||||
// return toolResponse.status === 'done'
|
||||
// ? toolOutput?.searchResults?.map((result, index) => (
|
||||
// ? toolOutput?.results?.map((result, index) => (
|
||||
// <MessageWebSearchToolBodyUlWrapper key={result?.query ?? '' + index}>
|
||||
// {result.results.map((item, index) => (
|
||||
// <li key={item.url + index}>
|
||||
|
||||
@ -134,6 +134,7 @@ export async function fetchChatCompletion({
|
||||
isImageGenerationEndpoint: isDedicatedImageGenerationModel(assistant.model || getDefaultModel()),
|
||||
enableWebSearch: capabilities.enableWebSearch,
|
||||
enableGenerateImage: capabilities.enableGenerateImage,
|
||||
enableUrlContext: capabilities.enableUrlContext,
|
||||
mcpTools,
|
||||
uiMessages
|
||||
}
|
||||
@ -222,6 +223,7 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
|
||||
isImageGenerationEndpoint: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false,
|
||||
enableUrlContext: false,
|
||||
mcpTools: []
|
||||
}
|
||||
try {
|
||||
@ -308,7 +310,8 @@ export async function fetchGenerate({
|
||||
isSupportedToolUse: false,
|
||||
isImageGenerationEndpoint: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
enableGenerateImage: false,
|
||||
enableUrlContext: false
|
||||
}
|
||||
|
||||
try {
|
||||
@ -420,6 +423,7 @@ export async function checkApi(provider: Provider, model: Model, timeout = 15000
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false,
|
||||
isPromptToolUse: false,
|
||||
enableUrlContext: false,
|
||||
assistant,
|
||||
callType: 'check',
|
||||
onChunk: (chunk: Chunk) => {
|
||||
|
||||
@ -83,13 +83,12 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
||||
}
|
||||
}
|
||||
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
|
||||
|
||||
// Handle citation block creation for web search results
|
||||
if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response?.searchResults) {
|
||||
if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response) {
|
||||
const citationBlock = createCitationBlock(
|
||||
assistantMsgId,
|
||||
{
|
||||
response: { results: toolResponse.response.searchResults, source: WebSearchSource.WEBSEARCH }
|
||||
response: { results: toolResponse.response, source: WebSearchSource.WEBSEARCH }
|
||||
},
|
||||
{
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
@ -98,10 +97,10 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
||||
citationBlockId = citationBlock.id
|
||||
blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
}
|
||||
if (toolResponse.tool.name === 'builtin_knowledge_search' && toolResponse.response?.knowledgeReferences) {
|
||||
if (toolResponse.tool.name === 'builtin_knowledge_search' && toolResponse.response) {
|
||||
const citationBlock = createCitationBlock(
|
||||
assistantMsgId,
|
||||
{ knowledge: toolResponse.response.knowledgeReferences },
|
||||
{ knowledge: toolResponse.response },
|
||||
{
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
|
||||
79
yarn.lock
79
yarn.lock
@ -90,6 +90,18 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/anthropic@npm:2.0.15":
|
||||
version: 2.0.15
|
||||
resolution: "@ai-sdk/anthropic@npm:2.0.15"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.8"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4
|
||||
checksum: 10c0/9597b32be8b83dab67b23f162ca66cde385213fb1665f54091d59430789becf73e2b4fcd2be66ceb13020409f59cd8f9da7dae23adf183bc9eb7ce94f55bde96
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/anthropic@npm:2.0.4":
|
||||
version: 2.0.4
|
||||
resolution: "@ai-sdk/anthropic@npm:2.0.4"
|
||||
@ -140,46 +152,34 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/gateway@npm:1.0.15":
|
||||
version: 1.0.15
|
||||
resolution: "@ai-sdk/gateway@npm:1.0.15"
|
||||
"@ai-sdk/gateway@npm:1.0.20":
|
||||
version: 1.0.20
|
||||
resolution: "@ai-sdk/gateway@npm:1.0.20"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.7"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.8"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4
|
||||
checksum: 10c0/cdd09f119d6618f00c363a27f51dc466a8a64f57f01bcdd127030a804825bd143b0fef2dbdb7802530865d474f4b9d55855670fecd7f2e6c615a5d9ac9fd6e3b
|
||||
checksum: 10c0/c25e98aab2513f783b2b552245b027e5a73b209d974e25bbfae0e69b67fd3468bba0bf57085ca3d7259b4dc8881e7f40fca769f698f0b1eb028a849f587ad09c
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/google-vertex@npm:^3.0.0":
|
||||
version: 3.0.9
|
||||
resolution: "@ai-sdk/google-vertex@npm:3.0.9"
|
||||
"@ai-sdk/google-vertex@npm:^3.0.25":
|
||||
version: 3.0.25
|
||||
resolution: "@ai-sdk/google-vertex@npm:3.0.25"
|
||||
dependencies:
|
||||
"@ai-sdk/anthropic": "npm:2.0.4"
|
||||
"@ai-sdk/google": "npm:2.0.6"
|
||||
"@ai-sdk/anthropic": "npm:2.0.15"
|
||||
"@ai-sdk/google": "npm:2.0.13"
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.3"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.8"
|
||||
google-auth-library: "npm:^9.15.0"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4
|
||||
checksum: 10c0/c6584b877f9e20a10dd7d92fc4cb1b4a9838510aa89734cf1ff2faa74ba820b976d3359d4eadcb6035c8911973300efb157931fa0d1105abc8db36f94544cc88
|
||||
checksum: 10c0/ed67a439fc4a446aa7353d258c61497198aecdf0de55500d2abbea86109bbf1ff4570fffdfcf58508db1c887a2095a71322777634f76326a45e259d28ef0b801
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/google@npm:2.0.6":
|
||||
version: 2.0.6
|
||||
resolution: "@ai-sdk/google@npm:2.0.6"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.3"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4
|
||||
checksum: 10c0/ad54dd4168df62851646bec3ac2e5cf9e39f3def3e9017579aef5c8e8ecdf57c150c67a80cad4d092c3df69cd8539bc1792adb6c311ed095f8261673b7812e98
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/google@npm:^2.0.13":
|
||||
"@ai-sdk/google@npm:2.0.13, @ai-sdk/google@npm:^2.0.13":
|
||||
version: 2.0.13
|
||||
resolution: "@ai-sdk/google@npm:2.0.13"
|
||||
dependencies:
|
||||
@ -267,19 +267,6 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/provider-utils@npm:3.0.7":
|
||||
version: 3.0.7
|
||||
resolution: "@ai-sdk/provider-utils@npm:3.0.7"
|
||||
dependencies:
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@standard-schema/spec": "npm:^1.0.0"
|
||||
eventsource-parser: "npm:^3.0.5"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4
|
||||
checksum: 10c0/7e709289f9e514a6ba56a9b19764eb124ea1bd36d4b3b3e455a1c05353674c152839a4d3cd061af7a4cc36106bd15859a2346e54d4ed0a861feec3b2c4c21513
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@ai-sdk/provider-utils@npm:3.0.8":
|
||||
version: 3.0.8
|
||||
resolution: "@ai-sdk/provider-utils@npm:3.0.8"
|
||||
@ -12829,7 +12816,7 @@ __metadata:
|
||||
"@agentic/searxng": "npm:^7.3.3"
|
||||
"@agentic/tavily": "npm:^7.3.3"
|
||||
"@ai-sdk/amazon-bedrock": "npm:^3.0.0"
|
||||
"@ai-sdk/google-vertex": "npm:^3.0.0"
|
||||
"@ai-sdk/google-vertex": "npm:^3.0.25"
|
||||
"@ai-sdk/mistral": "npm:^2.0.0"
|
||||
"@ant-design/v5-patch-for-react-19": "npm:^1.0.3"
|
||||
"@anthropic-ai/sdk": "npm:^0.41.0"
|
||||
@ -12944,7 +12931,7 @@ __metadata:
|
||||
"@viz-js/lang-dot": "npm:^1.0.5"
|
||||
"@viz-js/viz": "npm:^3.14.0"
|
||||
"@xyflow/react": "npm:^12.4.4"
|
||||
ai: "npm:^5.0.29"
|
||||
ai: "npm:^5.0.38"
|
||||
antd: "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch"
|
||||
archiver: "npm:^7.0.1"
|
||||
async-mutex: "npm:^0.5.0"
|
||||
@ -13198,17 +13185,17 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"ai@npm:^5.0.29":
|
||||
version: 5.0.29
|
||||
resolution: "ai@npm:5.0.29"
|
||||
"ai@npm:^5.0.38":
|
||||
version: 5.0.38
|
||||
resolution: "ai@npm:5.0.38"
|
||||
dependencies:
|
||||
"@ai-sdk/gateway": "npm:1.0.15"
|
||||
"@ai-sdk/gateway": "npm:1.0.20"
|
||||
"@ai-sdk/provider": "npm:2.0.0"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.7"
|
||||
"@ai-sdk/provider-utils": "npm:3.0.8"
|
||||
"@opentelemetry/api": "npm:1.9.0"
|
||||
peerDependencies:
|
||||
zod: ^3.25.76 || ^4
|
||||
checksum: 10c0/526cd2fd59b35b19d902665e3dc1ba5a09f2bb1377295d642fb8a33e13a890874e4dd4b49a787de7f31f4ec6b07257be8514efac08f993081daeb430cf2f60ba
|
||||
checksum: 10c0/9ea7a76ae5609574e9edb2f9541e2fe9cf0e7296547c5e9ae30ec000206c967b4c07fbb03b85f9027493f6877e15f6bfbe454faa793fca860826acf306982fc5
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user