From 4573e3f48feb7d5ab21d687fbff9537964583cdb Mon Sep 17 00:00:00 2001 From: lizhixuan Date: Mon, 7 Jul 2025 23:28:49 +0800 Subject: [PATCH] feat: add XAI provider options and enhance web search plugin - Introduced `createXaiOptions` function for XAI provider configuration. - Added `XaiProviderOptions` type and validation schema in `xai.ts`. - Updated `ProviderOptionsMap` to include XAI options. - Enhanced `webSearchPlugin` to support XAI-specific search parameters. - Refactored helper functions to integrate new XAI options into provider configurations. --- packages/aiCore/src/core/options/factory.ts | 7 + packages/aiCore/src/core/options/types.ts | 2 + packages/aiCore/src/core/options/xai.ts | 86 +++++++++ .../built-in/webSearchPlugin/helper.ts | 170 ++++++------------ .../plugins/built-in/webSearchPlugin/index.ts | 81 +++++---- src/renderer/src/aiCore/index_new.ts | 4 +- .../src/aiCore/transformParameters.ts | 12 +- src/renderer/src/aiCore/utils/websearch.ts | 66 +++---- 8 files changed, 237 insertions(+), 191 deletions(-) create mode 100644 packages/aiCore/src/core/options/xai.ts diff --git a/packages/aiCore/src/core/options/factory.ts b/packages/aiCore/src/core/options/factory.ts index 45eea32b98..4350e9241b 100644 --- a/packages/aiCore/src/core/options/factory.ts +++ b/packages/aiCore/src/core/options/factory.ts @@ -62,3 +62,10 @@ export function createGoogleOptions(options: ExtractProviderOptions<'google'>) { export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'>) { return createProviderOptions('openrouter', options) } + +/** + * 创建XAI供应商选项的便捷函数 + */ +export function createXaiOptions(options: ExtractProviderOptions<'xai'>) { + return createProviderOptions('xai', options) +} diff --git a/packages/aiCore/src/core/options/types.ts b/packages/aiCore/src/core/options/types.ts index 5e78323ce5..724dc30698 100644 --- a/packages/aiCore/src/core/options/types.ts +++ b/packages/aiCore/src/core/options/types.ts @@ -4,6 +4,7 @@ import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai' import { type SharedV2ProviderMetadata } from '@ai-sdk/provider' import { type OpenRouterProviderOptions } from './openrouter' +import { type XaiProviderOptions } from './xai' export type ProviderOptions = SharedV2ProviderMetadata[T] @@ -15,6 +16,7 @@ export type ProviderOptionsMap = { anthropic: AnthropicProviderOptions google: GoogleGenerativeAIProviderOptions openrouter: OpenRouterProviderOptions + xai: XaiProviderOptions } // 工具类型,用于从ProviderOptionsMap中提取特定供应商的选项类型 diff --git a/packages/aiCore/src/core/options/xai.ts b/packages/aiCore/src/core/options/xai.ts new file mode 100644 index 0000000000..13e3e7dd6b --- /dev/null +++ b/packages/aiCore/src/core/options/xai.ts @@ -0,0 +1,86 @@ +// copy from @ai-sdk/xai/xai-chat-options.ts +// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件 + +import { z } from 'zod' + +const webSourceSchema = z.object({ + type: z.literal('web'), + country: z.string().length(2).optional(), + excludedWebsites: z.array(z.string()).max(5).optional(), + allowedWebsites: z.array(z.string()).max(5).optional(), + safeSearch: z.boolean().optional() +}) + +const xSourceSchema = z.object({ + type: z.literal('x'), + xHandles: z.array(z.string()).optional() +}) + +const newsSourceSchema = z.object({ + type: z.literal('news'), + country: z.string().length(2).optional(), + excludedWebsites: z.array(z.string()).max(5).optional(), + safeSearch: z.boolean().optional() +}) + +const rssSourceSchema = z.object({ + type: z.literal('rss'), + links: z.array(z.string().url()).max(1) // currently only supports one RSS link +}) + +const searchSourceSchema = z.discriminatedUnion('type', [ + webSourceSchema, + xSourceSchema, + newsSourceSchema, + rssSourceSchema +]) + +export const xaiProviderOptions = z.object({ + /** + * reasoning effort for reasoning models + * only supported by grok-3-mini and grok-3-mini-fast models + */ + reasoningEffort: z.enum(['low', 'high']).optional(), + + searchParameters: z + .object({ + /** + * search mode preference + * - "off": disables search completely + * - "auto": model decides whether to search (default) + * - "on": always enables search + */ + mode: z.enum(['off', 'auto', 'on']), + + /** + * whether to return citations in the response + * defaults to true + */ + returnCitations: z.boolean().optional(), + + /** + * start date for search data (ISO8601 format: YYYY-MM-DD) + */ + fromDate: z.string().optional(), + + /** + * end date for search data (ISO8601 format: YYYY-MM-DD) + */ + toDate: z.string().optional(), + + /** + * maximum number of search results to consider + * defaults to 20 + */ + maxSearchResults: z.number().min(1).max(50).optional(), + + /** + * data sources to search from + * defaults to ["web", "x"] if not specified + */ + sources: z.array(searchSourceSchema).optional() + }) + .optional() +}) + +export type XaiProviderOptions = z.infer diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts index 2f14d648a2..a06a51449a 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts @@ -1,132 +1,78 @@ -/** - * 网络搜索助手函数 - * 提取各个 ApiClient 中的网络搜索逻辑,提供统一的适配器 - */ -import type { OpenAIProvider } from '@ai-sdk/openai' +import type { anthropic } from '@ai-sdk/anthropic' +import type { openai } from '@ai-sdk/openai' -import { ProviderId } from '../../../../types' - -// 派生自 OpenAI SDK 的标准工具入参类型 -type WebSearchPreviewParams = Parameters[0] - -// 使用交叉类型合并,并为 extra 添加注释 -export type WebSearchConfig = WebSearchPreviewParams & { - /** - * 扩展字段,用于提供给开发者自定义参数的能力 - * 这些参数将被合并到对应 provider 的 providerOptions 中 - */ - extra?: Record -} +import { ProviderOptionsMap } from '../../../options/types' /** - * 适配 OpenAI 网络搜索 - * 基于 Vercel AI SDK 的 web_search_preview 工具 + * 从 AI SDK 的工具函数中提取参数类型,以确保类型安全。 */ -export function adaptOpenAIWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any { - const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig - const { extra, ...stdParams } = config +type OpenAISearchConfig = Parameters[0] +type AnthropicSearchConfig = Parameters[0] - const webSearchTool = { - type: 'web_search_preview', - ...stdParams - } - - // 假设 params.tools 是一个数组或 undefined - const existingTools = Array.isArray(params.tools) ? params.tools : [] - - // 将 extra 参数添加到 providerOptions 中 - const providerOptions = { - ...params.providerOptions, - openai: { - ...params.providerOptions?.openai, - ...(extra || {}) - } - } - - return { - ...params, - tools: [...existingTools, webSearchTool], - providerOptions +/** + * XAI 特有的搜索参数 + * @internal + */ +interface XaiProviderOptions { + searchParameters?: { + sources?: any[] + safeSearch?: boolean } } /** + * 插件初始化时接收的完整配置对象 * - * 适配 Gemini 网络搜索 - * 将 googleSearch 工具放入 providerOptions.google.tools + * 其结构与 ProviderOptions 保持一致,方便上游统一管理配置 */ -// export function adaptGeminiWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any { -// const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig -// const googleSearchTool = { googleSearch: {} } - -// const existingTools = Array.isArray(params.providerOptions?.google?.tools) ? params.providerOptions.google.tools : [] - -// return { -// ...params, -// providerOptions: { -// ...params.providerOptions, -// google: { -// ...params.providerOptions?.google, -// useSearchGrounding: true, -// // tools: [...existingTools, googleSearchTool], -// ...(config.extra || {}) -// } -// } -// } -// } +export interface WebSearchPluginConfig { + openai?: OpenAISearchConfig + anthropic?: AnthropicSearchConfig + xai?: ProviderOptionsMap['xai']['searchParameters'] + google?: Pick + 'google-vertex'?: Pick +} /** - * 适配 Anthropic 网络搜索 - * 将 web_search_20250305 工具放入 providerOptions.anthropic.tools + * 插件的默认配置 */ -export function adaptAnthropicWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any { - const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig - const webSearchTool = { - type: 'web_search_20250305', - name: 'web_search', - max_uses: 5 // 默认值,可以通过 extra 覆盖 - } - - const existingTools = Array.isArray(params.providerOptions?.anthropic?.tools) - ? params.providerOptions.anthropic.tools - : [] - - return { - ...params, - providerOptions: { - ...params.providerOptions, - anthropic: { - ...params.providerOptions?.anthropic, - tools: [...existingTools, webSearchTool], - ...(config.extra || {}) - } - } +export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = { + google: { + useSearchGrounding: true + }, + 'google-vertex': { + useSearchGrounding: true + }, + openai: {}, + xai: { + mode: 'on', + returnCitations: true, + maxSearchResults: 5 + }, + anthropic: { + maxUses: 5 } } /** - * 通用网络搜索适配器 - * 根据 providerId 选择对应的适配函数 + * 根据配置构建 Google 的 providerOptions */ -export function adaptWebSearchForProvider( - params: any, - providerId: ProviderId, - webSearchConfig: WebSearchConfig | boolean -): any { - switch (providerId) { - case 'openai': - return adaptOpenAIWebSearch(params, webSearchConfig) - - // google的需要通过插件,在创建model的时候传入参数 - // case 'google': - // case 'google-vertex': - // return adaptGeminiWebSearch(params, webSearchConfig) - - case 'anthropic': - return adaptAnthropicWebSearch(params, webSearchConfig) - - default: - // 不支持的 provider,保持原样 - return params - } +export const getGoogleProviderOptions = (providerOptions: any) => { + if (!providerOptions) providerOptions = {} + if (!providerOptions.google) providerOptions.google = {} + providerOptions.google.useSearchGrounding = true + return providerOptions +} + +/** + * 根据配置构建 XAI 的 providerOptions + */ +export const getXaiProviderOptions = (providerOptions: any, config?: XaiProviderOptions['searchParameters']) => { + if (!providerOptions) providerOptions = {} + if (!providerOptions.xai) providerOptions.xai = {} + providerOptions.xai.searchParameters = { + mode: 'on', + ...(config ?? {}) + } + return providerOptions } diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts index ed4c700ab1..d84320cbe0 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts @@ -2,66 +2,71 @@ * Web Search Plugin * 提供统一的网络搜索能力,支持多个 AI Provider */ +import { anthropic } from '@ai-sdk/anthropic' +import { openai } from '@ai-sdk/openai' +import { createGoogleOptions, createXaiOptions, mergeProviderOptions } from '../../../options' import { definePlugin } from '../../' import type { AiRequestContext } from '../../types' -import { adaptWebSearchForProvider, type WebSearchConfig } from './helper' +import { DEFAULT_WEB_SEARCH_CONFIG, WebSearchPluginConfig } from './helper' /** * 网络搜索插件 * - * 此插件会检查 params.providerOptions.[providerId].webSearch 来激活。 - * options.ts 文件负责将高层级的设置(如 assistant.enableWebSearch) - * 转换为 providerOptions 中的 webSearch: { enabled: true } 配置。 + * @param config - 在插件初始化时传入的静态配置 */ -export const webSearchPlugin = () => +export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEARCH_CONFIG) => definePlugin({ name: 'webSearch', enforce: 'pre', - // configureModel: async (modelConfig: any, context: AiRequestContext) => { - // if (context.providerId === 'google') { - // return { - // ...modelConfig - // } - // } - // return null - // }, - transformParams: async (params: any, context: AiRequestContext) => { const { providerId } = context - // 从 providerOptions 中提取 webSearch 配置 - const webSearchConfig = params.providerOptions?.[providerId]?.webSearch + switch (providerId) { + case 'openai': { + if (config.openai) { + if (!params.tools) params.tools = {} + params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai) + } + break + } - // 检查是否启用了网络搜索 (enabled: false 可用于显式禁用) - if (!webSearchConfig || (typeof webSearchConfig === 'object' && webSearchConfig.enabled === false)) { - return params - } - console.log('webSearchConfig', webSearchConfig) - // // 检查当前 provider 是否支持网络搜索 - // if (!isWebSearchSupported(providerId)) { - // // 对于不支持的 provider,只记录警告,不修改参数 - // console.warn( - // `[webSearchPlugin] Provider '${providerId}' does not support web search. Ignoring webSearch parameter.` - // ) - // return params - // } + case 'anthropic': { + if (config.anthropic) { + if (!params.tools) params.tools = {} + params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic) + } + break + } - // 使用适配器函数处理网络搜索 - const adaptedParams = adaptWebSearchForProvider(params, providerId, webSearchConfig as WebSearchConfig | boolean) - // 清理原始的 webSearch 配置 - if (adaptedParams.providerOptions?.[providerId]) { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { webSearch, ...rest } = adaptedParams.providerOptions[providerId] - adaptedParams.providerOptions[providerId] = rest + case 'google': + case 'google-vertex': { + // @ts-ignore - providerId is a string that can be used to index config + if (config[providerId]) { + const searchOptions = createGoogleOptions({ useSearchGrounding: true }) + params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) + } + break + } + + case 'xai': { + if (config.xai) { + const searchOptions = createXaiOptions({ + searchParameters: { ...config.xai, mode: 'on' } + }) + params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions) + } + break + } } - return adaptedParams + + return params } }) // 导出类型定义供开发者使用 -export type { WebSearchConfig } from './helper' +export type { WebSearchPluginConfig } from './helper' // 默认导出 export default webSearchPlugin diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 43e1c4ba82..3f6cc77e18 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -17,7 +17,7 @@ import { type ProviderSettingsMap, StreamTextParams } from '@cherrystudio/ai-core' -import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/core/plugins/built-in' +import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/core/plugins/built-in' import { isDedicatedImageGenerationModel } from '@renderer/config/models' import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import type { GenerateImageParams, Model, Provider } from '@renderer/types' @@ -143,7 +143,7 @@ export default class ModernAiProvider { const plugins: AiPlugin[] = [] // 1. 总是添加通用插件 // plugins.push(textPlugin) - // plugins.push(webSearchPlugin) + plugins.push(webSearchPlugin()) // 2. 推理模型时添加推理插件 if (middlewareConfig.enableReasoning) { diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts index 8488315c44..986f628bdf 100644 --- a/src/renderer/src/aiCore/transformParameters.ts +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -42,7 +42,7 @@ import { defaultTimeout } from '@shared/config/constant' // import { jsonSchemaToZod } from 'json-schema-to-zod' import { setupToolsConfig } from './utils/mcp' import { buildProviderOptions } from './utils/options' -import { getWebSearchTools } from './utils/websearch' +// import { getWebSearchTools } from './utils/websearch' /** * 获取温度参数 @@ -279,17 +279,17 @@ export async function buildStreamTextParams( (isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true) // 构建系统提示 - let { tools } = setupToolsConfig({ + const { tools } = setupToolsConfig({ mcpTools, model, enableToolUse: enableTools }) // Add web search tools if enabled - if (enableWebSearch) { - const webSearchTools = getWebSearchTools(model) - tools = { ...tools, ...webSearchTools } - } + // if (enableWebSearch) { + // const webSearchTools = getWebSearchTools(model) + // tools = { ...tools, ...webSearchTools } + // } // 构建真正的 providerOptions const providerOptions = buildProviderOptions(assistant, model, { diff --git a/src/renderer/src/aiCore/utils/websearch.ts b/src/renderer/src/aiCore/utils/websearch.ts index 70da81c1da..a59cbf5fbd 100644 --- a/src/renderer/src/aiCore/utils/websearch.ts +++ b/src/renderer/src/aiCore/utils/websearch.ts @@ -1,37 +1,37 @@ -import { isWebSearchModel } from '@renderer/config/models' -import { Model } from '@renderer/types' -// import {} from '@cherrystudio/ai-core' +// import { isWebSearchModel } from '@renderer/config/models' +// import { Model } from '@renderer/types' +// // import {} from '@cherrystudio/ai-core' -// The tool name for Gemini search can be arbitrary, but let's use a descriptive one. -const GEMINI_SEARCH_TOOL_NAME = 'google_search' +// // The tool name for Gemini search can be arbitrary, but let's use a descriptive one. +// const GEMINI_SEARCH_TOOL_NAME = 'google_search' -export function getWebSearchTools(model: Model): Record { - if (!isWebSearchModel(model)) { - return {} - } +// export function getWebSearchTools(model: Model): Record { +// if (!isWebSearchModel(model)) { +// return {} +// } - // Use provider from model if available, otherwise fallback to parsing model id. - const provider = model.provider || model.id.split('/')[0] +// // Use provider from model if available, otherwise fallback to parsing model id. +// const provider = model.provider || model.id.split('/')[0] - switch (provider) { - case 'anthropic': - return { - web_search: { - type: 'web_search_20250305', - name: 'web_search', - max_uses: 5 - } - } - case 'google': - case 'gemini': - return { - [GEMINI_SEARCH_TOOL_NAME]: { - googleSearch: {} - } - } - default: - // For OpenAI and others, web search is often a parameter, not a tool. - // The logic is handled in `buildProviderOptions`. - return {} - } -} +// switch (provider) { +// case 'anthropic': +// return { +// web_search: { +// type: 'web_search_20250305', +// name: 'web_search', +// max_uses: 5 +// } +// } +// case 'google': +// case 'gemini': +// return { +// [GEMINI_SEARCH_TOOL_NAME]: { +// googleSearch: {} +// } +// } +// default: +// // For OpenAI and others, web search is often a parameter, not a tool. +// // The logic is handled in `buildProviderOptions`. +// return {} +// } +// }