diff --git a/src/main/apiServer/services/models.ts b/src/main/apiServer/services/models.ts index 684d7f10a8..6c0056b27e 100644 --- a/src/main/apiServer/services/models.ts +++ b/src/main/apiServer/services/models.ts @@ -1,3 +1,5 @@ +import { isEmpty } from 'lodash' + import { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels' import { loggerService } from '../../services/LoggerService' import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils' @@ -8,6 +10,10 @@ const logger = loggerService.withContext('ModelsService') export type ModelsFilter = ApiModelsFilter +const isAnthropicProvider = (provider: { type: string; anthropicApiHost?: string }) => { + return provider.type === 'anthropic' || !isEmpty(provider.anthropicApiHost?.trim()) +} + export class ModelsService { async getModels(filter: ModelsFilter): Promise { try { @@ -16,9 +22,7 @@ export class ModelsService { let providers = await getAvailableProviders() if (filter.providerType === 'anthropic') { - providers = providers.filter( - (p) => p.type === 'anthropic' || (p.anthropicApiHost !== undefined && p.anthropicApiHost.trim() !== '') - ) + providers = providers.filter(isAnthropicProvider) } const models = await listAllAvailableModels(providers) @@ -41,6 +45,10 @@ export class ModelsService { continue } + if (filter.supportAnthropic && model.endpoint_type !== 'anthropic' && !isAnthropicProvider(provider)) { + continue + } + const openAIModel = transformModelToOpenAI(model, provider) const fullModelId = openAIModel.id // This is already in format "provider:model_id" diff --git a/src/main/apiServer/utils/index.ts b/src/main/apiServer/utils/index.ts index bd80c73f51..865f961db9 100644 --- a/src/main/apiServer/utils/index.ts +++ b/src/main/apiServer/utils/index.ts @@ -1,7 +1,7 @@ import { CacheService } from '@main/services/CacheService' import { loggerService } from '@main/services/LoggerService' import { reduxService } from '@main/services/ReduxService' -import { ApiModel, Model, Provider } from '@types' +import { ApiModel, EndpointType, Model, Provider } from '@types' const logger = loggerService.withContext('ApiServerUtils') @@ -114,6 +114,7 @@ export async function validateModelId(model: string): Promise<{ error?: ModelValidationError provider?: Provider modelId?: string + modelEndpointType?: EndpointType }> { try { if (!model || typeof model !== 'string') { @@ -166,7 +167,8 @@ export async function validateModelId(model: string): Promise<{ } // Check if model exists in provider - const modelExists = provider.models?.some((m) => m.id === modelId) + const modelInProvider = provider.models?.find((m) => m.id === modelId) + const modelExists = !!modelInProvider if (!modelExists) { const availableModels = provider.models?.map((m) => m.id).join(', ') || 'none' return { @@ -179,10 +181,13 @@ export async function validateModelId(model: string): Promise<{ } } + const modelEndpointType = modelInProvider?.endpoint_type + return { valid: true, provider, - modelId + modelId, + modelEndpointType } } catch (error: any) { logger.error('Error validating model ID', { error, model }) diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts index 7b2f119afb..7dae2f9e9e 100644 --- a/src/main/services/agents/services/claudecode/index.ts +++ b/src/main/services/agents/services/claudecode/index.ts @@ -8,6 +8,7 @@ import { config as apiConfigService } from '@main/apiServer/config' import { validateModelId } from '@main/apiServer/utils' import getLoginShellEnvironment from '@main/utils/shell-env' import { app } from 'electron' +import { isEmpty } from 'lodash' import { GetAgentSessionResponse } from '../..' import { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface' @@ -60,11 +61,20 @@ class ClaudeCodeService implements AgentServiceInterface { }) return aiStream } - if ( - (modelInfo.provider?.type !== 'anthropic' && - (modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) || - modelInfo.provider.apiKey === '' - ) { + + const validateModelInfo: (m: typeof modelInfo) => boolean = (m) => { + const { provider, modelEndpointType } = m + if (!provider) return false + if (isEmpty(provider.apiKey?.trim())) return false + + const isAnthropicType = provider.type === 'anthropic' + const isAnthropicEndpoint = modelEndpointType === 'anthropic' + const hasValidApiHost = !isEmpty(provider.anthropicApiHost?.trim()) + + return !(!isAnthropicType && !isAnthropicEndpoint && !hasValidApiHost) + } + + if (!modelInfo.provider || !validateModelInfo(modelInfo)) { logger.error('Anthropic provider configuration is missing', { modelInfo }) diff --git a/src/renderer/src/components/Popups/agent/AgentModal.tsx b/src/renderer/src/components/Popups/agent/AgentModal.tsx index fa76e0e330..63b614944a 100644 --- a/src/renderer/src/components/Popups/agent/AgentModal.tsx +++ b/src/renderer/src/components/Popups/agent/AgentModal.tsx @@ -100,7 +100,7 @@ export const AgentModal: React.FC = ({ agent, trigger, isOpen: _isOpen, o const { addAgent } = useAgents() const { updateAgent } = useUpdateAgent() // hard-coded. We only support anthropic for now. - const { models } = useApiModels({ providerType: 'anthropic' }) + const { models } = useApiModels({ supportAnthropic: true }) const isEditing = (agent?: AgentWithTools) => agent !== undefined const [form, setForm] = useState(() => buildAgentForm(agent)) diff --git a/src/renderer/src/types/apiModels.ts b/src/renderer/src/types/apiModels.ts index 68141bf68c..7b4ec96c9d 100644 --- a/src/renderer/src/types/apiModels.ts +++ b/src/renderer/src/types/apiModels.ts @@ -6,6 +6,7 @@ import { ProviderTypeSchema } from './provider' // Request schema for /v1/models export const ApiModelsFilterSchema = z.object({ providerType: ProviderTypeSchema.optional(), + supportAnthropic: z.coerce.boolean().optional(), offset: z.coerce.number().min(0).default(0).optional(), limit: z.coerce.number().min(1).default(20).optional() }) diff --git a/src/renderer/src/utils/agentSession.ts b/src/renderer/src/utils/agentSession.ts index df34413641..b2cf14d174 100644 --- a/src/renderer/src/utils/agentSession.ts +++ b/src/renderer/src/utils/agentSession.ts @@ -18,7 +18,7 @@ export const getModelFilterByAgentType = (type: AgentType): ApiModelsFilter => { switch (type) { case 'claude-code': return { - providerType: 'anthropic' + supportAnthropic: true } default: return {}