refactor: aisdk config (#11402)

* refactor: improve model filtering with todo for robust conversion

* refactor(aiCore): add AiSdkConfig type and update provider config handling

- Introduce new AiSdkConfig type in aiCoreTypes for better type safety
- Update provider factory and config to use AiSdkConfig consistently
- Simplify getAiSdkProviderId return type to string
- Add config validation in ModernAiProvider

* refactor(aiCore): move ai core types to dedicated module

Consolidate AI core type definitions into a dedicated module under aiCore/types. This improves code organization by keeping related types together and removes circular dependencies between modules. The change includes:
- Moving AiSdkConfig to aiCore/types
- Updating all imports to reference the new location
- Removing duplicate type definitions

* refactor(provider): add return type to createAiSdkProvider function
This commit is contained in:
Phantom 2025-11-23 21:12:57 +08:00 committed by GitHub
parent 49903a1567
commit fa361126b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 37 additions and 20 deletions

View File

@ -32,6 +32,7 @@ import {
prepareSpecialProviderConfig, prepareSpecialProviderConfig,
providerToAiSdkConfig providerToAiSdkConfig
} from './provider/providerConfig' } from './provider/providerConfig'
import type { AiSdkConfig } from './types'
const logger = loggerService.withContext('ModernAiProvider') const logger = loggerService.withContext('ModernAiProvider')
@ -44,7 +45,7 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
export default class ModernAiProvider { export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider private legacyProvider: LegacyAiProvider
private config?: ReturnType<typeof providerToAiSdkConfig> private config?: AiSdkConfig
private actualProvider: Provider private actualProvider: Provider
private model?: Model private model?: Model
private localProvider: Awaited<AiSdkProvider> | null = null private localProvider: Awaited<AiSdkProvider> | null = null
@ -89,6 +90,11 @@ export default class ModernAiProvider {
// 每次请求时重新生成配置以确保API key轮换生效 // 每次请求时重新生成配置以确保API key轮换生效
this.config = providerToAiSdkConfig(this.actualProvider, this.model) this.config = providerToAiSdkConfig(this.actualProvider, this.model)
logger.debug('Generated provider config for completions', this.config) logger.debug('Generated provider config for completions', this.config)
// 检查 config 是否存在
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with completions')
}
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) { if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
providerConfig.isImageGenerationEndpoint = true providerConfig.isImageGenerationEndpoint = true
} }
@ -463,6 +469,11 @@ export default class ModernAiProvider {
// 如果支持新的 AI SDK使用现代化实现 // 如果支持新的 AI SDK使用现代化实现
if (isModernSdkSupported(this.actualProvider)) { if (isModernSdkSupported(this.actualProvider)) {
try { try {
// 确保 config 已定义
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with generateImage')
}
// 确保本地provider已创建 // 确保本地provider已创建
if (!this.localProvider) { if (!this.localProvider) {
this.localProvider = await createAiSdkProvider(this.config) this.localProvider = await createAiSdkProvider(this.config)

View File

@ -4,6 +4,7 @@ import { loggerService } from '@logger'
import type { Provider } from '@renderer/types' import type { Provider } from '@renderer/types'
import type { Provider as AiSdkProvider } from 'ai' import type { Provider as AiSdkProvider } from 'ai'
import type { AiSdkConfig } from '../types'
import { initializeNewProviders } from './providerInitialization' import { initializeNewProviders } from './providerInitialization'
const logger = loggerService.withContext('ProviderFactory') const logger = loggerService.withContext('ProviderFactory')
@ -55,7 +56,7 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
* AI SDK Provider ID * AI SDK Provider ID
* *
*/ */
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' { export function getAiSdkProviderId(provider: Provider): string {
// 1. 尝试解析provider.id // 1. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id) const resolvedFromId = tryResolveProviderId(provider.id)
if (resolvedFromId) { if (resolvedFromId) {
@ -73,11 +74,11 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
if (provider.apiHost.includes('api.openai.com')) { if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat' return 'openai-chat'
} }
// 3. 最后的fallback通常会成为openai-compatible // 3. 最后的fallback使用provider本身的id
return provider.id as ProviderId return provider.id
} }
export async function createAiSdkProvider(config) { export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
let localProvider: Awaited<AiSdkProvider> | null = null let localProvider: Awaited<AiSdkProvider> | null = null
try { try {
if (config.providerId === 'openai' && config.options?.mode === 'chat') { if (config.providerId === 'openai' && config.options?.mode === 'chat') {

View File

@ -1,10 +1,4 @@
import { import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
formatPrivateKey,
hasProviderConfig,
ProviderConfigFactory,
type ProviderId,
type ProviderSettingsMap
} from '@cherrystudio/ai-core/provider'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models' import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
import { import {
getAwsBedrockAccessKeyId, getAwsBedrockAccessKeyId,
@ -29,6 +23,7 @@ import {
} from '@renderer/utils/provider' } from '@renderer/utils/provider'
import { cloneDeep } from 'lodash' import { cloneDeep } from 'lodash'
import type { AiSdkConfig } from '../types'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config' import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { COPILOT_DEFAULT_HEADERS } from './constants' import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory' import { getAiSdkProviderId } from './factory'
@ -132,13 +127,7 @@ export function getActualProvider(model: Model): Provider {
* Provider AI SDK * Provider AI SDK
* *
*/ */
export function providerToAiSdkConfig( export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
actualProvider: Provider,
model: Model
): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
const aiSdkProviderId = getAiSdkProviderId(actualProvider) const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// 构建基础配置 // 构建基础配置
@ -238,7 +227,7 @@ export function providerToAiSdkConfig(
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions) const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return { return {
providerId: aiSdkProviderId as ProviderId, providerId: aiSdkProviderId,
options options
} }
} }

View File

@ -0,0 +1,15 @@
/**
* This type definition file is only for renderer.
* It cannot be migrated to @renderer/types since files within it are actually being used by both main and renderer.
* If we do that, main would throw an error because it cannot import a module which imports a type from a browser-enviroment-only package.
* (ai-core package is set as browser-enviroment-only)
*
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
*/
import type { ProviderSettingsMap } from '@cherrystudio/ai-core/provider'
export type AiSdkConfig = {
providerId: string
options: ProviderSettingsMap[keyof ProviderSettingsMap]
}

View File

@ -183,6 +183,7 @@ const PopupContainer: React.FC<Props> = ({ providerId, resolve }) => {
setLoadingModels(true) setLoadingModels(true)
try { try {
const models = await fetchModels(provider) const models = await fetchModels(provider)
// TODO: More robust conversion
const filteredModels = models const filteredModels = models
.map((model) => ({ .map((model) => ({
// @ts-ignore modelId // @ts-ignore modelId