feat: add XAI provider options and enhance web search plugin

- Introduced `createXaiOptions` function for XAI provider configuration.
- Added `XaiProviderOptions` type and validation schema in `xai.ts`.
- Updated `ProviderOptionsMap` to include XAI options.
- Enhanced `webSearchPlugin` to support XAI-specific search parameters.
- Refactored helper functions to integrate new XAI options into provider configurations.
This commit is contained in:
lizhixuan 2025-07-07 23:28:49 +08:00
parent 56c5e5a80f
commit 4573e3f48f
8 changed files with 237 additions and 191 deletions

View File

@ -62,3 +62,10 @@ export function createGoogleOptions(options: ExtractProviderOptions<'google'>) {
export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'>) {
return createProviderOptions('openrouter', options)
}
/**
* XAI供应商选项的便捷函数
*/
export function createXaiOptions(options: ExtractProviderOptions<'xai'>) {
return createProviderOptions('xai', options)
}

View File

@ -4,6 +4,7 @@ import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import { type SharedV2ProviderMetadata } from '@ai-sdk/provider'
import { type OpenRouterProviderOptions } from './openrouter'
import { type XaiProviderOptions } from './xai'
export type ProviderOptions<T extends keyof SharedV2ProviderMetadata> = SharedV2ProviderMetadata[T]
@ -15,6 +16,7 @@ export type ProviderOptionsMap = {
anthropic: AnthropicProviderOptions
google: GoogleGenerativeAIProviderOptions
openrouter: OpenRouterProviderOptions
xai: XaiProviderOptions
}
// 工具类型用于从ProviderOptionsMap中提取特定供应商的选项类型

View File

@ -0,0 +1,86 @@
// copy from @ai-sdk/xai/xai-chat-options.ts
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
import { z } from 'zod'
const webSourceSchema = z.object({
type: z.literal('web'),
country: z.string().length(2).optional(),
excludedWebsites: z.array(z.string()).max(5).optional(),
allowedWebsites: z.array(z.string()).max(5).optional(),
safeSearch: z.boolean().optional()
})
const xSourceSchema = z.object({
type: z.literal('x'),
xHandles: z.array(z.string()).optional()
})
const newsSourceSchema = z.object({
type: z.literal('news'),
country: z.string().length(2).optional(),
excludedWebsites: z.array(z.string()).max(5).optional(),
safeSearch: z.boolean().optional()
})
const rssSourceSchema = z.object({
type: z.literal('rss'),
links: z.array(z.string().url()).max(1) // currently only supports one RSS link
})
const searchSourceSchema = z.discriminatedUnion('type', [
webSourceSchema,
xSourceSchema,
newsSourceSchema,
rssSourceSchema
])
export const xaiProviderOptions = z.object({
/**
* reasoning effort for reasoning models
* only supported by grok-3-mini and grok-3-mini-fast models
*/
reasoningEffort: z.enum(['low', 'high']).optional(),
searchParameters: z
.object({
/**
* search mode preference
* - "off": disables search completely
* - "auto": model decides whether to search (default)
* - "on": always enables search
*/
mode: z.enum(['off', 'auto', 'on']),
/**
* whether to return citations in the response
* defaults to true
*/
returnCitations: z.boolean().optional(),
/**
* start date for search data (ISO8601 format: YYYY-MM-DD)
*/
fromDate: z.string().optional(),
/**
* end date for search data (ISO8601 format: YYYY-MM-DD)
*/
toDate: z.string().optional(),
/**
* maximum number of search results to consider
* defaults to 20
*/
maxSearchResults: z.number().min(1).max(50).optional(),
/**
* data sources to search from
* defaults to ["web", "x"] if not specified
*/
sources: z.array(searchSourceSchema).optional()
})
.optional()
})
export type XaiProviderOptions = z.infer<typeof xaiProviderOptions>

View File

@ -1,132 +1,78 @@
/**
*
* ApiClient
*/
import type { OpenAIProvider } from '@ai-sdk/openai'
import type { anthropic } from '@ai-sdk/anthropic'
import type { openai } from '@ai-sdk/openai'
import { ProviderId } from '../../../../types'
// 派生自 OpenAI SDK 的标准工具入参类型
type WebSearchPreviewParams = Parameters<OpenAIProvider['tools']['webSearchPreview']>[0]
// 使用交叉类型合并,并为 extra 添加注释
export type WebSearchConfig = WebSearchPreviewParams & {
/**
*
* provider providerOptions
*/
extra?: Record<string, any>
}
import { ProviderOptionsMap } from '../../../options/types'
/**
* OpenAI
* Vercel AI SDK web_search_preview
* AI SDK
*/
export function adaptOpenAIWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any {
const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig
const { extra, ...stdParams } = config
type OpenAISearchConfig = Parameters<typeof openai.tools.webSearchPreview>[0]
type AnthropicSearchConfig = Parameters<typeof anthropic.tools.webSearch_20250305>[0]
const webSearchTool = {
type: 'web_search_preview',
...stdParams
}
// 假设 params.tools 是一个数组或 undefined
const existingTools = Array.isArray(params.tools) ? params.tools : []
// 将 extra 参数添加到 providerOptions 中
const providerOptions = {
...params.providerOptions,
openai: {
...params.providerOptions?.openai,
...(extra || {})
}
}
return {
...params,
tools: [...existingTools, webSearchTool],
providerOptions
/**
* XAI
* @internal
*/
interface XaiProviderOptions {
searchParameters?: {
sources?: any[]
safeSearch?: boolean
}
}
/**
*
*
* Gemini
* googleSearch providerOptions.google.tools
* ProviderOptions 便
*/
// export function adaptGeminiWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any {
// const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig
// const googleSearchTool = { googleSearch: {} }
// const existingTools = Array.isArray(params.providerOptions?.google?.tools) ? params.providerOptions.google.tools : []
// return {
// ...params,
// providerOptions: {
// ...params.providerOptions,
// google: {
// ...params.providerOptions?.google,
// useSearchGrounding: true,
// // tools: [...existingTools, googleSearchTool],
// ...(config.extra || {})
// }
// }
// }
// }
export interface WebSearchPluginConfig {
openai?: OpenAISearchConfig
anthropic?: AnthropicSearchConfig
xai?: ProviderOptionsMap['xai']['searchParameters']
google?: Pick<ProviderOptionsMap['google'], 'useSearchGrounding' | 'dynamicRetrievalConfig'>
'google-vertex'?: Pick<ProviderOptionsMap['google'], 'useSearchGrounding' | 'dynamicRetrievalConfig'>
}
/**
* Anthropic
* web_search_20250305 providerOptions.anthropic.tools
*
*/
export function adaptAnthropicWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any {
const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig
const webSearchTool = {
type: 'web_search_20250305',
name: 'web_search',
max_uses: 5 // 默认值,可以通过 extra 覆盖
}
const existingTools = Array.isArray(params.providerOptions?.anthropic?.tools)
? params.providerOptions.anthropic.tools
: []
return {
...params,
providerOptions: {
...params.providerOptions,
anthropic: {
...params.providerOptions?.anthropic,
tools: [...existingTools, webSearchTool],
...(config.extra || {})
}
}
export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
google: {
useSearchGrounding: true
},
'google-vertex': {
useSearchGrounding: true
},
openai: {},
xai: {
mode: 'on',
returnCitations: true,
maxSearchResults: 5
},
anthropic: {
maxUses: 5
}
}
/**
*
* providerId
* Google providerOptions
*/
export function adaptWebSearchForProvider(
params: any,
providerId: ProviderId,
webSearchConfig: WebSearchConfig | boolean
): any {
switch (providerId) {
case 'openai':
return adaptOpenAIWebSearch(params, webSearchConfig)
// google的需要通过插件在创建model的时候传入参数
// case 'google':
// case 'google-vertex':
// return adaptGeminiWebSearch(params, webSearchConfig)
case 'anthropic':
return adaptAnthropicWebSearch(params, webSearchConfig)
default:
// 不支持的 provider保持原样
return params
}
export const getGoogleProviderOptions = (providerOptions: any) => {
if (!providerOptions) providerOptions = {}
if (!providerOptions.google) providerOptions.google = {}
providerOptions.google.useSearchGrounding = true
return providerOptions
}
/**
* XAI providerOptions
*/
export const getXaiProviderOptions = (providerOptions: any, config?: XaiProviderOptions['searchParameters']) => {
if (!providerOptions) providerOptions = {}
if (!providerOptions.xai) providerOptions.xai = {}
providerOptions.xai.searchParameters = {
mode: 'on',
...(config ?? {})
}
return providerOptions
}

View File

@ -2,66 +2,71 @@
* Web Search Plugin
* AI Provider
*/
import { anthropic } from '@ai-sdk/anthropic'
import { openai } from '@ai-sdk/openai'
import { createGoogleOptions, createXaiOptions, mergeProviderOptions } from '../../../options'
import { definePlugin } from '../../'
import type { AiRequestContext } from '../../types'
import { adaptWebSearchForProvider, type WebSearchConfig } from './helper'
import { DEFAULT_WEB_SEARCH_CONFIG, WebSearchPluginConfig } from './helper'
/**
*
*
* params.providerOptions.[providerId].webSearch
* options.ts assistant.enableWebSearch
* providerOptions webSearch: { enabled: true }
* @param config -
*/
export const webSearchPlugin = () =>
export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEARCH_CONFIG) =>
definePlugin({
name: 'webSearch',
enforce: 'pre',
// configureModel: async (modelConfig: any, context: AiRequestContext) => {
// if (context.providerId === 'google') {
// return {
// ...modelConfig
// }
// }
// return null
// },
transformParams: async (params: any, context: AiRequestContext) => {
const { providerId } = context
// 从 providerOptions 中提取 webSearch 配置
const webSearchConfig = params.providerOptions?.[providerId]?.webSearch
switch (providerId) {
case 'openai': {
if (config.openai) {
if (!params.tools) params.tools = {}
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
}
break
}
// 检查是否启用了网络搜索 (enabled: false 可用于显式禁用)
if (!webSearchConfig || (typeof webSearchConfig === 'object' && webSearchConfig.enabled === false)) {
return params
}
console.log('webSearchConfig', webSearchConfig)
// // 检查当前 provider 是否支持网络搜索
// if (!isWebSearchSupported(providerId)) {
// // 对于不支持的 provider只记录警告不修改参数
// console.warn(
// `[webSearchPlugin] Provider '${providerId}' does not support web search. Ignoring webSearch parameter.`
// )
// return params
// }
case 'anthropic': {
if (config.anthropic) {
if (!params.tools) params.tools = {}
params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic)
}
break
}
// 使用适配器函数处理网络搜索
const adaptedParams = adaptWebSearchForProvider(params, providerId, webSearchConfig as WebSearchConfig | boolean)
// 清理原始的 webSearch 配置
if (adaptedParams.providerOptions?.[providerId]) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { webSearch, ...rest } = adaptedParams.providerOptions[providerId]
adaptedParams.providerOptions[providerId] = rest
case 'google':
case 'google-vertex': {
// @ts-ignore - providerId is a string that can be used to index config
if (config[providerId]) {
const searchOptions = createGoogleOptions({ useSearchGrounding: true })
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
}
break
}
case 'xai': {
if (config.xai) {
const searchOptions = createXaiOptions({
searchParameters: { ...config.xai, mode: 'on' }
})
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
}
break
}
}
return adaptedParams
return params
}
})
// 导出类型定义供开发者使用
export type { WebSearchConfig } from './helper'
export type { WebSearchPluginConfig } from './helper'
// 默认导出
export default webSearchPlugin

View File

@ -17,7 +17,7 @@ import {
type ProviderSettingsMap,
StreamTextParams
} from '@cherrystudio/ai-core'
import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/core/plugins/built-in'
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/core/plugins/built-in'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
@ -143,7 +143,7 @@ export default class ModernAiProvider {
const plugins: AiPlugin[] = []
// 1. 总是添加通用插件
// plugins.push(textPlugin)
// plugins.push(webSearchPlugin)
plugins.push(webSearchPlugin())
// 2. 推理模型时添加推理插件
if (middlewareConfig.enableReasoning) {

View File

@ -42,7 +42,7 @@ import { defaultTimeout } from '@shared/config/constant'
// import { jsonSchemaToZod } from 'json-schema-to-zod'
import { setupToolsConfig } from './utils/mcp'
import { buildProviderOptions } from './utils/options'
import { getWebSearchTools } from './utils/websearch'
// import { getWebSearchTools } from './utils/websearch'
/**
*
@ -279,17 +279,17 @@ export async function buildStreamTextParams(
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
// 构建系统提示
let { tools } = setupToolsConfig({
const { tools } = setupToolsConfig({
mcpTools,
model,
enableToolUse: enableTools
})
// Add web search tools if enabled
if (enableWebSearch) {
const webSearchTools = getWebSearchTools(model)
tools = { ...tools, ...webSearchTools }
}
// if (enableWebSearch) {
// const webSearchTools = getWebSearchTools(model)
// tools = { ...tools, ...webSearchTools }
// }
// 构建真正的 providerOptions
const providerOptions = buildProviderOptions(assistant, model, {

View File

@ -1,37 +1,37 @@
import { isWebSearchModel } from '@renderer/config/models'
import { Model } from '@renderer/types'
// import {} from '@cherrystudio/ai-core'
// import { isWebSearchModel } from '@renderer/config/models'
// import { Model } from '@renderer/types'
// // import {} from '@cherrystudio/ai-core'
// The tool name for Gemini search can be arbitrary, but let's use a descriptive one.
const GEMINI_SEARCH_TOOL_NAME = 'google_search'
// // The tool name for Gemini search can be arbitrary, but let's use a descriptive one.
// const GEMINI_SEARCH_TOOL_NAME = 'google_search'
export function getWebSearchTools(model: Model): Record<string, any> {
if (!isWebSearchModel(model)) {
return {}
}
// export function getWebSearchTools(model: Model): Record<string, any> {
// if (!isWebSearchModel(model)) {
// return {}
// }
// Use provider from model if available, otherwise fallback to parsing model id.
const provider = model.provider || model.id.split('/')[0]
// // Use provider from model if available, otherwise fallback to parsing model id.
// const provider = model.provider || model.id.split('/')[0]
switch (provider) {
case 'anthropic':
return {
web_search: {
type: 'web_search_20250305',
name: 'web_search',
max_uses: 5
}
}
case 'google':
case 'gemini':
return {
[GEMINI_SEARCH_TOOL_NAME]: {
googleSearch: {}
}
}
default:
// For OpenAI and others, web search is often a parameter, not a tool.
// The logic is handled in `buildProviderOptions`.
return {}
}
}
// switch (provider) {
// case 'anthropic':
// return {
// web_search: {
// type: 'web_search_20250305',
// name: 'web_search',
// max_uses: 5
// }
// }
// case 'google':
// case 'gemini':
// return {
// [GEMINI_SEARCH_TOOL_NAME]: {
// googleSearch: {}
// }
// }
// default:
// // For OpenAI and others, web search is often a parameter, not a tool.
// // The logic is handled in `buildProviderOptions`.
// return {}
// }
// }