mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-11 16:39:15 +08:00
refactor(modelResolver): replace ':' with '|' as the default separator for model IDs
Updated the ModelResolver and related components to use '|' as the default separator instead of ':'. This change improves compatibility and resolves potential conflicts with model ID suffixes. Adjusted model resolution logic accordingly to ensure consistent behavior across the application.
This commit is contained in:
parent
9551c49452
commit
3004f84be3
@ -9,7 +9,7 @@ import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, LanguageModelV2Middlew
|
|||||||
|
|
||||||
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
|
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
|
||||||
import { wrapModelWithMiddlewares } from '../middleware/wrapper'
|
import { wrapModelWithMiddlewares } from '../middleware/wrapper'
|
||||||
import { globalRegistryManagement } from '../providers/RegistryManagement'
|
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement'
|
||||||
|
|
||||||
export class ModelResolver {
|
export class ModelResolver {
|
||||||
/**
|
/**
|
||||||
@ -39,7 +39,7 @@ export class ModelResolver {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否是命名空间格式
|
// 检查是否是命名空间格式
|
||||||
if (modelId.includes(':')) {
|
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||||
model = this.resolveNamespacedModel(modelId)
|
model = this.resolveNamespacedModel(modelId)
|
||||||
} else {
|
} else {
|
||||||
// 传统格式:使用处理后的 providerId + modelId
|
// 传统格式:使用处理后的 providerId + modelId
|
||||||
@ -58,7 +58,7 @@ export class ModelResolver {
|
|||||||
* 解析文本嵌入模型
|
* 解析文本嵌入模型
|
||||||
*/
|
*/
|
||||||
async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise<EmbeddingModelV2<string>> {
|
async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise<EmbeddingModelV2<string>> {
|
||||||
if (modelId.includes(':')) {
|
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||||
return this.resolveNamespacedEmbeddingModel(modelId)
|
return this.resolveNamespacedEmbeddingModel(modelId)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ export class ModelResolver {
|
|||||||
* 解析图像模型
|
* 解析图像模型
|
||||||
*/
|
*/
|
||||||
async resolveImageModel(modelId: string, fallbackProviderId: string): Promise<ImageModelV2> {
|
async resolveImageModel(modelId: string, fallbackProviderId: string): Promise<ImageModelV2> {
|
||||||
if (modelId.includes(':')) {
|
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||||
return this.resolveNamespacedImageModel(modelId)
|
return this.resolveNamespacedImageModel(modelId)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,7 +89,7 @@ export class ModelResolver {
|
|||||||
* providerId: 'openai', modelId: 'gpt-4' -> globalRegistryManagement.languageModel('openai:gpt-4')
|
* providerId: 'openai', modelId: 'gpt-4' -> globalRegistryManagement.languageModel('openai:gpt-4')
|
||||||
*/
|
*/
|
||||||
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
|
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
|
||||||
const fullModelId = `${providerId}:${modelId}`
|
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||||
return globalRegistryManagement.languageModel(fullModelId as any)
|
return globalRegistryManagement.languageModel(fullModelId as any)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ export class ModelResolver {
|
|||||||
* 解析传统格式的嵌入模型
|
* 解析传统格式的嵌入模型
|
||||||
*/
|
*/
|
||||||
private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV2<string> {
|
private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV2<string> {
|
||||||
const fullModelId = `${providerId}:${modelId}`
|
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||||
return globalRegistryManagement.textEmbeddingModel(fullModelId as any)
|
return globalRegistryManagement.textEmbeddingModel(fullModelId as any)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,7 +119,7 @@ export class ModelResolver {
|
|||||||
* 解析传统格式的图像模型
|
* 解析传统格式的图像模型
|
||||||
*/
|
*/
|
||||||
private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV2 {
|
private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV2 {
|
||||||
const fullModelId = `${providerId}:${modelId}`
|
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||||
return globalRegistryManagement.imageModel(fullModelId as any)
|
return globalRegistryManagement.imageModel(fullModelId as any)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,10 +5,11 @@
|
|||||||
* 例如: aihubmix:anthropic:claude-3.5-sonnet
|
* 例如: aihubmix:anthropic:claude-3.5-sonnet
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { EmbeddingModelV2, ImageModelV2, ProviderV2, SpeechModelV2, TranscriptionModelV2 } from '@ai-sdk/provider'
|
import { ProviderV2 } from '@ai-sdk/provider'
|
||||||
import { customProvider } from 'ai'
|
import { customProvider } from 'ai'
|
||||||
|
|
||||||
import { globalRegistryManagement } from './RegistryManagement'
|
import { globalRegistryManagement } from './RegistryManagement'
|
||||||
|
import type { AiSdkMethodName, AiSdkModelReturn, AiSdkModelType } from './types'
|
||||||
|
|
||||||
export interface HubProviderConfig {
|
export interface HubProviderConfig {
|
||||||
/** Hub的唯一标识符 */
|
/** Hub的唯一标识符 */
|
||||||
@ -52,7 +53,7 @@ export function createHubProvider(config: HubProviderConfig): ProviderV2 {
|
|||||||
function getTargetProvider(providerId: string): ProviderV2 {
|
function getTargetProvider(providerId: string): ProviderV2 {
|
||||||
// 从全局注册表获取provider实例
|
// 从全局注册表获取provider实例
|
||||||
try {
|
try {
|
||||||
const provider = (globalRegistryManagement as any).getProvider(providerId)
|
const provider = globalRegistryManagement.getProvider(providerId)
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
throw new HubProviderError(
|
throw new HubProviderError(
|
||||||
`Provider "${providerId}" is not initialized. Please call initializeProvider("${providerId}", options) first.`,
|
`Provider "${providerId}" is not initialized. Please call initializeProvider("${providerId}", options) first.`,
|
||||||
@ -71,26 +72,30 @@ export function createHubProvider(config: HubProviderConfig): ProviderV2 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function resolveModel<T>(modelId: string, modelType: string, methodName: keyof ProviderV2): T {
|
function resolveModel<T extends AiSdkModelType>(
|
||||||
|
modelId: string,
|
||||||
|
modelType: T,
|
||||||
|
methodName: AiSdkMethodName<T>
|
||||||
|
): AiSdkModelReturn<T> {
|
||||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||||
const targetProvider = getTargetProvider(provider)
|
const targetProvider = getTargetProvider(provider)
|
||||||
|
|
||||||
if (!targetProvider[methodName]) {
|
const fn = targetProvider[methodName] as (id: string) => AiSdkModelReturn<T>
|
||||||
|
|
||||||
|
if (!fn) {
|
||||||
throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider)
|
throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider)
|
||||||
}
|
}
|
||||||
|
|
||||||
return (targetProvider[methodName] as any)(actualModelId)
|
return fn(actualModelId)
|
||||||
}
|
}
|
||||||
|
|
||||||
return customProvider({
|
return customProvider({
|
||||||
fallbackProvider: {
|
fallbackProvider: {
|
||||||
languageModel: (modelId: string) => resolveModel(modelId, 'language models', 'languageModel'),
|
languageModel: (modelId: string) => resolveModel(modelId, 'text', 'languageModel'),
|
||||||
textEmbeddingModel: (modelId: string) =>
|
textEmbeddingModel: (modelId: string) => resolveModel(modelId, 'embedding', 'textEmbeddingModel'),
|
||||||
resolveModel<EmbeddingModelV2<string>>(modelId, 'text embedding models', 'textEmbeddingModel'),
|
imageModel: (modelId: string) => resolveModel(modelId, 'image', 'imageModel'),
|
||||||
imageModel: (modelId: string) => resolveModel<ImageModelV2>(modelId, 'image models', 'imageModel'),
|
transcriptionModel: (modelId: string) => resolveModel(modelId, 'transcription', 'transcriptionModel'),
|
||||||
transcriptionModel: (modelId: string) =>
|
speechModel: (modelId: string) => resolveModel(modelId, 'speech', 'speechModel')
|
||||||
resolveModel<TranscriptionModelV2>(modelId, 'transcription models', 'transcriptionModel'),
|
|
||||||
speechModel: (modelId: string) => resolveModel<SpeechModelV2>(modelId, 'speech models', 'speechModel')
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import { createProviderRegistry, type ProviderRegistryProvider } from 'ai'
|
|||||||
|
|
||||||
type PROVIDERS = Record<string, ProviderV2>
|
type PROVIDERS = Record<string, ProviderV2>
|
||||||
|
|
||||||
export const DEFAULT_SEPARATOR = ':'
|
export const DEFAULT_SEPARATOR = '|'
|
||||||
|
|
||||||
// export type MODEL_ID = `${string}${typeof DEFAULT_SEPARATOR}${string}`
|
// export type MODEL_ID = `${string}${typeof DEFAULT_SEPARATOR}${string}`
|
||||||
|
|
||||||
@ -216,5 +216,6 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 全局注册表管理器实例
|
* 全局注册表管理器实例
|
||||||
|
* 使用 | 作为分隔符,因为 : 会和 :free 等suffix冲突
|
||||||
*/
|
*/
|
||||||
export const globalRegistryManagement = new RegistryManagement<':'>()
|
export const globalRegistryManagement = new RegistryManagement()
|
||||||
|
|||||||
@ -4,6 +4,14 @@ import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
|
|||||||
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
|
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
|
||||||
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
|
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||||
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
|
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
|
||||||
|
import {
|
||||||
|
EmbeddingModelV2 as EmbeddingModel,
|
||||||
|
ImageModelV2 as ImageModel,
|
||||||
|
LanguageModelV2 as LanguageModel,
|
||||||
|
ProviderV2,
|
||||||
|
SpeechModelV2 as SpeechModel,
|
||||||
|
TranscriptionModelV2 as TranscriptionModel
|
||||||
|
} from '@ai-sdk/provider'
|
||||||
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
||||||
|
|
||||||
// 导入基于 Zod 的 ProviderId 类型
|
// 导入基于 Zod 的 ProviderId 类型
|
||||||
@ -29,31 +37,6 @@ export interface DynamicProviderRegistry {
|
|||||||
// 合并基础和动态provider类型
|
// 合并基础和动态provider类型
|
||||||
export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry
|
export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry
|
||||||
|
|
||||||
/**
|
|
||||||
* Provider 相关核心类型定义
|
|
||||||
* 只定义必要的接口,其他类型直接使用 AI SDK
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Provider 配置接口 - 支持灵活的创建方式
|
|
||||||
// export interface ProviderConfig {
|
|
||||||
// id: string
|
|
||||||
// name: string
|
|
||||||
|
|
||||||
// // 方式一:直接提供 creator 函数(推荐用于自定义)
|
|
||||||
// creator?: (options: any) => any
|
|
||||||
|
|
||||||
// // 方式二:动态导入 + 函数名(用于包导入)
|
|
||||||
// import?: () => Promise<any>
|
|
||||||
// creatorFunctionName?: string
|
|
||||||
|
|
||||||
// // 图片生成支持
|
|
||||||
// supportsImageGeneration?: boolean
|
|
||||||
// imageCreator?: (options: any) => any
|
|
||||||
|
|
||||||
// // 可选的验证函数
|
|
||||||
// validateOptions?: (options: any) => boolean
|
|
||||||
// }
|
|
||||||
|
|
||||||
// 错误类型
|
// 错误类型
|
||||||
export class ProviderError extends Error {
|
export class ProviderError extends Error {
|
||||||
constructor(
|
constructor(
|
||||||
@ -85,4 +68,29 @@ export type {
|
|||||||
OpenAIProviderSettings,
|
OpenAIProviderSettings,
|
||||||
XaiProviderSettings
|
XaiProviderSettings
|
||||||
}
|
}
|
||||||
// 新的provider类型已经在上面直接export,不需要重复导出
|
|
||||||
|
export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel<string> | TranscriptionModel | SpeechModel
|
||||||
|
|
||||||
|
export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech'
|
||||||
|
|
||||||
|
export const METHOD_MAP = {
|
||||||
|
text: 'languageModel',
|
||||||
|
image: 'imageModel',
|
||||||
|
embedding: 'textEmbeddingModel',
|
||||||
|
transcription: 'transcriptionModel',
|
||||||
|
speech: 'speechModel'
|
||||||
|
} as const satisfies Record<AiSdkModelType, keyof ProviderV2>
|
||||||
|
|
||||||
|
export type AiSdkModelMethodMap = Record<AiSdkModelType, keyof ProviderV2>
|
||||||
|
|
||||||
|
export type AiSdkModelReturnMap = {
|
||||||
|
text: LanguageModel
|
||||||
|
image: ImageModel
|
||||||
|
embedding: EmbeddingModel<string>
|
||||||
|
transcription: TranscriptionModel
|
||||||
|
speech: SpeechModel
|
||||||
|
}
|
||||||
|
|
||||||
|
export type AiSdkMethodName<T extends AiSdkModelType> = (typeof METHOD_MAP)[T]
|
||||||
|
|
||||||
|
export type AiSdkModelReturn<T extends AiSdkModelType> = AiSdkModelReturnMap[T]
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import {
|
|||||||
} from '@cherrystudio/ai-core/provider'
|
} from '@cherrystudio/ai-core/provider'
|
||||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
|
import { loggerService } from '@renderer/services/LoggerService'
|
||||||
import type { Model, Provider } from '@renderer/types'
|
import type { Model, Provider } from '@renderer/types'
|
||||||
import { formatApiHost } from '@renderer/utils/api'
|
import { formatApiHost } from '@renderer/utils/api'
|
||||||
import { cloneDeep } from 'lodash'
|
import { cloneDeep } from 'lodash'
|
||||||
@ -13,6 +14,8 @@ import { cloneDeep } from 'lodash'
|
|||||||
import { aihubmixProviderCreator, newApiResolverCreator } from './config'
|
import { aihubmixProviderCreator, newApiResolverCreator } from './config'
|
||||||
import { getAiSdkProviderId } from './factory'
|
import { getAiSdkProviderId } from './factory'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('ProviderConfigProcessor')
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 处理特殊provider的转换逻辑
|
* 处理特殊provider的转换逻辑
|
||||||
*/
|
*/
|
||||||
@ -70,6 +73,7 @@ export function providerToAiSdkConfig(actualProvider: Provider): {
|
|||||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||||
} {
|
} {
|
||||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||||
|
logger.debug('providerToAiSdkConfig', { aiSdkProviderId })
|
||||||
|
|
||||||
// 构建基础配置
|
// 构建基础配置
|
||||||
const baseConfig = {
|
const baseConfig = {
|
||||||
|
|||||||
@ -60,11 +60,13 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. 尝试解析provider.type
|
// 2. 尝试解析provider.type
|
||||||
const resolvedFromType = tryResolveProviderId(provider.type)
|
// 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
|
||||||
if (resolvedFromType) {
|
if (provider.type !== 'openai') {
|
||||||
return resolvedFromType
|
const resolvedFromType = tryResolveProviderId(provider.type)
|
||||||
|
if (resolvedFromType) {
|
||||||
|
return resolvedFromType
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 最后的fallback(通常会成为openai-compatible)
|
// 3. 最后的fallback(通常会成为openai-compatible)
|
||||||
return provider.id as ProviderId
|
return provider.id as ProviderId
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,9 +7,7 @@ const logger = loggerService.withContext('ProviderConfigs')
|
|||||||
* 新Provider配置定义
|
* 新Provider配置定义
|
||||||
* 定义了需要动态注册的AI Providers
|
* 定义了需要动态注册的AI Providers
|
||||||
*/
|
*/
|
||||||
export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
|
||||||
mappings?: Record<string, string>
|
|
||||||
})[] = [
|
|
||||||
{
|
{
|
||||||
id: 'openrouter',
|
id: 'openrouter',
|
||||||
name: 'OpenRouter',
|
name: 'OpenRouter',
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user