diff --git a/packages/shared/anthropic/index.ts b/packages/shared/anthropic/index.ts index 251e634edf..82d85c54fa 100644 --- a/packages/shared/anthropic/index.ts +++ b/packages/shared/anthropic/index.ts @@ -10,7 +10,9 @@ import Anthropic from '@anthropic-ai/sdk' import { TextBlockParam } from '@anthropic-ai/sdk/resources' +import { loggerService } from '@logger' import { Provider } from '@types' +const logger = loggerService.withContext('anthropic-sdk') /** * Creates and configures an Anthropic SDK client based on the provider configuration. @@ -75,6 +77,7 @@ export function getSdkClient(provider: Provider, oauthToken?: string | null): An ? provider.apiHost : (provider.anthropicApiHost && provider.anthropicApiHost.trim()) || provider.apiHost + logger.debug('Anthropic API baseURL', { baseURL }) return new Anthropic({ apiKey: provider.apiKey, authToken: provider.apiKey, diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index 19d83fb57a..994ee3edd6 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -2,7 +2,8 @@ import { MessageCreateParams } from '@anthropic-ai/sdk/resources' import { loggerService } from '@logger' import express, { Request, Response } from 'express' -import { messagesService } from '../services/messages' +import { Provider } from '../../../renderer/src/types/provider' +import { MessagesService, messagesService } from '../services/messages' import { getProviderById, validateModelId } from '../utils' const logger = loggerService.withContext('ApiServerMessagesRoutes') @@ -33,9 +34,8 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro async function handleStreamingResponse( res: Response, request: MessageCreateParams, - provider: any, - messagesService: any, - logger: any + provider: Provider, + messagesService: MessagesService ): Promise { res.setHeader('Content-Type', 'text/event-stream; charset=utf-8') res.setHeader('Cache-Control', 'no-cache, no-transform') @@ -80,7 +80,7 @@ async function handleStreamingResponse( } } -function handleErrorResponse(res: Response, error: any, logger: any): Response { +function handleErrorResponse(res: Response, error: any): Response { logger.error('Message processing error', { error }) let statusCode = 500 @@ -133,7 +133,7 @@ function handleErrorResponse(res: Response, error: any, logger: any): Response { async function processMessageRequest( req: Request, res: Response, - provider: any, + provider: Provider, modelId?: string ): Promise { try { @@ -144,17 +144,6 @@ async function processMessageRequest( request.model = modelId } - // Ensure provider is Anthropic type - if (provider.type !== 'anthropic') { - return res.status(400).json({ - type: 'error', - error: { - type: 'invalid_request_error', - message: `Invalid provider type '${provider.type}' for messages endpoint. Expected 'anthropic' provider.` - } - }) - } - // Validate request const validation = messagesService.validateRequest(request) if (!validation.isValid) { @@ -167,9 +156,14 @@ async function processMessageRequest( }) } + logger.silly('Processing message request', { + request, + provider: provider.id + }) + // Handle streaming if (request.stream) { - await handleStreamingResponse(res, request, provider, messagesService, logger) + await handleStreamingResponse(res, request, provider, messagesService) return } @@ -177,7 +171,7 @@ async function processMessageRequest( const response = await messagesService.processMessage(request, provider) return res.json(response) } catch (error: any) { - return handleErrorResponse(res, error, logger) + return handleErrorResponse(res, error) } } @@ -337,7 +331,7 @@ router.post('/', async (req: Request, res: Response) => { // Use shared processing function return await processMessageRequest(req, res, provider, modelId) } catch (error: any) { - return handleErrorResponse(res, error, logger) + return handleErrorResponse(res, error) } }) @@ -492,7 +486,7 @@ providerRouter.post('/', async (req: Request, res: Response) => { // Use shared processing function (no modelId override needed) return await processMessageRequest(req, res, provider) } catch (error: any) { - return handleErrorResponse(res, error, logger) + return handleErrorResponse(res, error) } }) diff --git a/src/main/apiServer/services/models.ts b/src/main/apiServer/services/models.ts index 846687a77e..15b88c2154 100644 --- a/src/main/apiServer/services/models.ts +++ b/src/main/apiServer/services/models.ts @@ -13,14 +13,24 @@ export class ModelsService { try { logger.debug('Getting available models from providers', { filter }) - const models = await listAllAvailableModels() - const providers = await getAvailableProviders() + let providers = await getAvailableProviders() + if (filter.providerType === 'anthropic') { + providers = providers.filter( + (p) => p.type === 'anthropic' || (p.anthropicApiHost !== undefined && p.anthropicApiHost.trim() !== '') + ) + } + + const models = await listAllAvailableModels(providers) // Use Map to deduplicate models by their full ID (provider:model_id) const uniqueModels = new Map() for (const model of models) { - const openAIModel = transformModelToOpenAI(model, providers) + const provider = providers.find((p) => p.id === model.provider) + if (!provider || (provider.isAnthropicModel && !provider.isAnthropicModel(model))) { + continue + } + const openAIModel = transformModelToOpenAI(model, provider) const fullModelId = openAIModel.id // This is already in format "provider:model_id" // Only add if not already present (first occurrence wins) @@ -32,16 +42,6 @@ export class ModelsService { } let modelData = Array.from(uniqueModels.values()) - if (filter.providerType) { - // Apply filters - const providerType = filter.providerType - modelData = modelData.filter((model) => { - // Find the provider for this model and check its type - return model.provider_type === providerType - }) - logger.debug(`Filtered by provider type '${providerType}': ${modelData.length} models`) - } - const total = modelData.length // Apply pagination diff --git a/src/main/apiServer/utils/index.ts b/src/main/apiServer/utils/index.ts index 6663918927..85167d8c7d 100644 --- a/src/main/apiServer/utils/index.ts +++ b/src/main/apiServer/utils/index.ts @@ -47,9 +47,11 @@ export async function getAvailableProviders(): Promise { } } -export async function listAllAvailableModels(): Promise { +export async function listAllAvailableModels(providers?: Provider[]): Promise { try { - const providers = await getAvailableProviders() + if (!providers) { + providers = await getAvailableProviders() + } return providers.map((p: Provider) => p.models || []).flat() } catch (error: any) { logger.error('Failed to list available models', { error }) @@ -107,9 +109,12 @@ export interface ModelValidationError { code: string } -export async function validateModelId( - model: string -): Promise<{ valid: boolean; error?: ModelValidationError; provider?: Provider; modelId?: string }> { +export async function validateModelId(model: string): Promise<{ + valid: boolean + error?: ModelValidationError + provider?: Provider + modelId?: string +}> { try { if (!model || typeof model !== 'string') { return { @@ -192,8 +197,7 @@ export async function validateModelId( } } -export function transformModelToOpenAI(model: Model, providers: Provider[]): ApiModel { - const provider = providers.find((p) => p.id === model.provider) +export function transformModelToOpenAI(model: Model, provider?: Provider): ApiModel { const providerDisplayName = provider?.name return { id: `${model.provider}:${model.id}`, @@ -268,7 +272,10 @@ export function validateProvider(provider: Provider): boolean { return true } catch (error: any) { - logger.error('Error validating provider', { error, providerId: provider?.id }) + logger.error('Error validating provider', { + error, + providerId: provider?.id + }) return false } } diff --git a/src/main/services/agents/BaseService.ts b/src/main/services/agents/BaseService.ts index 73a1a1828a..cadf45b43d 100644 --- a/src/main/services/agents/BaseService.ts +++ b/src/main/services/agents/BaseService.ts @@ -2,7 +2,7 @@ import { type Client, createClient } from '@libsql/client' import { loggerService } from '@logger' import { mcpApiService } from '@main/apiServer/services/mcp' import { ModelValidationError, validateModelId } from '@main/apiServer/utils' -import { AgentType, MCPTool, objectKeys, Provider, SlashCommand, Tool } from '@types' +import { AgentType, MCPTool, objectKeys, SlashCommand, Tool } from '@types' import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql' import fs from 'fs' import path from 'path' @@ -306,23 +306,6 @@ export abstract class BaseService { } ) } - - // different agent types may have different provider requirements - const agentTypeProviderRequirements: Record = { - 'claude-code': 'anthropic' - } - for (const [ak, pk] of Object.entries(agentTypeProviderRequirements)) { - if (agentType === ak && validation.provider.type !== pk) { - throw new AgentModelValidationError( - { agentType, field, model: modelValue }, - { - type: 'unsupported_provider_type', - message: `Provider type '${validation.provider.type}' is not supported for agent type '${agentType}'. Expected '${pk}'`, - code: 'unsupported_provider_type' - } - ) - } - } } } diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts index cc72b3c5a7..3c7ab40bc4 100644 --- a/src/main/services/agents/services/claudecode/index.ts +++ b/src/main/services/agents/services/claudecode/index.ts @@ -60,7 +60,15 @@ class ClaudeCodeService implements AgentServiceInterface { }) return aiStream } - if (modelInfo.provider?.type !== 'anthropic' || modelInfo.provider.apiKey === '') { + if ( + (modelInfo.provider?.type !== 'anthropic' && + (modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) || + modelInfo.provider.apiKey === '' + ) { + logger.error('Anthropic provider configuration is missing', { + modelInfo + }) + aiStream.emit('data', { type: 'error', error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`) diff --git a/src/renderer/src/components/ApiModelLabel.tsx b/src/renderer/src/components/ApiModelLabel.tsx index 4e6d318ebd..b2e08635fe 100644 --- a/src/renderer/src/components/ApiModelLabel.tsx +++ b/src/renderer/src/components/ApiModelLabel.tsx @@ -20,7 +20,7 @@ export const ApiModelLabel: React.FC = ({ model, className, cla {model?.name} | - {model?.provider_name} + {model?.provider} ) } diff --git a/src/renderer/src/config/providers.ts b/src/renderer/src/config/providers.ts index b0692e7f17..20be65ba7b 100644 --- a/src/renderer/src/config/providers.ts +++ b/src/renderer/src/config/providers.ts @@ -56,6 +56,7 @@ import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png' import { AtLeast, isSystemProvider, + Model, OpenAIServiceTiers, Provider, ProviderType, @@ -105,6 +106,7 @@ export const SYSTEM_PROVIDERS_CONFIG: Record = apiKey: '', apiHost: 'https://aihubmix.com', anthropicApiHost: 'https://aihubmix.com/anthropic', + isAnthropicModel: (m: Model) => m.id.includes('claude'), models: SYSTEM_MODELS.aihubmix, isSystem: true, enabled: false diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 35dd216565..4986b911fe 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -2618,6 +2618,7 @@ const migrateConfig = { break case 'aihubmix': provider.anthropicApiHost = 'https://aihubmix.com/anthropic' + provider.isAnthropicModel = (m: Model) => m.id.includes('claude') break } }) diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 56bc883fe8..dff9563b7f 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -253,6 +253,7 @@ export type Provider = { apiKey: string apiHost: string anthropicApiHost?: string + isAnthropicModel?: (m: Model) => boolean apiVersion?: string models: Model[] enabled?: boolean