mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat(aiCore): introduce provider configuration enhancements and initialization
- Added a new provider configuration module to handle special provider logic and formatting. - Implemented asynchronous preparation of special provider configurations in the ModernAiProvider class. - Refactored provider initialization logic to support dynamic registration of new AI providers. - Updated utility functions to streamline provider option building and improve compatibility with new provider configurations.
This commit is contained in:
parent
49cd9d6723
commit
efeada281a
@ -22,7 +22,12 @@ import LegacyAiProvider from './legacy/index'
|
||||
import { CompletionsResult } from './legacy/middleware/schemas'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { getActualProvider, isModernSdkSupported, providerToAiSdkConfig } from './provider/ProviderConfigProcessor'
|
||||
import {
|
||||
getActualProvider,
|
||||
isModernSdkSupported,
|
||||
prepareSpecialProviderConfig,
|
||||
providerToAiSdkConfig
|
||||
} from './provider/providerConfig'
|
||||
import type { StreamTextParams } from './types'
|
||||
|
||||
const logger = loggerService.withContext('ModernAiProvider')
|
||||
@ -54,6 +59,10 @@ export default class ModernAiProvider {
|
||||
callType: string
|
||||
}
|
||||
) {
|
||||
// 准备特殊配置
|
||||
await prepareSpecialProviderConfig(this.actualProvider, this.config)
|
||||
|
||||
console.log('this.config', this.config)
|
||||
if (config.topicId && getEnableDeveloperMode()) {
|
||||
// TypeScript类型窄化:确保topicId是string类型
|
||||
const traceConfig = {
|
||||
|
||||
@ -2,7 +2,7 @@ import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } fr
|
||||
import { loggerService } from '@logger'
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import { initializeNewProviders } from './providerConfigs'
|
||||
import { initializeNewProviders } from './providerInitialization'
|
||||
|
||||
const logger = loggerService.withContext('ProviderFactory')
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { loggerService } from '@renderer/services/LoggerService'
|
||||
import store from '@renderer/store'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { cloneDeep } from 'lodash'
|
||||
@ -97,6 +98,15 @@ export function providerToAiSdkConfig(
|
||||
extraOptions.headers = actualProvider.extra_headers
|
||||
}
|
||||
|
||||
// copilot
|
||||
if (actualProvider.id === 'copilot') {
|
||||
extraOptions.headers = {
|
||||
...extraOptions.extra_headers,
|
||||
'editor-version': 'vscode/1.97.2',
|
||||
'copilot-vision-request': 'true'
|
||||
}
|
||||
}
|
||||
|
||||
// 如果AI SDK支持该provider,使用原生配置
|
||||
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
||||
@ -134,3 +144,18 @@ export function isModernSdkSupported(provider: Provider): boolean {
|
||||
// 如果映射到了支持的provider,则支持现代SDK
|
||||
return hasProviderConfig(aiSdkProviderId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 准备特殊provider的配置,主要用于异步处理的配置
|
||||
*/
|
||||
export async function prepareSpecialProviderConfig(
|
||||
provider: Provider,
|
||||
config: ReturnType<typeof providerToAiSdkConfig>
|
||||
) {
|
||||
if (provider.id === 'copilot') {
|
||||
const defaultHeaders = store.getState().copilot.defaultHeaders
|
||||
const { token } = await window.api.copilot.getToken(defaultHeaders)
|
||||
config.options.apiKey = token
|
||||
}
|
||||
return config
|
||||
}
|
||||
@ -68,7 +68,6 @@ export function buildProviderOptions(
|
||||
}
|
||||
): Record<string, any> {
|
||||
const providerId = getAiSdkProviderId(actualProvider)
|
||||
const serviceTierSetting = getServiceTier(model, actualProvider)
|
||||
// 构建 provider 特定的选项
|
||||
let providerSpecificOptions: Record<string, any> = {}
|
||||
|
||||
@ -77,7 +76,7 @@ export function buildProviderOptions(
|
||||
case 'openai':
|
||||
case 'azure':
|
||||
providerSpecificOptions = {
|
||||
...buildOpenAIProviderOptions(assistant, model, capabilities)
|
||||
...buildOpenAIProviderOptions(assistant, model, capabilities, actualProvider)
|
||||
}
|
||||
break
|
||||
|
||||
@ -103,7 +102,6 @@ export function buildProviderOptions(
|
||||
// 合并自定义参数到 provider 特定的选项中
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
serviceTier: serviceTierSetting,
|
||||
...getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
@ -123,17 +121,19 @@ function buildOpenAIProviderOptions(
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
},
|
||||
actualProvider: Provider
|
||||
): Record<string, any> {
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
const serviceTierSetting = getServiceTier(model, actualProvider)
|
||||
// OpenAI 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getOpenAIReasoningParams(assistant, model)
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...reasoningParams
|
||||
...reasoningParams,
|
||||
serviceTier: serviceTierSetting
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user