diff --git a/src/renderer/src/pages/home/Inputbar/AgentSessionInputbar.tsx b/src/renderer/src/pages/home/Inputbar/AgentSessionInputbar.tsx index 850be7f727..2aedab0b21 100644 --- a/src/renderer/src/pages/home/Inputbar/AgentSessionInputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/AgentSessionInputbar.tsx @@ -23,6 +23,7 @@ import { abortCompletion } from '@renderer/utils/abortController' import { buildAgentSessionTopicId } from '@renderer/utils/agentSession' import { getSendMessageShortcutLabel } from '@renderer/utils/input' import { createMainTextBlock, createMessage } from '@renderer/utils/messageUtils/create' +import { parseModelId } from '@renderer/utils/model' import { documentExts, imageExts, textExts } from '@shared/config/constant' import type { FC } from 'react' import React, { useCallback, useEffect, useMemo, useRef } from 'react' @@ -67,8 +68,9 @@ const AgentSessionInputbar: FC = ({ agentId, sessionId }) => { if (!session) return null // Extract model info - const [providerId, actualModelId] = session.model?.split(':') ?? [undefined, undefined] - const actualModel = actualModelId ? getModel(actualModelId, providerId) : undefined + // Use parseModelId to handle model IDs with colons (e.g., "openrouter:anthropic/claude:free") + const parsed = parseModelId(session.model) + const actualModel = parsed ? getModel(parsed.modelId, parsed.providerId) : undefined const model: Model | undefined = actualModel ? { diff --git a/src/renderer/src/utils/__tests__/model.test.ts b/src/renderer/src/utils/__tests__/model.test.ts index fe1697e3ed..79f7948fb1 100644 --- a/src/renderer/src/utils/__tests__/model.test.ts +++ b/src/renderer/src/utils/__tests__/model.test.ts @@ -1,7 +1,7 @@ import type { Model, ModelTag } from '@renderer/types' import { describe, expect, it, vi } from 'vitest' -import { getModelTags, isFreeModel } from '../model' +import { getModelTags, isFreeModel, parseModelId } from '../model' // Mock the model checking functions from @renderer/config/models vi.mock('@renderer/config/models', () => ({ @@ -92,4 +92,56 @@ describe('model', () => { expect(getModelTags(models_2)).toStrictEqual(expected_2) }) }) + + describe('parseModelId', () => { + it('should parse model identifiers with single colon', () => { + expect(parseModelId('anthropic:claude-3-sonnet')).toEqual({ + providerId: 'anthropic', + modelId: 'claude-3-sonnet' + }) + + expect(parseModelId('openai:gpt-4')).toEqual({ + providerId: 'openai', + modelId: 'gpt-4' + }) + }) + + it('should parse model identifiers with multiple colons', () => { + expect(parseModelId('openrouter:anthropic/claude-3.5-sonnet:free')).toEqual({ + providerId: 'openrouter', + modelId: 'anthropic/claude-3.5-sonnet:free' + }) + + expect(parseModelId('provider:model:suffix:extra')).toEqual({ + providerId: 'provider', + modelId: 'model:suffix:extra' + }) + }) + + it('should return undefined for invalid inputs', () => { + expect(parseModelId(undefined)).toBeUndefined() + expect(parseModelId('')).toBeUndefined() + expect(parseModelId('no-colon')).toBeUndefined() + expect(parseModelId(':missing-provider')).toBeUndefined() + expect(parseModelId('missing-model:')).toBeUndefined() + expect(parseModelId(':')).toBeUndefined() + }) + + it('should handle edge cases', () => { + expect(parseModelId('a:b')).toEqual({ + providerId: 'a', + modelId: 'b' + }) + + expect(parseModelId('provider:model-with-dashes')).toEqual({ + providerId: 'provider', + modelId: 'model-with-dashes' + }) + + expect(parseModelId('provider:model/with/slashes')).toEqual({ + providerId: 'provider', + modelId: 'model/with/slashes' + }) + }) + }) }) diff --git a/src/renderer/src/utils/model.ts b/src/renderer/src/utils/model.ts index a74ffab25f..23e955eff3 100644 --- a/src/renderer/src/utils/model.ts +++ b/src/renderer/src/utils/model.ts @@ -81,3 +81,39 @@ export const apiModelAdapter = (model: ApiModel): AdaptedApiModel => { origin: model } } + +/** + * Parse a model identifier in the format "provider:modelId" + * where modelId may contain additional colons (e.g., "openrouter:anthropic/claude-3.5-sonnet:free") + * + * @param modelIdentifier - The full model identifier string + * @returns Object with providerId and modelId, or undefined if invalid + * + * @example + * parseModelId("openrouter:anthropic/claude-3.5-sonnet:free") + * // => { providerId: "openrouter", modelId: "anthropic/claude-3.5-sonnet:free" } + * + * @example + * parseModelId("anthropic:claude-3-sonnet") + * // => { providerId: "anthropic", modelId: "claude-3-sonnet" } + * + * @example + * parseModelId("invalid") // => undefined + */ +export function parseModelId(modelIdentifier: string | undefined): { providerId: string; modelId: string } | undefined { + if (!modelIdentifier || typeof modelIdentifier !== 'string') { + return undefined + } + + const colonIndex = modelIdentifier.indexOf(':') + + // Must contain at least one colon and have content on both sides + if (colonIndex <= 0 || colonIndex >= modelIdentifier.length - 1) { + return undefined + } + + return { + providerId: modelIdentifier.substring(0, colonIndex), + modelId: modelIdentifier.substring(colonIndex + 1) + } +}