mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-01 17:59:09 +08:00
feat(aiCore): introduce provider configuration enhancements and initialization
- Added a new provider configuration module to handle special provider logic and formatting. - Implemented asynchronous preparation of special provider configurations in the ModernAiProvider class. - Refactored provider initialization logic to support dynamic registration of new AI providers. - Updated utility functions to streamline provider option building and improve compatibility with new provider configurations.
This commit is contained in:
parent
49cd9d6723
commit
efeada281a
@ -22,7 +22,12 @@ import LegacyAiProvider from './legacy/index'
|
|||||||
import { CompletionsResult } from './legacy/middleware/schemas'
|
import { CompletionsResult } from './legacy/middleware/schemas'
|
||||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
|
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
|
||||||
import { buildPlugins } from './plugins/PluginBuilder'
|
import { buildPlugins } from './plugins/PluginBuilder'
|
||||||
import { getActualProvider, isModernSdkSupported, providerToAiSdkConfig } from './provider/ProviderConfigProcessor'
|
import {
|
||||||
|
getActualProvider,
|
||||||
|
isModernSdkSupported,
|
||||||
|
prepareSpecialProviderConfig,
|
||||||
|
providerToAiSdkConfig
|
||||||
|
} from './provider/providerConfig'
|
||||||
import type { StreamTextParams } from './types'
|
import type { StreamTextParams } from './types'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ModernAiProvider')
|
const logger = loggerService.withContext('ModernAiProvider')
|
||||||
@ -54,6 +59,10 @@ export default class ModernAiProvider {
|
|||||||
callType: string
|
callType: string
|
||||||
}
|
}
|
||||||
) {
|
) {
|
||||||
|
// 准备特殊配置
|
||||||
|
await prepareSpecialProviderConfig(this.actualProvider, this.config)
|
||||||
|
|
||||||
|
console.log('this.config', this.config)
|
||||||
if (config.topicId && getEnableDeveloperMode()) {
|
if (config.topicId && getEnableDeveloperMode()) {
|
||||||
// TypeScript类型窄化:确保topicId是string类型
|
// TypeScript类型窄化:确保topicId是string类型
|
||||||
const traceConfig = {
|
const traceConfig = {
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } fr
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { Provider } from '@renderer/types'
|
import { Provider } from '@renderer/types'
|
||||||
|
|
||||||
import { initializeNewProviders } from './providerConfigs'
|
import { initializeNewProviders } from './providerInitialization'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ProviderFactory')
|
const logger = loggerService.withContext('ProviderFactory')
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
|||||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import { loggerService } from '@renderer/services/LoggerService'
|
import { loggerService } from '@renderer/services/LoggerService'
|
||||||
|
import store from '@renderer/store'
|
||||||
import type { Model, Provider } from '@renderer/types'
|
import type { Model, Provider } from '@renderer/types'
|
||||||
import { formatApiHost } from '@renderer/utils/api'
|
import { formatApiHost } from '@renderer/utils/api'
|
||||||
import { cloneDeep } from 'lodash'
|
import { cloneDeep } from 'lodash'
|
||||||
@ -97,6 +98,15 @@ export function providerToAiSdkConfig(
|
|||||||
extraOptions.headers = actualProvider.extra_headers
|
extraOptions.headers = actualProvider.extra_headers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copilot
|
||||||
|
if (actualProvider.id === 'copilot') {
|
||||||
|
extraOptions.headers = {
|
||||||
|
...extraOptions.extra_headers,
|
||||||
|
'editor-version': 'vscode/1.97.2',
|
||||||
|
'copilot-vision-request': 'true'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 如果AI SDK支持该provider,使用原生配置
|
// 如果AI SDK支持该provider,使用原生配置
|
||||||
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
||||||
@ -134,3 +144,18 @@ export function isModernSdkSupported(provider: Provider): boolean {
|
|||||||
// 如果映射到了支持的provider,则支持现代SDK
|
// 如果映射到了支持的provider,则支持现代SDK
|
||||||
return hasProviderConfig(aiSdkProviderId)
|
return hasProviderConfig(aiSdkProviderId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 准备特殊provider的配置,主要用于异步处理的配置
|
||||||
|
*/
|
||||||
|
export async function prepareSpecialProviderConfig(
|
||||||
|
provider: Provider,
|
||||||
|
config: ReturnType<typeof providerToAiSdkConfig>
|
||||||
|
) {
|
||||||
|
if (provider.id === 'copilot') {
|
||||||
|
const defaultHeaders = store.getState().copilot.defaultHeaders
|
||||||
|
const { token } = await window.api.copilot.getToken(defaultHeaders)
|
||||||
|
config.options.apiKey = token
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
}
|
||||||
@ -68,7 +68,6 @@ export function buildProviderOptions(
|
|||||||
}
|
}
|
||||||
): Record<string, any> {
|
): Record<string, any> {
|
||||||
const providerId = getAiSdkProviderId(actualProvider)
|
const providerId = getAiSdkProviderId(actualProvider)
|
||||||
const serviceTierSetting = getServiceTier(model, actualProvider)
|
|
||||||
// 构建 provider 特定的选项
|
// 构建 provider 特定的选项
|
||||||
let providerSpecificOptions: Record<string, any> = {}
|
let providerSpecificOptions: Record<string, any> = {}
|
||||||
|
|
||||||
@ -77,7 +76,7 @@ export function buildProviderOptions(
|
|||||||
case 'openai':
|
case 'openai':
|
||||||
case 'azure':
|
case 'azure':
|
||||||
providerSpecificOptions = {
|
providerSpecificOptions = {
|
||||||
...buildOpenAIProviderOptions(assistant, model, capabilities)
|
...buildOpenAIProviderOptions(assistant, model, capabilities, actualProvider)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -103,7 +102,6 @@ export function buildProviderOptions(
|
|||||||
// 合并自定义参数到 provider 特定的选项中
|
// 合并自定义参数到 provider 特定的选项中
|
||||||
providerSpecificOptions = {
|
providerSpecificOptions = {
|
||||||
...providerSpecificOptions,
|
...providerSpecificOptions,
|
||||||
serviceTier: serviceTierSetting,
|
|
||||||
...getCustomParameters(assistant)
|
...getCustomParameters(assistant)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,17 +121,19 @@ function buildOpenAIProviderOptions(
|
|||||||
enableReasoning: boolean
|
enableReasoning: boolean
|
||||||
enableWebSearch: boolean
|
enableWebSearch: boolean
|
||||||
enableGenerateImage: boolean
|
enableGenerateImage: boolean
|
||||||
}
|
},
|
||||||
|
actualProvider: Provider
|
||||||
): Record<string, any> {
|
): Record<string, any> {
|
||||||
const { enableReasoning } = capabilities
|
const { enableReasoning } = capabilities
|
||||||
let providerOptions: Record<string, any> = {}
|
let providerOptions: Record<string, any> = {}
|
||||||
|
const serviceTierSetting = getServiceTier(model, actualProvider)
|
||||||
// OpenAI 推理参数
|
// OpenAI 推理参数
|
||||||
if (enableReasoning) {
|
if (enableReasoning) {
|
||||||
const reasoningParams = getOpenAIReasoningParams(assistant, model)
|
const reasoningParams = getOpenAIReasoningParams(assistant, model)
|
||||||
providerOptions = {
|
providerOptions = {
|
||||||
...providerOptions,
|
...providerOptions,
|
||||||
...reasoningParams
|
...reasoningParams,
|
||||||
|
serviceTier: serviceTierSetting
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user