mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-31 08:29:07 +08:00
refactor: aisdk config (#11402)
* refactor: improve model filtering with todo for robust conversion * refactor(aiCore): add AiSdkConfig type and update provider config handling - Introduce new AiSdkConfig type in aiCoreTypes for better type safety - Update provider factory and config to use AiSdkConfig consistently - Simplify getAiSdkProviderId return type to string - Add config validation in ModernAiProvider * refactor(aiCore): move ai core types to dedicated module Consolidate AI core type definitions into a dedicated module under aiCore/types. This improves code organization by keeping related types together and removes circular dependencies between modules. The change includes: - Moving AiSdkConfig to aiCore/types - Updating all imports to reference the new location - Removing duplicate type definitions * refactor(provider): add return type to createAiSdkProvider function
This commit is contained in:
parent
49903a1567
commit
fa361126b8
@ -32,6 +32,7 @@ import {
|
||||
prepareSpecialProviderConfig,
|
||||
providerToAiSdkConfig
|
||||
} from './provider/providerConfig'
|
||||
import type { AiSdkConfig } from './types'
|
||||
|
||||
const logger = loggerService.withContext('ModernAiProvider')
|
||||
|
||||
@ -44,7 +45,7 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
|
||||
|
||||
export default class ModernAiProvider {
|
||||
private legacyProvider: LegacyAiProvider
|
||||
private config?: ReturnType<typeof providerToAiSdkConfig>
|
||||
private config?: AiSdkConfig
|
||||
private actualProvider: Provider
|
||||
private model?: Model
|
||||
private localProvider: Awaited<AiSdkProvider> | null = null
|
||||
@ -89,6 +90,11 @@ export default class ModernAiProvider {
|
||||
// 每次请求时重新生成配置以确保API key轮换生效
|
||||
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
|
||||
logger.debug('Generated provider config for completions', this.config)
|
||||
|
||||
// 检查 config 是否存在
|
||||
if (!this.config) {
|
||||
throw new Error('Provider config is undefined; cannot proceed with completions')
|
||||
}
|
||||
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
|
||||
providerConfig.isImageGenerationEndpoint = true
|
||||
}
|
||||
@ -463,6 +469,11 @@ export default class ModernAiProvider {
|
||||
// 如果支持新的 AI SDK,使用现代化实现
|
||||
if (isModernSdkSupported(this.actualProvider)) {
|
||||
try {
|
||||
// 确保 config 已定义
|
||||
if (!this.config) {
|
||||
throw new Error('Provider config is undefined; cannot proceed with generateImage')
|
||||
}
|
||||
|
||||
// 确保本地provider已创建
|
||||
if (!this.localProvider) {
|
||||
this.localProvider = await createAiSdkProvider(this.config)
|
||||
|
||||
@ -4,6 +4,7 @@ import { loggerService } from '@logger'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import type { Provider as AiSdkProvider } from 'ai'
|
||||
|
||||
import type { AiSdkConfig } from '../types'
|
||||
import { initializeNewProviders } from './providerInitialization'
|
||||
|
||||
const logger = loggerService.withContext('ProviderFactory')
|
||||
@ -55,7 +56,7 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
|
||||
* 获取AI SDK Provider ID
|
||||
* 简化版:减少重复逻辑,利用通用解析函数
|
||||
*/
|
||||
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
|
||||
export function getAiSdkProviderId(provider: Provider): string {
|
||||
// 1. 尝试解析provider.id
|
||||
const resolvedFromId = tryResolveProviderId(provider.id)
|
||||
if (resolvedFromId) {
|
||||
@ -73,11 +74,11 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
|
||||
if (provider.apiHost.includes('api.openai.com')) {
|
||||
return 'openai-chat'
|
||||
}
|
||||
// 3. 最后的fallback(通常会成为openai-compatible)
|
||||
return provider.id as ProviderId
|
||||
// 3. 最后的fallback(使用provider本身的id)
|
||||
return provider.id
|
||||
}
|
||||
|
||||
export async function createAiSdkProvider(config) {
|
||||
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
|
||||
let localProvider: Awaited<AiSdkProvider> | null = null
|
||||
try {
|
||||
if (config.providerId === 'openai' && config.options?.mode === 'chat') {
|
||||
|
||||
@ -1,10 +1,4 @@
|
||||
import {
|
||||
formatPrivateKey,
|
||||
hasProviderConfig,
|
||||
ProviderConfigFactory,
|
||||
type ProviderId,
|
||||
type ProviderSettingsMap
|
||||
} from '@cherrystudio/ai-core/provider'
|
||||
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
|
||||
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
||||
import {
|
||||
getAwsBedrockAccessKeyId,
|
||||
@ -29,6 +23,7 @@ import {
|
||||
} from '@renderer/utils/provider'
|
||||
import { cloneDeep } from 'lodash'
|
||||
|
||||
import type { AiSdkConfig } from '../types'
|
||||
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
|
||||
import { COPILOT_DEFAULT_HEADERS } from './constants'
|
||||
import { getAiSdkProviderId } from './factory'
|
||||
@ -132,13 +127,7 @@ export function getActualProvider(model: Model): Provider {
|
||||
* 将 Provider 配置转换为新 AI SDK 格式
|
||||
* 简化版:利用新的别名映射系统
|
||||
*/
|
||||
export function providerToAiSdkConfig(
|
||||
actualProvider: Provider,
|
||||
model: Model
|
||||
): {
|
||||
providerId: ProviderId | 'openai-compatible'
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
} {
|
||||
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
|
||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||
|
||||
// 构建基础配置
|
||||
@ -238,7 +227,7 @@ export function providerToAiSdkConfig(
|
||||
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
||||
return {
|
||||
providerId: aiSdkProviderId as ProviderId,
|
||||
providerId: aiSdkProviderId,
|
||||
options
|
||||
}
|
||||
}
|
||||
|
||||
15
src/renderer/src/aiCore/types/index.ts
Normal file
15
src/renderer/src/aiCore/types/index.ts
Normal file
@ -0,0 +1,15 @@
|
||||
/**
|
||||
* This type definition file is only for renderer.
|
||||
* It cannot be migrated to @renderer/types since files within it are actually being used by both main and renderer.
|
||||
* If we do that, main would throw an error because it cannot import a module which imports a type from a browser-enviroment-only package.
|
||||
* (ai-core package is set as browser-enviroment-only)
|
||||
*
|
||||
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
|
||||
*/
|
||||
|
||||
import type { ProviderSettingsMap } from '@cherrystudio/ai-core/provider'
|
||||
|
||||
export type AiSdkConfig = {
|
||||
providerId: string
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
}
|
||||
@ -183,6 +183,7 @@ const PopupContainer: React.FC<Props> = ({ providerId, resolve }) => {
|
||||
setLoadingModels(true)
|
||||
try {
|
||||
const models = await fetchModels(provider)
|
||||
// TODO: More robust conversion
|
||||
const filteredModels = models
|
||||
.map((model) => ({
|
||||
// @ts-ignore modelId
|
||||
|
||||
Loading…
Reference in New Issue
Block a user