diff --git a/src/main/apiServer/services/models.ts b/src/main/apiServer/services/models.ts index 6c0056b27e..660686ef45 100644 --- a/src/main/apiServer/services/models.ts +++ b/src/main/apiServer/services/models.ts @@ -2,7 +2,12 @@ import { isEmpty } from 'lodash' import { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels' import { loggerService } from '../../services/LoggerService' -import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils' +import { + getAvailableProviders, + getProviderAnthropicModelChecker, + listAllAvailableModels, + transformModelToOpenAI +} from '../utils' const logger = loggerService.withContext('ModelsService') @@ -10,10 +15,6 @@ const logger = loggerService.withContext('ModelsService') export type ModelsFilter = ApiModelsFilter -const isAnthropicProvider = (provider: { type: string; anthropicApiHost?: string }) => { - return provider.type === 'anthropic' || !isEmpty(provider.anthropicApiHost?.trim()) -} - export class ModelsService { async getModels(filter: ModelsFilter): Promise { try { @@ -22,7 +23,7 @@ export class ModelsService { let providers = await getAvailableProviders() if (filter.providerType === 'anthropic') { - providers = providers.filter(isAnthropicProvider) + providers = providers.filter((p) => p.type === 'anthropic' || !isEmpty(p.anthropicApiHost?.trim())) } const models = await listAllAvailableModels(providers) @@ -31,22 +32,18 @@ export class ModelsService { for (const model of models) { const provider = providers.find((p) => p.id === model.provider) - logger.debug(`Processing model ${model.id} from provider ${model.provider}`, { - isAnthropicModel: provider?.isAnthropicModel - }) - if ( - !provider || - (filter.providerType === 'anthropic' && provider.isAnthropicModel && !provider.isAnthropicModel(model)) - ) { - continue - } - // Special case: For "aihubmix", it should be covered by above condition, but just in case - if (provider.id === 'aihubmix' && filter.providerType === 'anthropic' && !model.id.includes('claude')) { + logger.debug(`Processing model ${model.id}`) + if (!provider) { + logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`) continue } - if (filter.supportAnthropic && model.endpoint_type !== 'anthropic' && !isAnthropicProvider(provider)) { - continue + if (filter.providerType === 'anthropic') { + const checker = getProviderAnthropicModelChecker(provider.id) + if (!checker(model)) { + logger.debug(`Skipping model ${model.id} from ${model.provider}. Reason: Not an Anthropic model.`) + continue + } } const openAIModel = transformModelToOpenAI(model, provider) diff --git a/src/main/apiServer/utils/index.ts b/src/main/apiServer/utils/index.ts index 865f961db9..7fb0c3511f 100644 --- a/src/main/apiServer/utils/index.ts +++ b/src/main/apiServer/utils/index.ts @@ -1,7 +1,7 @@ import { CacheService } from '@main/services/CacheService' import { loggerService } from '@main/services/LoggerService' import { reduxService } from '@main/services/ReduxService' -import { ApiModel, EndpointType, Model, Provider } from '@types' +import { ApiModel, Model, Provider } from '@types' const logger = loggerService.withContext('ApiServerUtils') @@ -114,7 +114,6 @@ export async function validateModelId(model: string): Promise<{ error?: ModelValidationError provider?: Provider modelId?: string - modelEndpointType?: EndpointType }> { try { if (!model || typeof model !== 'string') { @@ -167,8 +166,7 @@ export async function validateModelId(model: string): Promise<{ } // Check if model exists in provider - const modelInProvider = provider.models?.find((m) => m.id === modelId) - const modelExists = !!modelInProvider + const modelExists = provider.models?.some((m) => m.id === modelId) if (!modelExists) { const availableModels = provider.models?.map((m) => m.id).join(', ') || 'none' return { @@ -181,13 +179,10 @@ export async function validateModelId(model: string): Promise<{ } } - const modelEndpointType = modelInProvider?.endpoint_type - return { valid: true, provider, - modelId, - modelEndpointType + modelId } } catch (error: any) { logger.error('Error validating model ID', { error, model }) @@ -284,3 +279,16 @@ export function validateProvider(provider: Provider): boolean { return false } } + +export const getProviderAnthropicModelChecker = (providerId: string): ((m: Model) => boolean) => { + switch (providerId) { + case 'cherryin': + case 'new-api': + return (m: Model) => m.endpoint_type === 'anthropic' + case 'aihubmix': + return (m: Model) => m.id.includes('claude') + default: + // allow all models when checker not configured + return () => true + } +} diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts index 7dae2f9e9e..7b2f119afb 100644 --- a/src/main/services/agents/services/claudecode/index.ts +++ b/src/main/services/agents/services/claudecode/index.ts @@ -8,7 +8,6 @@ import { config as apiConfigService } from '@main/apiServer/config' import { validateModelId } from '@main/apiServer/utils' import getLoginShellEnvironment from '@main/utils/shell-env' import { app } from 'electron' -import { isEmpty } from 'lodash' import { GetAgentSessionResponse } from '../..' import { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface' @@ -61,20 +60,11 @@ class ClaudeCodeService implements AgentServiceInterface { }) return aiStream } - - const validateModelInfo: (m: typeof modelInfo) => boolean = (m) => { - const { provider, modelEndpointType } = m - if (!provider) return false - if (isEmpty(provider.apiKey?.trim())) return false - - const isAnthropicType = provider.type === 'anthropic' - const isAnthropicEndpoint = modelEndpointType === 'anthropic' - const hasValidApiHost = !isEmpty(provider.anthropicApiHost?.trim()) - - return !(!isAnthropicType && !isAnthropicEndpoint && !hasValidApiHost) - } - - if (!modelInfo.provider || !validateModelInfo(modelInfo)) { + if ( + (modelInfo.provider?.type !== 'anthropic' && + (modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) || + modelInfo.provider.apiKey === '' + ) { logger.error('Anthropic provider configuration is missing', { modelInfo }) diff --git a/src/renderer/src/components/Popups/agent/AgentModal.tsx b/src/renderer/src/components/Popups/agent/AgentModal.tsx index 63b614944a..ea772fb1e6 100644 --- a/src/renderer/src/components/Popups/agent/AgentModal.tsx +++ b/src/renderer/src/components/Popups/agent/AgentModal.tsx @@ -17,7 +17,7 @@ import { import { loggerService } from '@logger' import type { Selection } from '@react-types/shared' import ClaudeIcon from '@renderer/assets/images/models/claude.png' -import { getModelLogo } from '@renderer/config/models' +import { agentModelFilter, getModelLogo } from '@renderer/config/models' import { permissionModeCards } from '@renderer/constants/permissionModes' import { useAgents } from '@renderer/hooks/agents/useAgents' import { useApiModels } from '@renderer/hooks/agents/useModels' @@ -100,7 +100,7 @@ export const AgentModal: React.FC = ({ agent, trigger, isOpen: _isOpen, o const { addAgent } = useAgents() const { updateAgent } = useUpdateAgent() // hard-coded. We only support anthropic for now. - const { models } = useApiModels({ supportAnthropic: true }) + const { models } = useApiModels({ providerType: 'anthropic' }) const isEditing = (agent?: AgentWithTools) => agent !== undefined const [form, setForm] = useState(() => buildAgentForm(agent)) @@ -245,14 +245,23 @@ export const AgentModal: React.FC = ({ agent, trigger, isOpen: _isOpen, o const modelOptions = useMemo(() => { // mocked data. not final version - return (models ?? []).map((model) => ({ - type: 'model', - key: model.id, - label: model.name, - avatar: getModelLogo(model.id), - providerId: model.provider, - providerName: model.provider_name - })) satisfies ModelOption[] + return (models ?? []) + .filter((m) => + agentModelFilter({ + id: m.id, + provider: m.provider || '', + name: m.name, + group: '' + }) + ) + .map((model) => ({ + type: 'model', + key: model.id, + label: model.name, + avatar: getModelLogo(model.id), + providerId: model.provider, + providerName: model.provider_name + })) satisfies ModelOption[] }, [models]) const onModelChange = useCallback((e: ChangeEvent) => { diff --git a/src/renderer/src/config/models/utils.ts b/src/renderer/src/config/models/utils.ts index 39078e2924..1759a93d18 100644 --- a/src/renderer/src/config/models/utils.ts +++ b/src/renderer/src/config/models/utils.ts @@ -1,3 +1,4 @@ +import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding' import { Model } from '@renderer/types' import { getLowerBaseModelName } from '@renderer/utils' import OpenAI from 'openai' @@ -5,7 +6,7 @@ import OpenAI from 'openai' import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from '../prompts' import { getWebSearchTools } from '../tools' import { isOpenAIReasoningModel } from './reasoning' -import { isGenerateImageModel, isVisionModel } from './vision' +import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision' import { isOpenAIWebSearchChatCompletionOnlyModel } from './websearch' export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i @@ -246,3 +247,7 @@ export const isOpenAIOpenWeightModel = (model: Model) => { // zhipu 视觉推理模型用这组 special token 标记推理结果 export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const + +export const agentModelFilter = (model: Model): boolean => { + return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model) +} diff --git a/src/renderer/src/config/providers.ts b/src/renderer/src/config/providers.ts index 22faf0fb0e..7f8d95dcd1 100644 --- a/src/renderer/src/config/providers.ts +++ b/src/renderer/src/config/providers.ts @@ -58,7 +58,6 @@ import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png' import { AtLeast, isSystemProvider, - Model, OpenAIServiceTiers, Provider, ProviderType, @@ -88,6 +87,7 @@ export const SYSTEM_PROVIDERS_CONFIG: Record = type: 'openai', apiKey: '', apiHost: 'https://open.cherryin.net', + anthropicApiHost: 'https://open.cherryin.net', models: [], isSystem: true, enabled: true @@ -109,7 +109,6 @@ export const SYSTEM_PROVIDERS_CONFIG: Record = apiKey: '', apiHost: 'https://aihubmix.com', anthropicApiHost: 'https://aihubmix.com/anthropic', - isAnthropicModel: (m: Model) => m.id.includes('claude'), models: SYSTEM_MODELS.aihubmix, isSystem: true, enabled: false diff --git a/src/renderer/src/pages/home/components/SelectAgentBaseModelButton.tsx b/src/renderer/src/pages/home/components/SelectAgentBaseModelButton.tsx index 7e21d1b47a..74f69cf15e 100644 --- a/src/renderer/src/pages/home/components/SelectAgentBaseModelButton.tsx +++ b/src/renderer/src/pages/home/components/SelectAgentBaseModelButton.tsx @@ -1,10 +1,10 @@ import { Button } from '@heroui/react' import ModelAvatar from '@renderer/components/Avatar/ModelAvatar' import { SelectApiModelPopup } from '@renderer/components/Popups/SelectModelPopup' -import { isEmbeddingModel, isRerankModel, isTextToImageModel } from '@renderer/config/models' +import { agentModelFilter } from '@renderer/config/models' import { useApiModel } from '@renderer/hooks/agents/useModel' import { getProviderNameById } from '@renderer/services/ProviderService' -import { AgentBaseWithId, ApiModel, isAgentEntity, Model } from '@renderer/types' +import { AgentBaseWithId, ApiModel, isAgentEntity } from '@renderer/types' import { getModelFilterByAgentType } from '@renderer/utils/agentSession' import { apiModelAdapter } from '@renderer/utils/model' import { ChevronsUpDown } from 'lucide-react' @@ -22,12 +22,11 @@ const SelectAgentBaseModelButton: FC = ({ agentBase: agent, onSelect, isD const model = useApiModel({ id: agent?.model }) const apiFilter = isAgentEntity(agent) ? getModelFilterByAgentType(agent.type) : undefined - const modelFilter = (model: Model) => !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model) if (!agent) return null const onSelectModel = async () => { - const selectedModel = await SelectApiModelPopup.show({ model, apiFilter: apiFilter, modelFilter }) + const selectedModel = await SelectApiModelPopup.show({ model, apiFilter: apiFilter, modelFilter: agentModelFilter }) if (selectedModel && selectedModel.id !== agent.model) { onSelect(selectedModel) } diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index 0d53c2ca5b..0fee3196b2 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -65,7 +65,7 @@ const persistedReducer = persistReducer( { key: 'cherry-studio', storage, - version: 162, + version: 163, blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs'], migrate }, diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 5cca66b47a..46ac128c66 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -2671,6 +2671,11 @@ const migrateConfig = { '163': (state: RootState) => { try { addOcrProvider(state, BUILTIN_OCR_PROVIDERS_MAP.ovocr) + state.llm.providers.forEach((provider) => { + if (provider.id === 'cherryin') { + provider.anthropicApiHost = 'https://open.cherryin.net' + } + }) return state } catch (error) { logger.error('migrate 163 error', error as Error) diff --git a/src/renderer/src/types/apiModels.ts b/src/renderer/src/types/apiModels.ts index 7b4ec96c9d..68141bf68c 100644 --- a/src/renderer/src/types/apiModels.ts +++ b/src/renderer/src/types/apiModels.ts @@ -6,7 +6,6 @@ import { ProviderTypeSchema } from './provider' // Request schema for /v1/models export const ApiModelsFilterSchema = z.object({ providerType: ProviderTypeSchema.optional(), - supportAnthropic: z.coerce.boolean().optional(), offset: z.coerce.number().min(0).default(0).optional(), limit: z.coerce.number().min(1).default(20).optional() }) diff --git a/src/renderer/src/utils/agentSession.ts b/src/renderer/src/utils/agentSession.ts index b2cf14d174..df34413641 100644 --- a/src/renderer/src/utils/agentSession.ts +++ b/src/renderer/src/utils/agentSession.ts @@ -18,7 +18,7 @@ export const getModelFilterByAgentType = (type: AgentType): ApiModelsFilter => { switch (type) { case 'claude-code': return { - supportAnthropic: true + providerType: 'anthropic' } default: return {}