mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat: add XAI provider options and enhance web search plugin
- Introduced `createXaiOptions` function for XAI provider configuration. - Added `XaiProviderOptions` type and validation schema in `xai.ts`. - Updated `ProviderOptionsMap` to include XAI options. - Enhanced `webSearchPlugin` to support XAI-specific search parameters. - Refactored helper functions to integrate new XAI options into provider configurations.
This commit is contained in:
parent
56c5e5a80f
commit
4573e3f48f
@ -62,3 +62,10 @@ export function createGoogleOptions(options: ExtractProviderOptions<'google'>) {
|
||||
export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'>) {
|
||||
return createProviderOptions('openrouter', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建XAI供应商选项的便捷函数
|
||||
*/
|
||||
export function createXaiOptions(options: ExtractProviderOptions<'xai'>) {
|
||||
return createProviderOptions('xai', options)
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||
import { type SharedV2ProviderMetadata } from '@ai-sdk/provider'
|
||||
|
||||
import { type OpenRouterProviderOptions } from './openrouter'
|
||||
import { type XaiProviderOptions } from './xai'
|
||||
|
||||
export type ProviderOptions<T extends keyof SharedV2ProviderMetadata> = SharedV2ProviderMetadata[T]
|
||||
|
||||
@ -15,6 +16,7 @@ export type ProviderOptionsMap = {
|
||||
anthropic: AnthropicProviderOptions
|
||||
google: GoogleGenerativeAIProviderOptions
|
||||
openrouter: OpenRouterProviderOptions
|
||||
xai: XaiProviderOptions
|
||||
}
|
||||
|
||||
// 工具类型,用于从ProviderOptionsMap中提取特定供应商的选项类型
|
||||
|
||||
86
packages/aiCore/src/core/options/xai.ts
Normal file
86
packages/aiCore/src/core/options/xai.ts
Normal file
@ -0,0 +1,86 @@
|
||||
// copy from @ai-sdk/xai/xai-chat-options.ts
|
||||
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
|
||||
|
||||
import { z } from 'zod'
|
||||
|
||||
const webSourceSchema = z.object({
|
||||
type: z.literal('web'),
|
||||
country: z.string().length(2).optional(),
|
||||
excludedWebsites: z.array(z.string()).max(5).optional(),
|
||||
allowedWebsites: z.array(z.string()).max(5).optional(),
|
||||
safeSearch: z.boolean().optional()
|
||||
})
|
||||
|
||||
const xSourceSchema = z.object({
|
||||
type: z.literal('x'),
|
||||
xHandles: z.array(z.string()).optional()
|
||||
})
|
||||
|
||||
const newsSourceSchema = z.object({
|
||||
type: z.literal('news'),
|
||||
country: z.string().length(2).optional(),
|
||||
excludedWebsites: z.array(z.string()).max(5).optional(),
|
||||
safeSearch: z.boolean().optional()
|
||||
})
|
||||
|
||||
const rssSourceSchema = z.object({
|
||||
type: z.literal('rss'),
|
||||
links: z.array(z.string().url()).max(1) // currently only supports one RSS link
|
||||
})
|
||||
|
||||
const searchSourceSchema = z.discriminatedUnion('type', [
|
||||
webSourceSchema,
|
||||
xSourceSchema,
|
||||
newsSourceSchema,
|
||||
rssSourceSchema
|
||||
])
|
||||
|
||||
export const xaiProviderOptions = z.object({
|
||||
/**
|
||||
* reasoning effort for reasoning models
|
||||
* only supported by grok-3-mini and grok-3-mini-fast models
|
||||
*/
|
||||
reasoningEffort: z.enum(['low', 'high']).optional(),
|
||||
|
||||
searchParameters: z
|
||||
.object({
|
||||
/**
|
||||
* search mode preference
|
||||
* - "off": disables search completely
|
||||
* - "auto": model decides whether to search (default)
|
||||
* - "on": always enables search
|
||||
*/
|
||||
mode: z.enum(['off', 'auto', 'on']),
|
||||
|
||||
/**
|
||||
* whether to return citations in the response
|
||||
* defaults to true
|
||||
*/
|
||||
returnCitations: z.boolean().optional(),
|
||||
|
||||
/**
|
||||
* start date for search data (ISO8601 format: YYYY-MM-DD)
|
||||
*/
|
||||
fromDate: z.string().optional(),
|
||||
|
||||
/**
|
||||
* end date for search data (ISO8601 format: YYYY-MM-DD)
|
||||
*/
|
||||
toDate: z.string().optional(),
|
||||
|
||||
/**
|
||||
* maximum number of search results to consider
|
||||
* defaults to 20
|
||||
*/
|
||||
maxSearchResults: z.number().min(1).max(50).optional(),
|
||||
|
||||
/**
|
||||
* data sources to search from
|
||||
* defaults to ["web", "x"] if not specified
|
||||
*/
|
||||
sources: z.array(searchSourceSchema).optional()
|
||||
})
|
||||
.optional()
|
||||
})
|
||||
|
||||
export type XaiProviderOptions = z.infer<typeof xaiProviderOptions>
|
||||
@ -1,132 +1,78 @@
|
||||
/**
|
||||
* 网络搜索助手函数
|
||||
* 提取各个 ApiClient 中的网络搜索逻辑,提供统一的适配器
|
||||
*/
|
||||
import type { OpenAIProvider } from '@ai-sdk/openai'
|
||||
import type { anthropic } from '@ai-sdk/anthropic'
|
||||
import type { openai } from '@ai-sdk/openai'
|
||||
|
||||
import { ProviderId } from '../../../../types'
|
||||
|
||||
// 派生自 OpenAI SDK 的标准工具入参类型
|
||||
type WebSearchPreviewParams = Parameters<OpenAIProvider['tools']['webSearchPreview']>[0]
|
||||
|
||||
// 使用交叉类型合并,并为 extra 添加注释
|
||||
export type WebSearchConfig = WebSearchPreviewParams & {
|
||||
/**
|
||||
* 扩展字段,用于提供给开发者自定义参数的能力
|
||||
* 这些参数将被合并到对应 provider 的 providerOptions 中
|
||||
*/
|
||||
extra?: Record<string, any>
|
||||
}
|
||||
import { ProviderOptionsMap } from '../../../options/types'
|
||||
|
||||
/**
|
||||
* 适配 OpenAI 网络搜索
|
||||
* 基于 Vercel AI SDK 的 web_search_preview 工具
|
||||
* 从 AI SDK 的工具函数中提取参数类型,以确保类型安全。
|
||||
*/
|
||||
export function adaptOpenAIWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any {
|
||||
const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig
|
||||
const { extra, ...stdParams } = config
|
||||
type OpenAISearchConfig = Parameters<typeof openai.tools.webSearchPreview>[0]
|
||||
type AnthropicSearchConfig = Parameters<typeof anthropic.tools.webSearch_20250305>[0]
|
||||
|
||||
const webSearchTool = {
|
||||
type: 'web_search_preview',
|
||||
...stdParams
|
||||
}
|
||||
|
||||
// 假设 params.tools 是一个数组或 undefined
|
||||
const existingTools = Array.isArray(params.tools) ? params.tools : []
|
||||
|
||||
// 将 extra 参数添加到 providerOptions 中
|
||||
const providerOptions = {
|
||||
...params.providerOptions,
|
||||
openai: {
|
||||
...params.providerOptions?.openai,
|
||||
...(extra || {})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
...params,
|
||||
tools: [...existingTools, webSearchTool],
|
||||
providerOptions
|
||||
/**
|
||||
* XAI 特有的搜索参数
|
||||
* @internal
|
||||
*/
|
||||
interface XaiProviderOptions {
|
||||
searchParameters?: {
|
||||
sources?: any[]
|
||||
safeSearch?: boolean
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件初始化时接收的完整配置对象
|
||||
*
|
||||
* 适配 Gemini 网络搜索
|
||||
* 将 googleSearch 工具放入 providerOptions.google.tools
|
||||
* 其结构与 ProviderOptions 保持一致,方便上游统一管理配置
|
||||
*/
|
||||
// export function adaptGeminiWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any {
|
||||
// const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig
|
||||
// const googleSearchTool = { googleSearch: {} }
|
||||
|
||||
// const existingTools = Array.isArray(params.providerOptions?.google?.tools) ? params.providerOptions.google.tools : []
|
||||
|
||||
// return {
|
||||
// ...params,
|
||||
// providerOptions: {
|
||||
// ...params.providerOptions,
|
||||
// google: {
|
||||
// ...params.providerOptions?.google,
|
||||
// useSearchGrounding: true,
|
||||
// // tools: [...existingTools, googleSearchTool],
|
||||
// ...(config.extra || {})
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
export interface WebSearchPluginConfig {
|
||||
openai?: OpenAISearchConfig
|
||||
anthropic?: AnthropicSearchConfig
|
||||
xai?: ProviderOptionsMap['xai']['searchParameters']
|
||||
google?: Pick<ProviderOptionsMap['google'], 'useSearchGrounding' | 'dynamicRetrievalConfig'>
|
||||
'google-vertex'?: Pick<ProviderOptionsMap['google'], 'useSearchGrounding' | 'dynamicRetrievalConfig'>
|
||||
}
|
||||
|
||||
/**
|
||||
* 适配 Anthropic 网络搜索
|
||||
* 将 web_search_20250305 工具放入 providerOptions.anthropic.tools
|
||||
* 插件的默认配置
|
||||
*/
|
||||
export function adaptAnthropicWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any {
|
||||
const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig
|
||||
const webSearchTool = {
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
max_uses: 5 // 默认值,可以通过 extra 覆盖
|
||||
}
|
||||
|
||||
const existingTools = Array.isArray(params.providerOptions?.anthropic?.tools)
|
||||
? params.providerOptions.anthropic.tools
|
||||
: []
|
||||
|
||||
return {
|
||||
...params,
|
||||
providerOptions: {
|
||||
...params.providerOptions,
|
||||
anthropic: {
|
||||
...params.providerOptions?.anthropic,
|
||||
tools: [...existingTools, webSearchTool],
|
||||
...(config.extra || {})
|
||||
}
|
||||
}
|
||||
export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
|
||||
google: {
|
||||
useSearchGrounding: true
|
||||
},
|
||||
'google-vertex': {
|
||||
useSearchGrounding: true
|
||||
},
|
||||
openai: {},
|
||||
xai: {
|
||||
mode: 'on',
|
||||
returnCitations: true,
|
||||
maxSearchResults: 5
|
||||
},
|
||||
anthropic: {
|
||||
maxUses: 5
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用网络搜索适配器
|
||||
* 根据 providerId 选择对应的适配函数
|
||||
* 根据配置构建 Google 的 providerOptions
|
||||
*/
|
||||
export function adaptWebSearchForProvider(
|
||||
params: any,
|
||||
providerId: ProviderId,
|
||||
webSearchConfig: WebSearchConfig | boolean
|
||||
): any {
|
||||
switch (providerId) {
|
||||
case 'openai':
|
||||
return adaptOpenAIWebSearch(params, webSearchConfig)
|
||||
|
||||
// google的需要通过插件,在创建model的时候传入参数
|
||||
// case 'google':
|
||||
// case 'google-vertex':
|
||||
// return adaptGeminiWebSearch(params, webSearchConfig)
|
||||
|
||||
case 'anthropic':
|
||||
return adaptAnthropicWebSearch(params, webSearchConfig)
|
||||
|
||||
default:
|
||||
// 不支持的 provider,保持原样
|
||||
return params
|
||||
}
|
||||
export const getGoogleProviderOptions = (providerOptions: any) => {
|
||||
if (!providerOptions) providerOptions = {}
|
||||
if (!providerOptions.google) providerOptions.google = {}
|
||||
providerOptions.google.useSearchGrounding = true
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据配置构建 XAI 的 providerOptions
|
||||
*/
|
||||
export const getXaiProviderOptions = (providerOptions: any, config?: XaiProviderOptions['searchParameters']) => {
|
||||
if (!providerOptions) providerOptions = {}
|
||||
if (!providerOptions.xai) providerOptions.xai = {}
|
||||
providerOptions.xai.searchParameters = {
|
||||
mode: 'on',
|
||||
...(config ?? {})
|
||||
}
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
@ -2,66 +2,71 @@
|
||||
* Web Search Plugin
|
||||
* 提供统一的网络搜索能力,支持多个 AI Provider
|
||||
*/
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
import { openai } from '@ai-sdk/openai'
|
||||
|
||||
import { createGoogleOptions, createXaiOptions, mergeProviderOptions } from '../../../options'
|
||||
import { definePlugin } from '../../'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import { adaptWebSearchForProvider, type WebSearchConfig } from './helper'
|
||||
import { DEFAULT_WEB_SEARCH_CONFIG, WebSearchPluginConfig } from './helper'
|
||||
|
||||
/**
|
||||
* 网络搜索插件
|
||||
*
|
||||
* 此插件会检查 params.providerOptions.[providerId].webSearch 来激活。
|
||||
* options.ts 文件负责将高层级的设置(如 assistant.enableWebSearch)
|
||||
* 转换为 providerOptions 中的 webSearch: { enabled: true } 配置。
|
||||
* @param config - 在插件初始化时传入的静态配置
|
||||
*/
|
||||
export const webSearchPlugin = () =>
|
||||
export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEARCH_CONFIG) =>
|
||||
definePlugin({
|
||||
name: 'webSearch',
|
||||
enforce: 'pre',
|
||||
|
||||
// configureModel: async (modelConfig: any, context: AiRequestContext) => {
|
||||
// if (context.providerId === 'google') {
|
||||
// return {
|
||||
// ...modelConfig
|
||||
// }
|
||||
// }
|
||||
// return null
|
||||
// },
|
||||
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
const { providerId } = context
|
||||
|
||||
// 从 providerOptions 中提取 webSearch 配置
|
||||
const webSearchConfig = params.providerOptions?.[providerId]?.webSearch
|
||||
switch (providerId) {
|
||||
case 'openai': {
|
||||
if (config.openai) {
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// 检查是否启用了网络搜索 (enabled: false 可用于显式禁用)
|
||||
if (!webSearchConfig || (typeof webSearchConfig === 'object' && webSearchConfig.enabled === false)) {
|
||||
return params
|
||||
}
|
||||
console.log('webSearchConfig', webSearchConfig)
|
||||
// // 检查当前 provider 是否支持网络搜索
|
||||
// if (!isWebSearchSupported(providerId)) {
|
||||
// // 对于不支持的 provider,只记录警告,不修改参数
|
||||
// console.warn(
|
||||
// `[webSearchPlugin] Provider '${providerId}' does not support web search. Ignoring webSearch parameter.`
|
||||
// )
|
||||
// return params
|
||||
// }
|
||||
case 'anthropic': {
|
||||
if (config.anthropic) {
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// 使用适配器函数处理网络搜索
|
||||
const adaptedParams = adaptWebSearchForProvider(params, providerId, webSearchConfig as WebSearchConfig | boolean)
|
||||
// 清理原始的 webSearch 配置
|
||||
if (adaptedParams.providerOptions?.[providerId]) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const { webSearch, ...rest } = adaptedParams.providerOptions[providerId]
|
||||
adaptedParams.providerOptions[providerId] = rest
|
||||
case 'google':
|
||||
case 'google-vertex': {
|
||||
// @ts-ignore - providerId is a string that can be used to index config
|
||||
if (config[providerId]) {
|
||||
const searchOptions = createGoogleOptions({ useSearchGrounding: true })
|
||||
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'xai': {
|
||||
if (config.xai) {
|
||||
const searchOptions = createXaiOptions({
|
||||
searchParameters: { ...config.xai, mode: 'on' }
|
||||
})
|
||||
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return adaptedParams
|
||||
|
||||
return params
|
||||
}
|
||||
})
|
||||
|
||||
// 导出类型定义供开发者使用
|
||||
export type { WebSearchConfig } from './helper'
|
||||
export type { WebSearchPluginConfig } from './helper'
|
||||
|
||||
// 默认导出
|
||||
export default webSearchPlugin
|
||||
|
||||
@ -17,7 +17,7 @@ import {
|
||||
type ProviderSettingsMap,
|
||||
StreamTextParams
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/core/plugins/built-in'
|
||||
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/core/plugins/built-in'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
@ -143,7 +143,7 @@ export default class ModernAiProvider {
|
||||
const plugins: AiPlugin[] = []
|
||||
// 1. 总是添加通用插件
|
||||
// plugins.push(textPlugin)
|
||||
// plugins.push(webSearchPlugin)
|
||||
plugins.push(webSearchPlugin())
|
||||
|
||||
// 2. 推理模型时添加推理插件
|
||||
if (middlewareConfig.enableReasoning) {
|
||||
|
||||
@ -42,7 +42,7 @@ import { defaultTimeout } from '@shared/config/constant'
|
||||
// import { jsonSchemaToZod } from 'json-schema-to-zod'
|
||||
import { setupToolsConfig } from './utils/mcp'
|
||||
import { buildProviderOptions } from './utils/options'
|
||||
import { getWebSearchTools } from './utils/websearch'
|
||||
// import { getWebSearchTools } from './utils/websearch'
|
||||
|
||||
/**
|
||||
* 获取温度参数
|
||||
@ -279,17 +279,17 @@ export async function buildStreamTextParams(
|
||||
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
|
||||
|
||||
// 构建系统提示
|
||||
let { tools } = setupToolsConfig({
|
||||
const { tools } = setupToolsConfig({
|
||||
mcpTools,
|
||||
model,
|
||||
enableToolUse: enableTools
|
||||
})
|
||||
|
||||
// Add web search tools if enabled
|
||||
if (enableWebSearch) {
|
||||
const webSearchTools = getWebSearchTools(model)
|
||||
tools = { ...tools, ...webSearchTools }
|
||||
}
|
||||
// if (enableWebSearch) {
|
||||
// const webSearchTools = getWebSearchTools(model)
|
||||
// tools = { ...tools, ...webSearchTools }
|
||||
// }
|
||||
|
||||
// 构建真正的 providerOptions
|
||||
const providerOptions = buildProviderOptions(assistant, model, {
|
||||
|
||||
@ -1,37 +1,37 @@
|
||||
import { isWebSearchModel } from '@renderer/config/models'
|
||||
import { Model } from '@renderer/types'
|
||||
// import {} from '@cherrystudio/ai-core'
|
||||
// import { isWebSearchModel } from '@renderer/config/models'
|
||||
// import { Model } from '@renderer/types'
|
||||
// // import {} from '@cherrystudio/ai-core'
|
||||
|
||||
// The tool name for Gemini search can be arbitrary, but let's use a descriptive one.
|
||||
const GEMINI_SEARCH_TOOL_NAME = 'google_search'
|
||||
// // The tool name for Gemini search can be arbitrary, but let's use a descriptive one.
|
||||
// const GEMINI_SEARCH_TOOL_NAME = 'google_search'
|
||||
|
||||
export function getWebSearchTools(model: Model): Record<string, any> {
|
||||
if (!isWebSearchModel(model)) {
|
||||
return {}
|
||||
}
|
||||
// export function getWebSearchTools(model: Model): Record<string, any> {
|
||||
// if (!isWebSearchModel(model)) {
|
||||
// return {}
|
||||
// }
|
||||
|
||||
// Use provider from model if available, otherwise fallback to parsing model id.
|
||||
const provider = model.provider || model.id.split('/')[0]
|
||||
// // Use provider from model if available, otherwise fallback to parsing model id.
|
||||
// const provider = model.provider || model.id.split('/')[0]
|
||||
|
||||
switch (provider) {
|
||||
case 'anthropic':
|
||||
return {
|
||||
web_search: {
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
max_uses: 5
|
||||
}
|
||||
}
|
||||
case 'google':
|
||||
case 'gemini':
|
||||
return {
|
||||
[GEMINI_SEARCH_TOOL_NAME]: {
|
||||
googleSearch: {}
|
||||
}
|
||||
}
|
||||
default:
|
||||
// For OpenAI and others, web search is often a parameter, not a tool.
|
||||
// The logic is handled in `buildProviderOptions`.
|
||||
return {}
|
||||
}
|
||||
}
|
||||
// switch (provider) {
|
||||
// case 'anthropic':
|
||||
// return {
|
||||
// web_search: {
|
||||
// type: 'web_search_20250305',
|
||||
// name: 'web_search',
|
||||
// max_uses: 5
|
||||
// }
|
||||
// }
|
||||
// case 'google':
|
||||
// case 'gemini':
|
||||
// return {
|
||||
// [GEMINI_SEARCH_TOOL_NAME]: {
|
||||
// googleSearch: {}
|
||||
// }
|
||||
// }
|
||||
// default:
|
||||
// // For OpenAI and others, web search is often a parameter, not a tool.
|
||||
// // The logic is handled in `buildProviderOptions`.
|
||||
// return {}
|
||||
// }
|
||||
// }
|
||||
|
||||
Loading…
Reference in New Issue
Block a user