feat: add XAI provider options and enhance web search plugin

- Introduced `createXaiOptions` function for XAI provider configuration.
- Added `XaiProviderOptions` type and validation schema in `xai.ts`.
- Updated `ProviderOptionsMap` to include XAI options.
- Enhanced `webSearchPlugin` to support XAI-specific search parameters.
- Refactored helper functions to integrate new XAI options into provider configurations.
This commit is contained in:
lizhixuan 2025-07-07 23:28:49 +08:00
parent 56c5e5a80f
commit 4573e3f48f
8 changed files with 237 additions and 191 deletions

View File

@ -62,3 +62,10 @@ export function createGoogleOptions(options: ExtractProviderOptions<'google'>) {
export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'>) { export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'>) {
return createProviderOptions('openrouter', options) return createProviderOptions('openrouter', options)
} }
/**
* XAI供应商选项的便捷函数
*/
export function createXaiOptions(options: ExtractProviderOptions<'xai'>) {
return createProviderOptions('xai', options)
}

View File

@ -4,6 +4,7 @@ import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import { type SharedV2ProviderMetadata } from '@ai-sdk/provider' import { type SharedV2ProviderMetadata } from '@ai-sdk/provider'
import { type OpenRouterProviderOptions } from './openrouter' import { type OpenRouterProviderOptions } from './openrouter'
import { type XaiProviderOptions } from './xai'
export type ProviderOptions<T extends keyof SharedV2ProviderMetadata> = SharedV2ProviderMetadata[T] export type ProviderOptions<T extends keyof SharedV2ProviderMetadata> = SharedV2ProviderMetadata[T]
@ -15,6 +16,7 @@ export type ProviderOptionsMap = {
anthropic: AnthropicProviderOptions anthropic: AnthropicProviderOptions
google: GoogleGenerativeAIProviderOptions google: GoogleGenerativeAIProviderOptions
openrouter: OpenRouterProviderOptions openrouter: OpenRouterProviderOptions
xai: XaiProviderOptions
} }
// 工具类型用于从ProviderOptionsMap中提取特定供应商的选项类型 // 工具类型用于从ProviderOptionsMap中提取特定供应商的选项类型

View File

@ -0,0 +1,86 @@
// copy from @ai-sdk/xai/xai-chat-options.ts
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
import { z } from 'zod'
const webSourceSchema = z.object({
type: z.literal('web'),
country: z.string().length(2).optional(),
excludedWebsites: z.array(z.string()).max(5).optional(),
allowedWebsites: z.array(z.string()).max(5).optional(),
safeSearch: z.boolean().optional()
})
const xSourceSchema = z.object({
type: z.literal('x'),
xHandles: z.array(z.string()).optional()
})
const newsSourceSchema = z.object({
type: z.literal('news'),
country: z.string().length(2).optional(),
excludedWebsites: z.array(z.string()).max(5).optional(),
safeSearch: z.boolean().optional()
})
const rssSourceSchema = z.object({
type: z.literal('rss'),
links: z.array(z.string().url()).max(1) // currently only supports one RSS link
})
const searchSourceSchema = z.discriminatedUnion('type', [
webSourceSchema,
xSourceSchema,
newsSourceSchema,
rssSourceSchema
])
export const xaiProviderOptions = z.object({
/**
* reasoning effort for reasoning models
* only supported by grok-3-mini and grok-3-mini-fast models
*/
reasoningEffort: z.enum(['low', 'high']).optional(),
searchParameters: z
.object({
/**
* search mode preference
* - "off": disables search completely
* - "auto": model decides whether to search (default)
* - "on": always enables search
*/
mode: z.enum(['off', 'auto', 'on']),
/**
* whether to return citations in the response
* defaults to true
*/
returnCitations: z.boolean().optional(),
/**
* start date for search data (ISO8601 format: YYYY-MM-DD)
*/
fromDate: z.string().optional(),
/**
* end date for search data (ISO8601 format: YYYY-MM-DD)
*/
toDate: z.string().optional(),
/**
* maximum number of search results to consider
* defaults to 20
*/
maxSearchResults: z.number().min(1).max(50).optional(),
/**
* data sources to search from
* defaults to ["web", "x"] if not specified
*/
sources: z.array(searchSourceSchema).optional()
})
.optional()
})
export type XaiProviderOptions = z.infer<typeof xaiProviderOptions>

View File

@ -1,132 +1,78 @@
/** import type { anthropic } from '@ai-sdk/anthropic'
* import type { openai } from '@ai-sdk/openai'
* ApiClient
*/
import type { OpenAIProvider } from '@ai-sdk/openai'
import { ProviderId } from '../../../../types' import { ProviderOptionsMap } from '../../../options/types'
// 派生自 OpenAI SDK 的标准工具入参类型
type WebSearchPreviewParams = Parameters<OpenAIProvider['tools']['webSearchPreview']>[0]
// 使用交叉类型合并,并为 extra 添加注释
export type WebSearchConfig = WebSearchPreviewParams & {
/**
*
* provider providerOptions
*/
extra?: Record<string, any>
}
/** /**
* OpenAI * AI SDK
* Vercel AI SDK web_search_preview
*/ */
export function adaptOpenAIWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any { type OpenAISearchConfig = Parameters<typeof openai.tools.webSearchPreview>[0]
const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig type AnthropicSearchConfig = Parameters<typeof anthropic.tools.webSearch_20250305>[0]
const { extra, ...stdParams } = config
const webSearchTool = { /**
type: 'web_search_preview', * XAI
...stdParams * @internal
} */
interface XaiProviderOptions {
// 假设 params.tools 是一个数组或 undefined searchParameters?: {
const existingTools = Array.isArray(params.tools) ? params.tools : [] sources?: any[]
safeSearch?: boolean
// 将 extra 参数添加到 providerOptions 中
const providerOptions = {
...params.providerOptions,
openai: {
...params.providerOptions?.openai,
...(extra || {})
}
}
return {
...params,
tools: [...existingTools, webSearchTool],
providerOptions
} }
} }
/** /**
*
* *
* Gemini * ProviderOptions 便
* googleSearch providerOptions.google.tools
*/ */
// export function adaptGeminiWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any { export interface WebSearchPluginConfig {
// const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig openai?: OpenAISearchConfig
// const googleSearchTool = { googleSearch: {} } anthropic?: AnthropicSearchConfig
xai?: ProviderOptionsMap['xai']['searchParameters']
// const existingTools = Array.isArray(params.providerOptions?.google?.tools) ? params.providerOptions.google.tools : [] google?: Pick<ProviderOptionsMap['google'], 'useSearchGrounding' | 'dynamicRetrievalConfig'>
'google-vertex'?: Pick<ProviderOptionsMap['google'], 'useSearchGrounding' | 'dynamicRetrievalConfig'>
// return { }
// ...params,
// providerOptions: {
// ...params.providerOptions,
// google: {
// ...params.providerOptions?.google,
// useSearchGrounding: true,
// // tools: [...existingTools, googleSearchTool],
// ...(config.extra || {})
// }
// }
// }
// }
/** /**
* Anthropic *
* web_search_20250305 providerOptions.anthropic.tools
*/ */
export function adaptAnthropicWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any { export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig google: {
const webSearchTool = { useSearchGrounding: true
type: 'web_search_20250305', },
name: 'web_search', 'google-vertex': {
max_uses: 5 // 默认值,可以通过 extra 覆盖 useSearchGrounding: true
} },
openai: {},
const existingTools = Array.isArray(params.providerOptions?.anthropic?.tools) xai: {
? params.providerOptions.anthropic.tools mode: 'on',
: [] returnCitations: true,
maxSearchResults: 5
return { },
...params, anthropic: {
providerOptions: { maxUses: 5
...params.providerOptions,
anthropic: {
...params.providerOptions?.anthropic,
tools: [...existingTools, webSearchTool],
...(config.extra || {})
}
}
} }
} }
/** /**
* * Google providerOptions
* providerId
*/ */
export function adaptWebSearchForProvider( export const getGoogleProviderOptions = (providerOptions: any) => {
params: any, if (!providerOptions) providerOptions = {}
providerId: ProviderId, if (!providerOptions.google) providerOptions.google = {}
webSearchConfig: WebSearchConfig | boolean providerOptions.google.useSearchGrounding = true
): any { return providerOptions
switch (providerId) { }
case 'openai':
return adaptOpenAIWebSearch(params, webSearchConfig) /**
* XAI providerOptions
// google的需要通过插件在创建model的时候传入参数 */
// case 'google': export const getXaiProviderOptions = (providerOptions: any, config?: XaiProviderOptions['searchParameters']) => {
// case 'google-vertex': if (!providerOptions) providerOptions = {}
// return adaptGeminiWebSearch(params, webSearchConfig) if (!providerOptions.xai) providerOptions.xai = {}
providerOptions.xai.searchParameters = {
case 'anthropic': mode: 'on',
return adaptAnthropicWebSearch(params, webSearchConfig) ...(config ?? {})
}
default: return providerOptions
// 不支持的 provider保持原样
return params
}
} }

View File

@ -2,66 +2,71 @@
* Web Search Plugin * Web Search Plugin
* AI Provider * AI Provider
*/ */
import { anthropic } from '@ai-sdk/anthropic'
import { openai } from '@ai-sdk/openai'
import { createGoogleOptions, createXaiOptions, mergeProviderOptions } from '../../../options'
import { definePlugin } from '../../' import { definePlugin } from '../../'
import type { AiRequestContext } from '../../types' import type { AiRequestContext } from '../../types'
import { adaptWebSearchForProvider, type WebSearchConfig } from './helper' import { DEFAULT_WEB_SEARCH_CONFIG, WebSearchPluginConfig } from './helper'
/** /**
* *
* *
* params.providerOptions.[providerId].webSearch * @param config -
* options.ts assistant.enableWebSearch
* providerOptions webSearch: { enabled: true }
*/ */
export const webSearchPlugin = () => export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEARCH_CONFIG) =>
definePlugin({ definePlugin({
name: 'webSearch', name: 'webSearch',
enforce: 'pre', enforce: 'pre',
// configureModel: async (modelConfig: any, context: AiRequestContext) => {
// if (context.providerId === 'google') {
// return {
// ...modelConfig
// }
// }
// return null
// },
transformParams: async (params: any, context: AiRequestContext) => { transformParams: async (params: any, context: AiRequestContext) => {
const { providerId } = context const { providerId } = context
// 从 providerOptions 中提取 webSearch 配置 switch (providerId) {
const webSearchConfig = params.providerOptions?.[providerId]?.webSearch case 'openai': {
if (config.openai) {
if (!params.tools) params.tools = {}
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
}
break
}
// 检查是否启用了网络搜索 (enabled: false 可用于显式禁用) case 'anthropic': {
if (!webSearchConfig || (typeof webSearchConfig === 'object' && webSearchConfig.enabled === false)) { if (config.anthropic) {
return params if (!params.tools) params.tools = {}
} params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic)
console.log('webSearchConfig', webSearchConfig) }
// // 检查当前 provider 是否支持网络搜索 break
// if (!isWebSearchSupported(providerId)) { }
// // 对于不支持的 provider只记录警告不修改参数
// console.warn(
// `[webSearchPlugin] Provider '${providerId}' does not support web search. Ignoring webSearch parameter.`
// )
// return params
// }
// 使用适配器函数处理网络搜索 case 'google':
const adaptedParams = adaptWebSearchForProvider(params, providerId, webSearchConfig as WebSearchConfig | boolean) case 'google-vertex': {
// 清理原始的 webSearch 配置 // @ts-ignore - providerId is a string that can be used to index config
if (adaptedParams.providerOptions?.[providerId]) { if (config[providerId]) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars const searchOptions = createGoogleOptions({ useSearchGrounding: true })
const { webSearch, ...rest } = adaptedParams.providerOptions[providerId] params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
adaptedParams.providerOptions[providerId] = rest }
break
}
case 'xai': {
if (config.xai) {
const searchOptions = createXaiOptions({
searchParameters: { ...config.xai, mode: 'on' }
})
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
}
break
}
} }
return adaptedParams
return params
} }
}) })
// 导出类型定义供开发者使用 // 导出类型定义供开发者使用
export type { WebSearchConfig } from './helper' export type { WebSearchPluginConfig } from './helper'
// 默认导出 // 默认导出
export default webSearchPlugin export default webSearchPlugin

View File

@ -17,7 +17,7 @@ import {
type ProviderSettingsMap, type ProviderSettingsMap,
StreamTextParams StreamTextParams
} from '@cherrystudio/ai-core' } from '@cherrystudio/ai-core'
import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/core/plugins/built-in' import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/core/plugins/built-in'
import { isDedicatedImageGenerationModel } from '@renderer/config/models' import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import type { GenerateImageParams, Model, Provider } from '@renderer/types' import type { GenerateImageParams, Model, Provider } from '@renderer/types'
@ -143,7 +143,7 @@ export default class ModernAiProvider {
const plugins: AiPlugin[] = [] const plugins: AiPlugin[] = []
// 1. 总是添加通用插件 // 1. 总是添加通用插件
// plugins.push(textPlugin) // plugins.push(textPlugin)
// plugins.push(webSearchPlugin) plugins.push(webSearchPlugin())
// 2. 推理模型时添加推理插件 // 2. 推理模型时添加推理插件
if (middlewareConfig.enableReasoning) { if (middlewareConfig.enableReasoning) {

View File

@ -42,7 +42,7 @@ import { defaultTimeout } from '@shared/config/constant'
// import { jsonSchemaToZod } from 'json-schema-to-zod' // import { jsonSchemaToZod } from 'json-schema-to-zod'
import { setupToolsConfig } from './utils/mcp' import { setupToolsConfig } from './utils/mcp'
import { buildProviderOptions } from './utils/options' import { buildProviderOptions } from './utils/options'
import { getWebSearchTools } from './utils/websearch' // import { getWebSearchTools } from './utils/websearch'
/** /**
* *
@ -279,17 +279,17 @@ export async function buildStreamTextParams(
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true) (isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
// 构建系统提示 // 构建系统提示
let { tools } = setupToolsConfig({ const { tools } = setupToolsConfig({
mcpTools, mcpTools,
model, model,
enableToolUse: enableTools enableToolUse: enableTools
}) })
// Add web search tools if enabled // Add web search tools if enabled
if (enableWebSearch) { // if (enableWebSearch) {
const webSearchTools = getWebSearchTools(model) // const webSearchTools = getWebSearchTools(model)
tools = { ...tools, ...webSearchTools } // tools = { ...tools, ...webSearchTools }
} // }
// 构建真正的 providerOptions // 构建真正的 providerOptions
const providerOptions = buildProviderOptions(assistant, model, { const providerOptions = buildProviderOptions(assistant, model, {

View File

@ -1,37 +1,37 @@
import { isWebSearchModel } from '@renderer/config/models' // import { isWebSearchModel } from '@renderer/config/models'
import { Model } from '@renderer/types' // import { Model } from '@renderer/types'
// import {} from '@cherrystudio/ai-core' // // import {} from '@cherrystudio/ai-core'
// The tool name for Gemini search can be arbitrary, but let's use a descriptive one. // // The tool name for Gemini search can be arbitrary, but let's use a descriptive one.
const GEMINI_SEARCH_TOOL_NAME = 'google_search' // const GEMINI_SEARCH_TOOL_NAME = 'google_search'
export function getWebSearchTools(model: Model): Record<string, any> { // export function getWebSearchTools(model: Model): Record<string, any> {
if (!isWebSearchModel(model)) { // if (!isWebSearchModel(model)) {
return {} // return {}
} // }
// Use provider from model if available, otherwise fallback to parsing model id. // // Use provider from model if available, otherwise fallback to parsing model id.
const provider = model.provider || model.id.split('/')[0] // const provider = model.provider || model.id.split('/')[0]
switch (provider) { // switch (provider) {
case 'anthropic': // case 'anthropic':
return { // return {
web_search: { // web_search: {
type: 'web_search_20250305', // type: 'web_search_20250305',
name: 'web_search', // name: 'web_search',
max_uses: 5 // max_uses: 5
} // }
} // }
case 'google': // case 'google':
case 'gemini': // case 'gemini':
return { // return {
[GEMINI_SEARCH_TOOL_NAME]: { // [GEMINI_SEARCH_TOOL_NAME]: {
googleSearch: {} // googleSearch: {}
} // }
} // }
default: // default:
// For OpenAI and others, web search is often a parameter, not a tool. // // For OpenAI and others, web search is often a parameter, not a tool.
// The logic is handled in `buildProviderOptions`. // // The logic is handled in `buildProviderOptions`.
return {} // return {}
} // }
} // }