mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 14:41:24 +08:00
feat(model): add parseModelId function to handle model identifiers with colons
This commit is contained in:
parent
b33e595955
commit
4173fcbb98
@ -23,6 +23,7 @@ import { abortCompletion } from '@renderer/utils/abortController'
|
||||
import { buildAgentSessionTopicId } from '@renderer/utils/agentSession'
|
||||
import { getSendMessageShortcutLabel } from '@renderer/utils/input'
|
||||
import { createMainTextBlock, createMessage } from '@renderer/utils/messageUtils/create'
|
||||
import { parseModelId } from '@renderer/utils/model'
|
||||
import { documentExts, imageExts, textExts } from '@shared/config/constant'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback, useEffect, useMemo, useRef } from 'react'
|
||||
@ -67,8 +68,9 @@ const AgentSessionInputbar: FC<Props> = ({ agentId, sessionId }) => {
|
||||
if (!session) return null
|
||||
|
||||
// Extract model info
|
||||
const [providerId, actualModelId] = session.model?.split(':') ?? [undefined, undefined]
|
||||
const actualModel = actualModelId ? getModel(actualModelId, providerId) : undefined
|
||||
// Use parseModelId to handle model IDs with colons (e.g., "openrouter:anthropic/claude:free")
|
||||
const parsed = parseModelId(session.model)
|
||||
const actualModel = parsed ? getModel(parsed.modelId, parsed.providerId) : undefined
|
||||
|
||||
const model: Model | undefined = actualModel
|
||||
? {
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import type { Model, ModelTag } from '@renderer/types'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { getModelTags, isFreeModel } from '../model'
|
||||
import { getModelTags, isFreeModel, parseModelId } from '../model'
|
||||
|
||||
// Mock the model checking functions from @renderer/config/models
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
@ -92,4 +92,56 @@ describe('model', () => {
|
||||
expect(getModelTags(models_2)).toStrictEqual(expected_2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('parseModelId', () => {
|
||||
it('should parse model identifiers with single colon', () => {
|
||||
expect(parseModelId('anthropic:claude-3-sonnet')).toEqual({
|
||||
providerId: 'anthropic',
|
||||
modelId: 'claude-3-sonnet'
|
||||
})
|
||||
|
||||
expect(parseModelId('openai:gpt-4')).toEqual({
|
||||
providerId: 'openai',
|
||||
modelId: 'gpt-4'
|
||||
})
|
||||
})
|
||||
|
||||
it('should parse model identifiers with multiple colons', () => {
|
||||
expect(parseModelId('openrouter:anthropic/claude-3.5-sonnet:free')).toEqual({
|
||||
providerId: 'openrouter',
|
||||
modelId: 'anthropic/claude-3.5-sonnet:free'
|
||||
})
|
||||
|
||||
expect(parseModelId('provider:model:suffix:extra')).toEqual({
|
||||
providerId: 'provider',
|
||||
modelId: 'model:suffix:extra'
|
||||
})
|
||||
})
|
||||
|
||||
it('should return undefined for invalid inputs', () => {
|
||||
expect(parseModelId(undefined)).toBeUndefined()
|
||||
expect(parseModelId('')).toBeUndefined()
|
||||
expect(parseModelId('no-colon')).toBeUndefined()
|
||||
expect(parseModelId(':missing-provider')).toBeUndefined()
|
||||
expect(parseModelId('missing-model:')).toBeUndefined()
|
||||
expect(parseModelId(':')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should handle edge cases', () => {
|
||||
expect(parseModelId('a:b')).toEqual({
|
||||
providerId: 'a',
|
||||
modelId: 'b'
|
||||
})
|
||||
|
||||
expect(parseModelId('provider:model-with-dashes')).toEqual({
|
||||
providerId: 'provider',
|
||||
modelId: 'model-with-dashes'
|
||||
})
|
||||
|
||||
expect(parseModelId('provider:model/with/slashes')).toEqual({
|
||||
providerId: 'provider',
|
||||
modelId: 'model/with/slashes'
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -81,3 +81,39 @@ export const apiModelAdapter = (model: ApiModel): AdaptedApiModel => {
|
||||
origin: model
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a model identifier in the format "provider:modelId"
|
||||
* where modelId may contain additional colons (e.g., "openrouter:anthropic/claude-3.5-sonnet:free")
|
||||
*
|
||||
* @param modelIdentifier - The full model identifier string
|
||||
* @returns Object with providerId and modelId, or undefined if invalid
|
||||
*
|
||||
* @example
|
||||
* parseModelId("openrouter:anthropic/claude-3.5-sonnet:free")
|
||||
* // => { providerId: "openrouter", modelId: "anthropic/claude-3.5-sonnet:free" }
|
||||
*
|
||||
* @example
|
||||
* parseModelId("anthropic:claude-3-sonnet")
|
||||
* // => { providerId: "anthropic", modelId: "claude-3-sonnet" }
|
||||
*
|
||||
* @example
|
||||
* parseModelId("invalid") // => undefined
|
||||
*/
|
||||
export function parseModelId(modelIdentifier: string | undefined): { providerId: string; modelId: string } | undefined {
|
||||
if (!modelIdentifier || typeof modelIdentifier !== 'string') {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const colonIndex = modelIdentifier.indexOf(':')
|
||||
|
||||
// Must contain at least one colon and have content on both sides
|
||||
if (colonIndex <= 0 || colonIndex >= modelIdentifier.length - 1) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return {
|
||||
providerId: modelIdentifier.substring(0, colonIndex),
|
||||
modelId: modelIdentifier.substring(colonIndex + 1)
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user