diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index f934fbbb81..64b32c6699 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -145,7 +145,7 @@ import YoudaoLogo from '@renderer/assets/images/providers/netease-youdao.svg' import NomicLogo from '@renderer/assets/images/providers/nomic.png' import { getProviderByModel } from '@renderer/services/AssistantService' import { Model } from '@renderer/types' -import { getBaseModelName } from '@renderer/utils' +import { getLowerBaseModelName } from '@renderer/utils' import OpenAI from 'openai' import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from './prompts' @@ -2503,7 +2503,7 @@ export function isSupportedThinkingTokenQwenModel(model?: Model): boolean { return false } - const baseName = getBaseModelName(model.id, '/').toLowerCase() + const baseName = getLowerBaseModelName(model.id, '/') return ( baseName.startsWith('qwen3') || @@ -2616,7 +2616,7 @@ export function isWebSearchModel(model: Model): boolean { return false } - const baseName = getBaseModelName(model.id, '/').toLowerCase() + const baseName = getLowerBaseModelName(model.id, '/') // 不管哪个供应商都判断了 if (model.id.includes('claude')) { @@ -2710,7 +2710,7 @@ export function isGenerateImageModel(model: Model): boolean { return false } - const baseName = getBaseModelName(model.id, '/').toLowerCase() + const baseName = getLowerBaseModelName(model.id, '/') if (GENERATE_IMAGE_MODELS.includes(baseName)) { return true } @@ -2722,7 +2722,7 @@ export function isSupportedDisableGenerationModel(model: Model): boolean { return false } - return SUPPORTED_DISABLE_GENERATION_MODELS.includes(getBaseModelName(model.id)) + return SUPPORTED_DISABLE_GENERATION_MODELS.includes(getLowerBaseModelName(model.id)) } export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boolean): Record { diff --git a/src/renderer/src/utils/__tests__/naming.test.ts b/src/renderer/src/utils/__tests__/naming.test.ts index 1d4560a3ab..3765e1e3b8 100644 --- a/src/renderer/src/utils/__tests__/naming.test.ts +++ b/src/renderer/src/utils/__tests__/naming.test.ts @@ -8,6 +8,7 @@ import { getDefaultGroupName, getFirstCharacter, getLeadingEmoji, + getLowerBaseModelName, isEmoji, removeLeadingEmoji, removeSpecialCharactersForTopicName @@ -190,6 +191,35 @@ describe('naming', () => { }) }) + describe('getLowerBaseModelName', () => { + it('should convert base model name to lowercase', () => { + // 验证将基础模型名称转换为小写 + expect(getLowerBaseModelName('DeepSeek/DeepSeek-R1')).toBe('deepseek-r1') + expect(getLowerBaseModelName('openai/GPT-4.1')).toBe('gpt-4.1') + expect(getLowerBaseModelName('Anthropic/Claude-3.5-Sonnet')).toBe('claude-3.5-sonnet') + }) + + it('should handle multiple levels of paths', () => { + // 验证处理多层路径 + expect(getLowerBaseModelName('Pro/DeepSeek-AI/DeepSeek-R1')).toBe('deepseek-r1') + expect(getLowerBaseModelName('Org/Team/Group/Model')).toBe('model') + }) + + it('should return lowercase original id if no delimiter found', () => { + // 验证没有分隔符时返回小写原始ID + expect(getLowerBaseModelName('DeepSeek-R1')).toBe('deepseek-r1') + expect(getLowerBaseModelName('GPT-4:Free')).toBe('gpt-4:free') + }) + + it('should handle edge cases', () => { + // 验证边缘情况 + expect(getLowerBaseModelName('')).toBe('') + expect(getLowerBaseModelName('Model/')).toBe('') + expect(getLowerBaseModelName('/Model')).toBe('model') + expect(getLowerBaseModelName('Model//Name')).toBe('name') + }) + }) + describe('generateColorFromChar', () => { it('should generate a valid hex color code', () => { // 验证生成有效的十六进制颜色代码 diff --git a/src/renderer/src/utils/naming.ts b/src/renderer/src/utils/naming.ts index df178104de..e26813fa41 100644 --- a/src/renderer/src/utils/naming.ts +++ b/src/renderer/src/utils/naming.ts @@ -60,6 +60,19 @@ export const getBaseModelName = (id: string, delimiter: string = '/'): string => return parts[parts.length - 1] } +/** + * 从模型 ID 中提取基础名称并转换为小写。 + * 例如: + * - 'deepseek/DeepSeek-R1' => 'deepseek-r1' + * - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1' + * @param {string} id 模型 ID + * @param {string} [delimiter='/'] 分隔符,默认为 '/' + * @returns {string} 小写的基础名称 + */ +export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => { + return getBaseModelName(id, delimiter).toLowerCase() +} + /** * 用于获取 avatar 名字的辅助函数,会取出字符串的第一个字符,支持表情符号。 * @param {string} str 输入字符串