mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-09 14:59:27 +08:00
refactor: aisdk config (#11402)
* refactor: improve model filtering with todo for robust conversion * refactor(aiCore): add AiSdkConfig type and update provider config handling - Introduce new AiSdkConfig type in aiCoreTypes for better type safety - Update provider factory and config to use AiSdkConfig consistently - Simplify getAiSdkProviderId return type to string - Add config validation in ModernAiProvider * refactor(aiCore): move ai core types to dedicated module Consolidate AI core type definitions into a dedicated module under aiCore/types. This improves code organization by keeping related types together and removes circular dependencies between modules. The change includes: - Moving AiSdkConfig to aiCore/types - Updating all imports to reference the new location - Removing duplicate type definitions * refactor(provider): add return type to createAiSdkProvider function
This commit is contained in:
parent
49903a1567
commit
fa361126b8
@ -32,6 +32,7 @@ import {
|
|||||||
prepareSpecialProviderConfig,
|
prepareSpecialProviderConfig,
|
||||||
providerToAiSdkConfig
|
providerToAiSdkConfig
|
||||||
} from './provider/providerConfig'
|
} from './provider/providerConfig'
|
||||||
|
import type { AiSdkConfig } from './types'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ModernAiProvider')
|
const logger = loggerService.withContext('ModernAiProvider')
|
||||||
|
|
||||||
@ -44,7 +45,7 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
|
|||||||
|
|
||||||
export default class ModernAiProvider {
|
export default class ModernAiProvider {
|
||||||
private legacyProvider: LegacyAiProvider
|
private legacyProvider: LegacyAiProvider
|
||||||
private config?: ReturnType<typeof providerToAiSdkConfig>
|
private config?: AiSdkConfig
|
||||||
private actualProvider: Provider
|
private actualProvider: Provider
|
||||||
private model?: Model
|
private model?: Model
|
||||||
private localProvider: Awaited<AiSdkProvider> | null = null
|
private localProvider: Awaited<AiSdkProvider> | null = null
|
||||||
@ -89,6 +90,11 @@ export default class ModernAiProvider {
|
|||||||
// 每次请求时重新生成配置以确保API key轮换生效
|
// 每次请求时重新生成配置以确保API key轮换生效
|
||||||
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
|
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
|
||||||
logger.debug('Generated provider config for completions', this.config)
|
logger.debug('Generated provider config for completions', this.config)
|
||||||
|
|
||||||
|
// 检查 config 是否存在
|
||||||
|
if (!this.config) {
|
||||||
|
throw new Error('Provider config is undefined; cannot proceed with completions')
|
||||||
|
}
|
||||||
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
|
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
|
||||||
providerConfig.isImageGenerationEndpoint = true
|
providerConfig.isImageGenerationEndpoint = true
|
||||||
}
|
}
|
||||||
@ -463,6 +469,11 @@ export default class ModernAiProvider {
|
|||||||
// 如果支持新的 AI SDK,使用现代化实现
|
// 如果支持新的 AI SDK,使用现代化实现
|
||||||
if (isModernSdkSupported(this.actualProvider)) {
|
if (isModernSdkSupported(this.actualProvider)) {
|
||||||
try {
|
try {
|
||||||
|
// 确保 config 已定义
|
||||||
|
if (!this.config) {
|
||||||
|
throw new Error('Provider config is undefined; cannot proceed with generateImage')
|
||||||
|
}
|
||||||
|
|
||||||
// 确保本地provider已创建
|
// 确保本地provider已创建
|
||||||
if (!this.localProvider) {
|
if (!this.localProvider) {
|
||||||
this.localProvider = await createAiSdkProvider(this.config)
|
this.localProvider = await createAiSdkProvider(this.config)
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import { loggerService } from '@logger'
|
|||||||
import type { Provider } from '@renderer/types'
|
import type { Provider } from '@renderer/types'
|
||||||
import type { Provider as AiSdkProvider } from 'ai'
|
import type { Provider as AiSdkProvider } from 'ai'
|
||||||
|
|
||||||
|
import type { AiSdkConfig } from '../types'
|
||||||
import { initializeNewProviders } from './providerInitialization'
|
import { initializeNewProviders } from './providerInitialization'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ProviderFactory')
|
const logger = loggerService.withContext('ProviderFactory')
|
||||||
@ -55,7 +56,7 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
|
|||||||
* 获取AI SDK Provider ID
|
* 获取AI SDK Provider ID
|
||||||
* 简化版:减少重复逻辑,利用通用解析函数
|
* 简化版:减少重复逻辑,利用通用解析函数
|
||||||
*/
|
*/
|
||||||
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
|
export function getAiSdkProviderId(provider: Provider): string {
|
||||||
// 1. 尝试解析provider.id
|
// 1. 尝试解析provider.id
|
||||||
const resolvedFromId = tryResolveProviderId(provider.id)
|
const resolvedFromId = tryResolveProviderId(provider.id)
|
||||||
if (resolvedFromId) {
|
if (resolvedFromId) {
|
||||||
@ -73,11 +74,11 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
|
|||||||
if (provider.apiHost.includes('api.openai.com')) {
|
if (provider.apiHost.includes('api.openai.com')) {
|
||||||
return 'openai-chat'
|
return 'openai-chat'
|
||||||
}
|
}
|
||||||
// 3. 最后的fallback(通常会成为openai-compatible)
|
// 3. 最后的fallback(使用provider本身的id)
|
||||||
return provider.id as ProviderId
|
return provider.id
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function createAiSdkProvider(config) {
|
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
|
||||||
let localProvider: Awaited<AiSdkProvider> | null = null
|
let localProvider: Awaited<AiSdkProvider> | null = null
|
||||||
try {
|
try {
|
||||||
if (config.providerId === 'openai' && config.options?.mode === 'chat') {
|
if (config.providerId === 'openai' && config.options?.mode === 'chat') {
|
||||||
|
|||||||
@ -1,10 +1,4 @@
|
|||||||
import {
|
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
|
||||||
formatPrivateKey,
|
|
||||||
hasProviderConfig,
|
|
||||||
ProviderConfigFactory,
|
|
||||||
type ProviderId,
|
|
||||||
type ProviderSettingsMap
|
|
||||||
} from '@cherrystudio/ai-core/provider'
|
|
||||||
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
||||||
import {
|
import {
|
||||||
getAwsBedrockAccessKeyId,
|
getAwsBedrockAccessKeyId,
|
||||||
@ -29,6 +23,7 @@ import {
|
|||||||
} from '@renderer/utils/provider'
|
} from '@renderer/utils/provider'
|
||||||
import { cloneDeep } from 'lodash'
|
import { cloneDeep } from 'lodash'
|
||||||
|
|
||||||
|
import type { AiSdkConfig } from '../types'
|
||||||
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
|
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
|
||||||
import { COPILOT_DEFAULT_HEADERS } from './constants'
|
import { COPILOT_DEFAULT_HEADERS } from './constants'
|
||||||
import { getAiSdkProviderId } from './factory'
|
import { getAiSdkProviderId } from './factory'
|
||||||
@ -132,13 +127,7 @@ export function getActualProvider(model: Model): Provider {
|
|||||||
* 将 Provider 配置转换为新 AI SDK 格式
|
* 将 Provider 配置转换为新 AI SDK 格式
|
||||||
* 简化版:利用新的别名映射系统
|
* 简化版:利用新的别名映射系统
|
||||||
*/
|
*/
|
||||||
export function providerToAiSdkConfig(
|
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
|
||||||
actualProvider: Provider,
|
|
||||||
model: Model
|
|
||||||
): {
|
|
||||||
providerId: ProviderId | 'openai-compatible'
|
|
||||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
|
||||||
} {
|
|
||||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||||
|
|
||||||
// 构建基础配置
|
// 构建基础配置
|
||||||
@ -238,7 +227,7 @@ export function providerToAiSdkConfig(
|
|||||||
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
||||||
return {
|
return {
|
||||||
providerId: aiSdkProviderId as ProviderId,
|
providerId: aiSdkProviderId,
|
||||||
options
|
options
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
15
src/renderer/src/aiCore/types/index.ts
Normal file
15
src/renderer/src/aiCore/types/index.ts
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
/**
|
||||||
|
* This type definition file is only for renderer.
|
||||||
|
* It cannot be migrated to @renderer/types since files within it are actually being used by both main and renderer.
|
||||||
|
* If we do that, main would throw an error because it cannot import a module which imports a type from a browser-enviroment-only package.
|
||||||
|
* (ai-core package is set as browser-enviroment-only)
|
||||||
|
*
|
||||||
|
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { ProviderSettingsMap } from '@cherrystudio/ai-core/provider'
|
||||||
|
|
||||||
|
export type AiSdkConfig = {
|
||||||
|
providerId: string
|
||||||
|
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||||
|
}
|
||||||
@ -183,6 +183,7 @@ const PopupContainer: React.FC<Props> = ({ providerId, resolve }) => {
|
|||||||
setLoadingModels(true)
|
setLoadingModels(true)
|
||||||
try {
|
try {
|
||||||
const models = await fetchModels(provider)
|
const models = await fetchModels(provider)
|
||||||
|
// TODO: More robust conversion
|
||||||
const filteredModels = models
|
const filteredModels = models
|
||||||
.map((model) => ({
|
.map((model) => ({
|
||||||
// @ts-ignore modelId
|
// @ts-ignore modelId
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user