mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
refactor: enhance provider settings and update web search plugin configuration
- Updated providerSettings to allow optional 'mode' parameter for various providers, enhancing flexibility in model configuration. - Refactored web search plugin to integrate Google search capabilities and streamline provider options handling. - Removed deprecated code and improved type definitions for better clarity and maintainability. - Added console logging for debugging purposes in the provider configuration process.
This commit is contained in:
parent
650650a68f
commit
e7d5626055
@ -33,7 +33,7 @@ export async function createBaseModel<T extends ProviderId>({
|
||||
}: {
|
||||
providerId: T
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap[T]
|
||||
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
|
||||
extraModelConfig?: any
|
||||
// middlewares?: LanguageModelV1Middleware[]
|
||||
}): Promise<LanguageModelV2>
|
||||
@ -47,7 +47,7 @@ export async function createBaseModel({
|
||||
}: {
|
||||
providerId: string
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap['openai-compatible']
|
||||
providerSettings: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' }
|
||||
extraModelConfig?: any
|
||||
// middlewares?: LanguageModelV1Middleware[]
|
||||
}): Promise<LanguageModelV2>
|
||||
@ -61,7 +61,7 @@ export async function createBaseModel({
|
||||
}: {
|
||||
providerId: string
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap[ProviderId]
|
||||
providerSettings: ProviderSettingsMap[ProviderId] & { mode?: 'chat' | 'responses' }
|
||||
// middlewares?: LanguageModelV1Middleware[]
|
||||
extraModelConfig?: any
|
||||
}): Promise<LanguageModelV2> {
|
||||
@ -86,6 +86,7 @@ export async function createBaseModel({
|
||||
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${effectiveProviderId}"`
|
||||
)
|
||||
}
|
||||
// TODO: 对openai 的 providerSettings.mode参数是否要删除,目前看没毛病
|
||||
// 创建provider实例
|
||||
let provider = creatorFunction(providerSettings)
|
||||
|
||||
@ -93,7 +94,7 @@ export async function createBaseModel({
|
||||
if (providerConfig.id === 'openai') {
|
||||
if (
|
||||
'mode' in providerSettings &&
|
||||
providerSettings.mode === 'response' &&
|
||||
providerSettings.mode === 'responses' &&
|
||||
!isOpenAIChatCompletionOnlyModel(modelId)
|
||||
) {
|
||||
provider = provider.responses
|
||||
|
||||
@ -8,7 +8,7 @@ import type { ProviderId, ProviderSettingsMap } from '../../types'
|
||||
export interface ModelConfig<T extends ProviderId = ProviderId> {
|
||||
providerId: T
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap[T] & { mode: 'chat' | 'responses' }
|
||||
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
// 额外模型参数
|
||||
extraModelConfig?: Record<string, any>
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
import { google } from '@ai-sdk/google'
|
||||
import { openai } from '@ai-sdk/openai'
|
||||
|
||||
import { ProviderOptionsMap } from '../../../options/types'
|
||||
@ -8,16 +9,7 @@ import { ProviderOptionsMap } from '../../../options/types'
|
||||
*/
|
||||
type OpenAISearchConfig = Parameters<typeof openai.tools.webSearchPreview>[0]
|
||||
type AnthropicSearchConfig = Parameters<typeof anthropic.tools.webSearch_20250305>[0]
|
||||
/**
|
||||
* XAI 特有的搜索参数
|
||||
* @internal
|
||||
*/
|
||||
interface XaiProviderOptions {
|
||||
searchParameters?: {
|
||||
sources?: any[]
|
||||
safeSearch?: boolean
|
||||
}
|
||||
}
|
||||
type GoogleSearchConfig = Parameters<typeof google.tools.googleSearch>[0]
|
||||
|
||||
/**
|
||||
* 插件初始化时接收的完整配置对象
|
||||
@ -28,20 +20,16 @@ export interface WebSearchPluginConfig {
|
||||
openai?: OpenAISearchConfig
|
||||
anthropic?: AnthropicSearchConfig
|
||||
xai?: ProviderOptionsMap['xai']['searchParameters']
|
||||
google?: Pick<ProviderOptionsMap['google'], 'useSearchGrounding' | 'dynamicRetrievalConfig'>
|
||||
'google-vertex'?: Pick<ProviderOptionsMap['google'], 'useSearchGrounding' | 'dynamicRetrievalConfig'>
|
||||
google?: GoogleSearchConfig
|
||||
'google-vertex'?: GoogleSearchConfig
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件的默认配置
|
||||
*/
|
||||
export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
|
||||
google: {
|
||||
useSearchGrounding: true
|
||||
},
|
||||
'google-vertex': {
|
||||
useSearchGrounding: true
|
||||
},
|
||||
google: {},
|
||||
'google-vertex': {},
|
||||
openai: {},
|
||||
xai: {
|
||||
mode: 'on',
|
||||
@ -53,37 +41,3 @@ export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
|
||||
maxUses: 5
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据配置构建 Google 的 providerOptions
|
||||
*/
|
||||
export const getGoogleProviderOptions = (providerOptions: any) => {
|
||||
if (!providerOptions) providerOptions = {}
|
||||
if (!providerOptions.google) providerOptions.google = {}
|
||||
providerOptions.google.useSearchGrounding = true
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据配置构建 XAI 的 providerOptions
|
||||
*/
|
||||
export const getXaiProviderOptions = (providerOptions: any, config?: XaiProviderOptions['searchParameters']) => {
|
||||
if (!providerOptions) providerOptions = {}
|
||||
if (!providerOptions.xai) providerOptions.xai = {}
|
||||
providerOptions.xai.searchParameters = {
|
||||
mode: 'on',
|
||||
...(config ?? {})
|
||||
}
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
export type AnthropicSearchInput = {
|
||||
query: string
|
||||
}
|
||||
export type AnthropicSearchOutput = {
|
||||
url: string
|
||||
title: string
|
||||
pageAge: string | null
|
||||
encryptedContent: string
|
||||
type: string
|
||||
}[]
|
||||
|
||||
@ -3,9 +3,10 @@
|
||||
* 提供统一的网络搜索能力,支持多个 AI Provider
|
||||
*/
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
import { google } from '@ai-sdk/google'
|
||||
import { openai } from '@ai-sdk/openai'
|
||||
|
||||
import { createGoogleOptions, createXaiOptions, mergeProviderOptions } from '../../../options'
|
||||
import { createXaiOptions, mergeProviderOptions } from '../../../options'
|
||||
import { definePlugin } from '../../'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import { DEFAULT_WEB_SEARCH_CONFIG, WebSearchPluginConfig } from './helper'
|
||||
@ -22,15 +23,12 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
|
||||
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
const { providerId } = context
|
||||
// console.log('providerId', providerId)
|
||||
// const modelToProviderId = getModelToProviderId(modelId)
|
||||
// console.log('modelToProviderId', modelToProviderId)
|
||||
console.log('providerId', providerId)
|
||||
switch (providerId) {
|
||||
case 'openai': {
|
||||
if (config.openai) {
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
|
||||
// console.log('params.tools', params.tools)
|
||||
}
|
||||
break
|
||||
}
|
||||
@ -45,11 +43,8 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
|
||||
|
||||
case 'google':
|
||||
case 'google-vertex': {
|
||||
// @ts-ignore - providerId is a string that can be used to index config
|
||||
if (config[providerId]) {
|
||||
const searchOptions = createGoogleOptions({ useSearchGrounding: true })
|
||||
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
|
||||
}
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search = google.tools.googleSearch(config.google || {})
|
||||
break
|
||||
}
|
||||
|
||||
@ -62,23 +57,14 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
|
||||
}
|
||||
break
|
||||
}
|
||||
// default: {
|
||||
// if (!params.providerOptions) params.providerOptions = {}
|
||||
// params.providerOptions['aihubmix'] = {
|
||||
// web_search: anthropic.tools.webSearch_20250305()
|
||||
// }
|
||||
// console.log('params.providerOptions', params.providerOptions)
|
||||
// break
|
||||
// }
|
||||
}
|
||||
// console.log('params', params)
|
||||
|
||||
return params
|
||||
}
|
||||
})
|
||||
|
||||
// 导出类型定义供开发者使用
|
||||
export type { AnthropicSearchInput, AnthropicSearchOutput, WebSearchPluginConfig } from './helper'
|
||||
export type { WebSearchPluginConfig } from './helper'
|
||||
|
||||
// 默认导出
|
||||
export default webSearchPlugin
|
||||
|
||||
@ -27,7 +27,7 @@ import { RuntimeExecutor } from './executor'
|
||||
*/
|
||||
export function createExecutor<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T],
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
plugins?: AiPlugin[]
|
||||
): RuntimeExecutor<T> {
|
||||
return RuntimeExecutor.create(providerId, options, plugins)
|
||||
@ -37,7 +37,7 @@ export function createExecutor<T extends ProviderId>(
|
||||
* 创建OpenAI Compatible执行器
|
||||
*/
|
||||
export function createOpenAICompatibleExecutor(
|
||||
options: ProviderSettingsMap['openai-compatible'],
|
||||
options: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' },
|
||||
plugins: AiPlugin[] = []
|
||||
): RuntimeExecutor<'openai-compatible'> {
|
||||
return RuntimeExecutor.createOpenAICompatible(options, plugins)
|
||||
@ -50,7 +50,7 @@ export function createOpenAICompatibleExecutor(
|
||||
*/
|
||||
export async function streamText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T],
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['streamText']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
@ -65,7 +65,7 @@ export async function streamText<T extends ProviderId>(
|
||||
*/
|
||||
export async function generateText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T],
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['generateText']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
@ -80,7 +80,7 @@ export async function generateText<T extends ProviderId>(
|
||||
*/
|
||||
export async function generateObject<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T],
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['generateObject']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
@ -95,7 +95,7 @@ export async function generateObject<T extends ProviderId>(
|
||||
*/
|
||||
export async function streamObject<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T],
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['streamObject']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
|
||||
@ -10,7 +10,7 @@ import { type AiPlugin } from '../plugins'
|
||||
*/
|
||||
export interface RuntimeConfig<T extends ProviderId = ProviderId> {
|
||||
providerId: T
|
||||
providerSettings: ModelConfig<T>['providerSettings']
|
||||
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
|
||||
@ -133,11 +133,11 @@ export class AiSdkToChunkAdapter {
|
||||
|
||||
// === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) ===
|
||||
|
||||
case 'tool-input-start':
|
||||
case 'tool-input-delta':
|
||||
case 'tool-input-end':
|
||||
this.toolCallHandler.handleToolCallCreated(chunk)
|
||||
break
|
||||
// case 'tool-input-start':
|
||||
// case 'tool-input-delta':
|
||||
// case 'tool-input-end':
|
||||
// this.toolCallHandler.handleToolCallCreated(chunk)
|
||||
// break
|
||||
|
||||
// case 'tool-input-delta':
|
||||
// this.toolCallHandler.handleToolCallCreated(chunk)
|
||||
|
||||
@ -46,7 +46,6 @@ function getActualProvider(model: Model): Provider {
|
||||
|
||||
if (provider.id === 'aihubmix') {
|
||||
actualProvider = createAihubmixProvider(model, actualProvider)
|
||||
console.log('actualProvider', actualProvider)
|
||||
}
|
||||
if (actualProvider.type === 'gemini') {
|
||||
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta')
|
||||
@ -67,13 +66,14 @@ function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||
// console.log('aiSdkProviderId', aiSdkProviderId)
|
||||
// 如果provider是openai,则使用strict模式并且默认responses api
|
||||
const actualProviderType = actualProvider.type
|
||||
const actualProviderId = actualProvider.id
|
||||
const openaiResponseOptions =
|
||||
actualProviderType === 'openai-response'
|
||||
// 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses
|
||||
actualProviderId === 'openai'
|
||||
? {
|
||||
mode: 'response'
|
||||
mode: 'responses'
|
||||
}
|
||||
: actualProviderType === 'openai'
|
||||
: aiSdkProviderId === 'openai'
|
||||
? {
|
||||
mode: 'chat'
|
||||
}
|
||||
@ -86,10 +86,9 @@ function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
aiSdkProviderId,
|
||||
{
|
||||
baseURL: actualProvider.apiHost,
|
||||
apiKey: actualProvider.apiKey,
|
||||
headers: actualProvider.extra_headers
|
||||
apiKey: actualProvider.apiKey
|
||||
},
|
||||
openaiResponseOptions
|
||||
{ ...openaiResponseOptions, headers: actualProvider.extra_headers }
|
||||
)
|
||||
|
||||
return {
|
||||
@ -224,7 +223,9 @@ export default class ModernAiProvider {
|
||||
// try {
|
||||
// 根据条件构建插件数组
|
||||
const plugins = this.buildPlugins(middlewareConfig)
|
||||
|
||||
console.log('this.config.providerId', this.config.providerId)
|
||||
console.log('this.config.options', this.config.options)
|
||||
console.log('plugins', plugins)
|
||||
// 用构建好的插件数组创建executor
|
||||
const executor = createExecutor(this.config.providerId, this.config.options, plugins)
|
||||
|
||||
|
||||
@ -26,6 +26,9 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
|
||||
if (AiCore.isSupported(provider.id)) {
|
||||
return provider.id as ProviderId
|
||||
}
|
||||
if (AiCore.isSupported(provider.type)) {
|
||||
return provider.type as ProviderId
|
||||
}
|
||||
|
||||
return provider.id as ProviderId
|
||||
}
|
||||
|
||||
@ -35,7 +35,12 @@ export function buildProviderOptions(
|
||||
switch (providerId) {
|
||||
case 'openai':
|
||||
case 'azure':
|
||||
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities)
|
||||
providerSpecificOptions = {
|
||||
...buildOpenAIProviderOptions(assistant, model, capabilities),
|
||||
// 函数内有对于真实provider.id的判断,应该不会影响原生provider
|
||||
...buildGenericProviderOptions(assistant, model, capabilities)
|
||||
}
|
||||
|
||||
break
|
||||
|
||||
case 'anthropic':
|
||||
|
||||
Loading…
Reference in New Issue
Block a user