refactor(modelResolver): replace ':' with '|' as the default separator for model IDs

Updated the ModelResolver and related components to use '|' as the default separator instead of ':'. This change improves compatibility and resolves potential conflicts with model ID suffixes. Adjusted model resolution logic accordingly to ensure consistent behavior across the application.
This commit is contained in:
suyao 2025-08-29 04:12:48 +08:00
parent 9551c49452
commit 3004f84be3
No known key found for this signature in database
7 changed files with 72 additions and 54 deletions

View File

@ -9,7 +9,7 @@ import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, LanguageModelV2Middlew
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model' import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
import { wrapModelWithMiddlewares } from '../middleware/wrapper' import { wrapModelWithMiddlewares } from '../middleware/wrapper'
import { globalRegistryManagement } from '../providers/RegistryManagement' import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement'
export class ModelResolver { export class ModelResolver {
/** /**
@ -39,7 +39,7 @@ export class ModelResolver {
} }
// 检查是否是命名空间格式 // 检查是否是命名空间格式
if (modelId.includes(':')) { if (modelId.includes(DEFAULT_SEPARATOR)) {
model = this.resolveNamespacedModel(modelId) model = this.resolveNamespacedModel(modelId)
} else { } else {
// 传统格式:使用处理后的 providerId + modelId // 传统格式:使用处理后的 providerId + modelId
@ -58,7 +58,7 @@ export class ModelResolver {
* *
*/ */
async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise<EmbeddingModelV2<string>> { async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise<EmbeddingModelV2<string>> {
if (modelId.includes(':')) { if (modelId.includes(DEFAULT_SEPARATOR)) {
return this.resolveNamespacedEmbeddingModel(modelId) return this.resolveNamespacedEmbeddingModel(modelId)
} }
@ -69,7 +69,7 @@ export class ModelResolver {
* *
*/ */
async resolveImageModel(modelId: string, fallbackProviderId: string): Promise<ImageModelV2> { async resolveImageModel(modelId: string, fallbackProviderId: string): Promise<ImageModelV2> {
if (modelId.includes(':')) { if (modelId.includes(DEFAULT_SEPARATOR)) {
return this.resolveNamespacedImageModel(modelId) return this.resolveNamespacedImageModel(modelId)
} }
@ -89,7 +89,7 @@ export class ModelResolver {
* providerId: 'openai', modelId: 'gpt-4' -> globalRegistryManagement.languageModel('openai:gpt-4') * providerId: 'openai', modelId: 'gpt-4' -> globalRegistryManagement.languageModel('openai:gpt-4')
*/ */
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 { private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
const fullModelId = `${providerId}:${modelId}` const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
return globalRegistryManagement.languageModel(fullModelId as any) return globalRegistryManagement.languageModel(fullModelId as any)
} }
@ -104,7 +104,7 @@ export class ModelResolver {
* *
*/ */
private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV2<string> { private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV2<string> {
const fullModelId = `${providerId}:${modelId}` const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
return globalRegistryManagement.textEmbeddingModel(fullModelId as any) return globalRegistryManagement.textEmbeddingModel(fullModelId as any)
} }
@ -119,7 +119,7 @@ export class ModelResolver {
* *
*/ */
private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV2 { private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV2 {
const fullModelId = `${providerId}:${modelId}` const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
return globalRegistryManagement.imageModel(fullModelId as any) return globalRegistryManagement.imageModel(fullModelId as any)
} }
} }

View File

@ -5,10 +5,11 @@
* 例如: aihubmix:anthropic:claude-3.5-sonnet * 例如: aihubmix:anthropic:claude-3.5-sonnet
*/ */
import { EmbeddingModelV2, ImageModelV2, ProviderV2, SpeechModelV2, TranscriptionModelV2 } from '@ai-sdk/provider' import { ProviderV2 } from '@ai-sdk/provider'
import { customProvider } from 'ai' import { customProvider } from 'ai'
import { globalRegistryManagement } from './RegistryManagement' import { globalRegistryManagement } from './RegistryManagement'
import type { AiSdkMethodName, AiSdkModelReturn, AiSdkModelType } from './types'
export interface HubProviderConfig { export interface HubProviderConfig {
/** Hub的唯一标识符 */ /** Hub的唯一标识符 */
@ -52,7 +53,7 @@ export function createHubProvider(config: HubProviderConfig): ProviderV2 {
function getTargetProvider(providerId: string): ProviderV2 { function getTargetProvider(providerId: string): ProviderV2 {
// 从全局注册表获取provider实例 // 从全局注册表获取provider实例
try { try {
const provider = (globalRegistryManagement as any).getProvider(providerId) const provider = globalRegistryManagement.getProvider(providerId)
if (!provider) { if (!provider) {
throw new HubProviderError( throw new HubProviderError(
`Provider "${providerId}" is not initialized. Please call initializeProvider("${providerId}", options) first.`, `Provider "${providerId}" is not initialized. Please call initializeProvider("${providerId}", options) first.`,
@ -71,26 +72,30 @@ export function createHubProvider(config: HubProviderConfig): ProviderV2 {
} }
} }
function resolveModel<T>(modelId: string, modelType: string, methodName: keyof ProviderV2): T { function resolveModel<T extends AiSdkModelType>(
modelId: string,
modelType: T,
methodName: AiSdkMethodName<T>
): AiSdkModelReturn<T> {
const { provider, actualModelId } = parseHubModelId(modelId) const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider) const targetProvider = getTargetProvider(provider)
if (!targetProvider[methodName]) { const fn = targetProvider[methodName] as (id: string) => AiSdkModelReturn<T>
if (!fn) {
throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider) throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider)
} }
return (targetProvider[methodName] as any)(actualModelId) return fn(actualModelId)
} }
return customProvider({ return customProvider({
fallbackProvider: { fallbackProvider: {
languageModel: (modelId: string) => resolveModel(modelId, 'language models', 'languageModel'), languageModel: (modelId: string) => resolveModel(modelId, 'text', 'languageModel'),
textEmbeddingModel: (modelId: string) => textEmbeddingModel: (modelId: string) => resolveModel(modelId, 'embedding', 'textEmbeddingModel'),
resolveModel<EmbeddingModelV2<string>>(modelId, 'text embedding models', 'textEmbeddingModel'), imageModel: (modelId: string) => resolveModel(modelId, 'image', 'imageModel'),
imageModel: (modelId: string) => resolveModel<ImageModelV2>(modelId, 'image models', 'imageModel'), transcriptionModel: (modelId: string) => resolveModel(modelId, 'transcription', 'transcriptionModel'),
transcriptionModel: (modelId: string) => speechModel: (modelId: string) => resolveModel(modelId, 'speech', 'speechModel')
resolveModel<TranscriptionModelV2>(modelId, 'transcription models', 'transcriptionModel'),
speechModel: (modelId: string) => resolveModel<SpeechModelV2>(modelId, 'speech models', 'speechModel')
} }
}) })
} }

View File

@ -9,7 +9,7 @@ import { createProviderRegistry, type ProviderRegistryProvider } from 'ai'
type PROVIDERS = Record<string, ProviderV2> type PROVIDERS = Record<string, ProviderV2>
export const DEFAULT_SEPARATOR = ':' export const DEFAULT_SEPARATOR = '|'
// export type MODEL_ID = `${string}${typeof DEFAULT_SEPARATOR}${string}` // export type MODEL_ID = `${string}${typeof DEFAULT_SEPARATOR}${string}`
@ -216,5 +216,6 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
/** /**
* *
* 使 | 因为 : 会和 :free suffix冲突
*/ */
export const globalRegistryManagement = new RegistryManagement<':'>() export const globalRegistryManagement = new RegistryManagement()

View File

@ -4,6 +4,14 @@ import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google' import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
import { type OpenAIProviderSettings } from '@ai-sdk/openai' import { type OpenAIProviderSettings } from '@ai-sdk/openai'
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible' import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
import {
EmbeddingModelV2 as EmbeddingModel,
ImageModelV2 as ImageModel,
LanguageModelV2 as LanguageModel,
ProviderV2,
SpeechModelV2 as SpeechModel,
TranscriptionModelV2 as TranscriptionModel
} from '@ai-sdk/provider'
import { type XaiProviderSettings } from '@ai-sdk/xai' import { type XaiProviderSettings } from '@ai-sdk/xai'
// 导入基于 Zod 的 ProviderId 类型 // 导入基于 Zod 的 ProviderId 类型
@ -29,31 +37,6 @@ export interface DynamicProviderRegistry {
// 合并基础和动态provider类型 // 合并基础和动态provider类型
export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry
/**
* Provider
* 使 AI SDK
*/
// Provider 配置接口 - 支持灵活的创建方式
// export interface ProviderConfig {
// id: string
// name: string
// // 方式一:直接提供 creator 函数(推荐用于自定义)
// creator?: (options: any) => any
// // 方式二:动态导入 + 函数名(用于包导入)
// import?: () => Promise<any>
// creatorFunctionName?: string
// // 图片生成支持
// supportsImageGeneration?: boolean
// imageCreator?: (options: any) => any
// // 可选的验证函数
// validateOptions?: (options: any) => boolean
// }
// 错误类型 // 错误类型
export class ProviderError extends Error { export class ProviderError extends Error {
constructor( constructor(
@ -85,4 +68,29 @@ export type {
OpenAIProviderSettings, OpenAIProviderSettings,
XaiProviderSettings XaiProviderSettings
} }
// 新的provider类型已经在上面直接export不需要重复导出
export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel<string> | TranscriptionModel | SpeechModel
export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech'
export const METHOD_MAP = {
text: 'languageModel',
image: 'imageModel',
embedding: 'textEmbeddingModel',
transcription: 'transcriptionModel',
speech: 'speechModel'
} as const satisfies Record<AiSdkModelType, keyof ProviderV2>
export type AiSdkModelMethodMap = Record<AiSdkModelType, keyof ProviderV2>
export type AiSdkModelReturnMap = {
text: LanguageModel
image: ImageModel
embedding: EmbeddingModel<string>
transcription: TranscriptionModel
speech: SpeechModel
}
export type AiSdkMethodName<T extends AiSdkModelType> = (typeof METHOD_MAP)[T]
export type AiSdkModelReturn<T extends AiSdkModelType> = AiSdkModelReturnMap[T]

View File

@ -6,6 +6,7 @@ import {
} from '@cherrystudio/ai-core/provider' } from '@cherrystudio/ai-core/provider'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService' import { getProviderByModel } from '@renderer/services/AssistantService'
import { loggerService } from '@renderer/services/LoggerService'
import type { Model, Provider } from '@renderer/types' import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api' import { formatApiHost } from '@renderer/utils/api'
import { cloneDeep } from 'lodash' import { cloneDeep } from 'lodash'
@ -13,6 +14,8 @@ import { cloneDeep } from 'lodash'
import { aihubmixProviderCreator, newApiResolverCreator } from './config' import { aihubmixProviderCreator, newApiResolverCreator } from './config'
import { getAiSdkProviderId } from './factory' import { getAiSdkProviderId } from './factory'
const logger = loggerService.withContext('ProviderConfigProcessor')
/** /**
* provider的转换逻辑 * provider的转换逻辑
*/ */
@ -70,6 +73,7 @@ export function providerToAiSdkConfig(actualProvider: Provider): {
options: ProviderSettingsMap[keyof ProviderSettingsMap] options: ProviderSettingsMap[keyof ProviderSettingsMap]
} { } {
const aiSdkProviderId = getAiSdkProviderId(actualProvider) const aiSdkProviderId = getAiSdkProviderId(actualProvider)
logger.debug('providerToAiSdkConfig', { aiSdkProviderId })
// 构建基础配置 // 构建基础配置
const baseConfig = { const baseConfig = {

View File

@ -60,11 +60,13 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
} }
// 2. 尝试解析provider.type // 2. 尝试解析provider.type
const resolvedFromType = tryResolveProviderId(provider.type) // 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
if (resolvedFromType) { if (provider.type !== 'openai') {
return resolvedFromType const resolvedFromType = tryResolveProviderId(provider.type)
if (resolvedFromType) {
return resolvedFromType
}
} }
// 3. 最后的fallback通常会成为openai-compatible // 3. 最后的fallback通常会成为openai-compatible
return provider.id as ProviderId return provider.id as ProviderId
} }

View File

@ -7,9 +7,7 @@ const logger = loggerService.withContext('ProviderConfigs')
* Provider配置定义 * Provider配置定义
* AI Providers * AI Providers
*/ */
export const NEW_PROVIDER_CONFIGS: (ProviderConfig & { export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
mappings?: Record<string, string>
})[] = [
{ {
id: 'openrouter', id: 'openrouter',
name: 'OpenRouter', name: 'OpenRouter',