refactor: aisdk config (#11402)

* refactor: improve model filtering with todo for robust conversion

* refactor(aiCore): add AiSdkConfig type and update provider config handling

- Introduce new AiSdkConfig type in aiCoreTypes for better type safety
- Update provider factory and config to use AiSdkConfig consistently
- Simplify getAiSdkProviderId return type to string
- Add config validation in ModernAiProvider

* refactor(aiCore): move ai core types to dedicated module

Consolidate AI core type definitions into a dedicated module under aiCore/types. This improves code organization by keeping related types together and removes circular dependencies between modules. The change includes:
- Moving AiSdkConfig to aiCore/types
- Updating all imports to reference the new location
- Removing duplicate type definitions

* refactor(provider): add return type to createAiSdkProvider function
This commit is contained in:
Phantom 2025-11-23 21:12:57 +08:00 committed by GitHub
parent 49903a1567
commit fa361126b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 37 additions and 20 deletions

View File

@ -32,6 +32,7 @@ import {
prepareSpecialProviderConfig,
providerToAiSdkConfig
} from './provider/providerConfig'
import type { AiSdkConfig } from './types'
const logger = loggerService.withContext('ModernAiProvider')
@ -44,7 +45,7 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
private config?: ReturnType<typeof providerToAiSdkConfig>
private config?: AiSdkConfig
private actualProvider: Provider
private model?: Model
private localProvider: Awaited<AiSdkProvider> | null = null
@ -89,6 +90,11 @@ export default class ModernAiProvider {
// 每次请求时重新生成配置以确保API key轮换生效
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
logger.debug('Generated provider config for completions', this.config)
// 检查 config 是否存在
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with completions')
}
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
providerConfig.isImageGenerationEndpoint = true
}
@ -463,6 +469,11 @@ export default class ModernAiProvider {
// 如果支持新的 AI SDK使用现代化实现
if (isModernSdkSupported(this.actualProvider)) {
try {
// 确保 config 已定义
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with generateImage')
}
// 确保本地provider已创建
if (!this.localProvider) {
this.localProvider = await createAiSdkProvider(this.config)

View File

@ -4,6 +4,7 @@ import { loggerService } from '@logger'
import type { Provider } from '@renderer/types'
import type { Provider as AiSdkProvider } from 'ai'
import type { AiSdkConfig } from '../types'
import { initializeNewProviders } from './providerInitialization'
const logger = loggerService.withContext('ProviderFactory')
@ -55,7 +56,7 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
* AI SDK Provider ID
*
*/
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
export function getAiSdkProviderId(provider: Provider): string {
// 1. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (resolvedFromId) {
@ -73,11 +74,11 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 3. 最后的fallback通常会成为openai-compatible
return provider.id as ProviderId
// 3. 最后的fallback使用provider本身的id
return provider.id
}
export async function createAiSdkProvider(config) {
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
let localProvider: Awaited<AiSdkProvider> | null = null
try {
if (config.providerId === 'openai' && config.options?.mode === 'chat') {

View File

@ -1,10 +1,4 @@
import {
formatPrivateKey,
hasProviderConfig,
ProviderConfigFactory,
type ProviderId,
type ProviderSettingsMap
} from '@cherrystudio/ai-core/provider'
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
import {
getAwsBedrockAccessKeyId,
@ -29,6 +23,7 @@ import {
} from '@renderer/utils/provider'
import { cloneDeep } from 'lodash'
import type { AiSdkConfig } from '../types'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory'
@ -132,13 +127,7 @@ export function getActualProvider(model: Model): Provider {
* Provider AI SDK
*
*/
export function providerToAiSdkConfig(
actualProvider: Provider,
model: Model
): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// 构建基础配置
@ -238,7 +227,7 @@ export function providerToAiSdkConfig(
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId as ProviderId,
providerId: aiSdkProviderId,
options
}
}

View File

@ -0,0 +1,15 @@
/**
* This type definition file is only for renderer.
* It cannot be migrated to @renderer/types since files within it are actually being used by both main and renderer.
* If we do that, main would throw an error because it cannot import a module which imports a type from a browser-enviroment-only package.
* (ai-core package is set as browser-enviroment-only)
*
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
*/
import type { ProviderSettingsMap } from '@cherrystudio/ai-core/provider'
export type AiSdkConfig = {
providerId: string
options: ProviderSettingsMap[keyof ProviderSettingsMap]
}

View File

@ -183,6 +183,7 @@ const PopupContainer: React.FC<Props> = ({ providerId, resolve }) => {
setLoadingModels(true)
try {
const models = await fetchModels(provider)
// TODO: More robust conversion
const filteredModels = models
.map((model) => ({
// @ts-ignore modelId