mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 21:01:32 +08:00
feat: enhance Anthropic API support for compatible providers
- Add support for anthropicApiHost configuration in providers - Improve model filtering for Anthropic-compatible providers - Add isAnthropicModel function to validate Anthropic models - Update ClaudeCode service to support compatible providers - Enhance logging and error handling in API routes - Fix model transformation and validation logic
This commit is contained in:
parent
35b885798b
commit
4d133d59ea
@ -10,7 +10,9 @@
|
||||
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import { TextBlockParam } from '@anthropic-ai/sdk/resources'
|
||||
import { loggerService } from '@logger'
|
||||
import { Provider } from '@types'
|
||||
const logger = loggerService.withContext('anthropic-sdk')
|
||||
|
||||
/**
|
||||
* Creates and configures an Anthropic SDK client based on the provider configuration.
|
||||
@ -75,6 +77,7 @@ export function getSdkClient(provider: Provider, oauthToken?: string | null): An
|
||||
? provider.apiHost
|
||||
: (provider.anthropicApiHost && provider.anthropicApiHost.trim()) || provider.apiHost
|
||||
|
||||
logger.debug('Anthropic API baseURL', { baseURL })
|
||||
return new Anthropic({
|
||||
apiKey: provider.apiKey,
|
||||
authToken: provider.apiKey,
|
||||
|
||||
@ -2,7 +2,8 @@ import { MessageCreateParams } from '@anthropic-ai/sdk/resources'
|
||||
import { loggerService } from '@logger'
|
||||
import express, { Request, Response } from 'express'
|
||||
|
||||
import { messagesService } from '../services/messages'
|
||||
import { Provider } from '../../../renderer/src/types/provider'
|
||||
import { MessagesService, messagesService } from '../services/messages'
|
||||
import { getProviderById, validateModelId } from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerMessagesRoutes')
|
||||
@ -33,9 +34,8 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
|
||||
async function handleStreamingResponse(
|
||||
res: Response,
|
||||
request: MessageCreateParams,
|
||||
provider: any,
|
||||
messagesService: any,
|
||||
logger: any
|
||||
provider: Provider,
|
||||
messagesService: MessagesService
|
||||
): Promise<void> {
|
||||
res.setHeader('Content-Type', 'text/event-stream; charset=utf-8')
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform')
|
||||
@ -80,7 +80,7 @@ async function handleStreamingResponse(
|
||||
}
|
||||
}
|
||||
|
||||
function handleErrorResponse(res: Response, error: any, logger: any): Response {
|
||||
function handleErrorResponse(res: Response, error: any): Response {
|
||||
logger.error('Message processing error', { error })
|
||||
|
||||
let statusCode = 500
|
||||
@ -133,7 +133,7 @@ function handleErrorResponse(res: Response, error: any, logger: any): Response {
|
||||
async function processMessageRequest(
|
||||
req: Request,
|
||||
res: Response,
|
||||
provider: any,
|
||||
provider: Provider,
|
||||
modelId?: string
|
||||
): Promise<Response | void> {
|
||||
try {
|
||||
@ -144,17 +144,6 @@ async function processMessageRequest(
|
||||
request.model = modelId
|
||||
}
|
||||
|
||||
// Ensure provider is Anthropic type
|
||||
if (provider.type !== 'anthropic') {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: `Invalid provider type '${provider.type}' for messages endpoint. Expected 'anthropic' provider.`
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Validate request
|
||||
const validation = messagesService.validateRequest(request)
|
||||
if (!validation.isValid) {
|
||||
@ -167,9 +156,14 @@ async function processMessageRequest(
|
||||
})
|
||||
}
|
||||
|
||||
logger.silly('Processing message request', {
|
||||
request,
|
||||
provider: provider.id
|
||||
})
|
||||
|
||||
// Handle streaming
|
||||
if (request.stream) {
|
||||
await handleStreamingResponse(res, request, provider, messagesService, logger)
|
||||
await handleStreamingResponse(res, request, provider, messagesService)
|
||||
return
|
||||
}
|
||||
|
||||
@ -177,7 +171,7 @@ async function processMessageRequest(
|
||||
const response = await messagesService.processMessage(request, provider)
|
||||
return res.json(response)
|
||||
} catch (error: any) {
|
||||
return handleErrorResponse(res, error, logger)
|
||||
return handleErrorResponse(res, error)
|
||||
}
|
||||
}
|
||||
|
||||
@ -337,7 +331,7 @@ router.post('/', async (req: Request, res: Response) => {
|
||||
// Use shared processing function
|
||||
return await processMessageRequest(req, res, provider, modelId)
|
||||
} catch (error: any) {
|
||||
return handleErrorResponse(res, error, logger)
|
||||
return handleErrorResponse(res, error)
|
||||
}
|
||||
})
|
||||
|
||||
@ -492,7 +486,7 @@ providerRouter.post('/', async (req: Request, res: Response) => {
|
||||
// Use shared processing function (no modelId override needed)
|
||||
return await processMessageRequest(req, res, provider)
|
||||
} catch (error: any) {
|
||||
return handleErrorResponse(res, error, logger)
|
||||
return handleErrorResponse(res, error)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@ -13,14 +13,24 @@ export class ModelsService {
|
||||
try {
|
||||
logger.debug('Getting available models from providers', { filter })
|
||||
|
||||
const models = await listAllAvailableModels()
|
||||
const providers = await getAvailableProviders()
|
||||
let providers = await getAvailableProviders()
|
||||
|
||||
if (filter.providerType === 'anthropic') {
|
||||
providers = providers.filter(
|
||||
(p) => p.type === 'anthropic' || (p.anthropicApiHost !== undefined && p.anthropicApiHost.trim() !== '')
|
||||
)
|
||||
}
|
||||
|
||||
const models = await listAllAvailableModels(providers)
|
||||
// Use Map to deduplicate models by their full ID (provider:model_id)
|
||||
const uniqueModels = new Map<string, ApiModel>()
|
||||
|
||||
for (const model of models) {
|
||||
const openAIModel = transformModelToOpenAI(model, providers)
|
||||
const provider = providers.find((p) => p.id === model.provider)
|
||||
if (!provider || (provider.isAnthropicModel && !provider.isAnthropicModel(model))) {
|
||||
continue
|
||||
}
|
||||
const openAIModel = transformModelToOpenAI(model, provider)
|
||||
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
||||
|
||||
// Only add if not already present (first occurrence wins)
|
||||
@ -32,16 +42,6 @@ export class ModelsService {
|
||||
}
|
||||
|
||||
let modelData = Array.from(uniqueModels.values())
|
||||
if (filter.providerType) {
|
||||
// Apply filters
|
||||
const providerType = filter.providerType
|
||||
modelData = modelData.filter((model) => {
|
||||
// Find the provider for this model and check its type
|
||||
return model.provider_type === providerType
|
||||
})
|
||||
logger.debug(`Filtered by provider type '${providerType}': ${modelData.length} models`)
|
||||
}
|
||||
|
||||
const total = modelData.length
|
||||
|
||||
// Apply pagination
|
||||
|
||||
@ -47,9 +47,11 @@ export async function getAvailableProviders(): Promise<Provider[]> {
|
||||
}
|
||||
}
|
||||
|
||||
export async function listAllAvailableModels(): Promise<Model[]> {
|
||||
export async function listAllAvailableModels(providers?: Provider[]): Promise<Model[]> {
|
||||
try {
|
||||
const providers = await getAvailableProviders()
|
||||
if (!providers) {
|
||||
providers = await getAvailableProviders()
|
||||
}
|
||||
return providers.map((p: Provider) => p.models || []).flat()
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to list available models', { error })
|
||||
@ -107,9 +109,12 @@ export interface ModelValidationError {
|
||||
code: string
|
||||
}
|
||||
|
||||
export async function validateModelId(
|
||||
model: string
|
||||
): Promise<{ valid: boolean; error?: ModelValidationError; provider?: Provider; modelId?: string }> {
|
||||
export async function validateModelId(model: string): Promise<{
|
||||
valid: boolean
|
||||
error?: ModelValidationError
|
||||
provider?: Provider
|
||||
modelId?: string
|
||||
}> {
|
||||
try {
|
||||
if (!model || typeof model !== 'string') {
|
||||
return {
|
||||
@ -192,8 +197,7 @@ export async function validateModelId(
|
||||
}
|
||||
}
|
||||
|
||||
export function transformModelToOpenAI(model: Model, providers: Provider[]): ApiModel {
|
||||
const provider = providers.find((p) => p.id === model.provider)
|
||||
export function transformModelToOpenAI(model: Model, provider?: Provider): ApiModel {
|
||||
const providerDisplayName = provider?.name
|
||||
return {
|
||||
id: `${model.provider}:${model.id}`,
|
||||
@ -268,7 +272,10 @@ export function validateProvider(provider: Provider): boolean {
|
||||
|
||||
return true
|
||||
} catch (error: any) {
|
||||
logger.error('Error validating provider', { error, providerId: provider?.id })
|
||||
logger.error('Error validating provider', {
|
||||
error,
|
||||
providerId: provider?.id
|
||||
})
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@ import { type Client, createClient } from '@libsql/client'
|
||||
import { loggerService } from '@logger'
|
||||
import { mcpApiService } from '@main/apiServer/services/mcp'
|
||||
import { ModelValidationError, validateModelId } from '@main/apiServer/utils'
|
||||
import { AgentType, MCPTool, objectKeys, Provider, SlashCommand, Tool } from '@types'
|
||||
import { AgentType, MCPTool, objectKeys, SlashCommand, Tool } from '@types'
|
||||
import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
|
||||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
@ -306,23 +306,6 @@ export abstract class BaseService {
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// different agent types may have different provider requirements
|
||||
const agentTypeProviderRequirements: Record<AgentType, Provider['type']> = {
|
||||
'claude-code': 'anthropic'
|
||||
}
|
||||
for (const [ak, pk] of Object.entries(agentTypeProviderRequirements)) {
|
||||
if (agentType === ak && validation.provider.type !== pk) {
|
||||
throw new AgentModelValidationError(
|
||||
{ agentType, field, model: modelValue },
|
||||
{
|
||||
type: 'unsupported_provider_type',
|
||||
message: `Provider type '${validation.provider.type}' is not supported for agent type '${agentType}'. Expected '${pk}'`,
|
||||
code: 'unsupported_provider_type'
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -60,7 +60,15 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
})
|
||||
return aiStream
|
||||
}
|
||||
if (modelInfo.provider?.type !== 'anthropic' || modelInfo.provider.apiKey === '') {
|
||||
if (
|
||||
(modelInfo.provider?.type !== 'anthropic' &&
|
||||
(modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) ||
|
||||
modelInfo.provider.apiKey === ''
|
||||
) {
|
||||
logger.error('Anthropic provider configuration is missing', {
|
||||
modelInfo
|
||||
})
|
||||
|
||||
aiStream.emit('data', {
|
||||
type: 'error',
|
||||
error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`)
|
||||
|
||||
@ -20,7 +20,7 @@ export const ApiModelLabel: React.FC<ModelLabelProps> = ({ model, className, cla
|
||||
<Avatar src={model ? getModelLogo(model.id) : undefined} className={cn('h-4 w-4', classNames?.avatar)} />
|
||||
<span className={classNames?.modelName}>{model?.name}</span>
|
||||
<span className={classNames?.divider}> | </span>
|
||||
<span className={classNames?.providerName}>{model?.provider_name}</span>
|
||||
<span className={classNames?.providerName}>{model?.provider}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -56,6 +56,7 @@ import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png'
|
||||
import {
|
||||
AtLeast,
|
||||
isSystemProvider,
|
||||
Model,
|
||||
OpenAIServiceTiers,
|
||||
Provider,
|
||||
ProviderType,
|
||||
@ -105,6 +106,7 @@ export const SYSTEM_PROVIDERS_CONFIG: Record<SystemProviderId, SystemProvider> =
|
||||
apiKey: '',
|
||||
apiHost: 'https://aihubmix.com',
|
||||
anthropicApiHost: 'https://aihubmix.com/anthropic',
|
||||
isAnthropicModel: (m: Model) => m.id.includes('claude'),
|
||||
models: SYSTEM_MODELS.aihubmix,
|
||||
isSystem: true,
|
||||
enabled: false
|
||||
|
||||
@ -2618,6 +2618,7 @@ const migrateConfig = {
|
||||
break
|
||||
case 'aihubmix':
|
||||
provider.anthropicApiHost = 'https://aihubmix.com/anthropic'
|
||||
provider.isAnthropicModel = (m: Model) => m.id.includes('claude')
|
||||
break
|
||||
}
|
||||
})
|
||||
|
||||
@ -253,6 +253,7 @@ export type Provider = {
|
||||
apiKey: string
|
||||
apiHost: string
|
||||
anthropicApiHost?: string
|
||||
isAnthropicModel?: (m: Model) => boolean
|
||||
apiVersion?: string
|
||||
models: Model[]
|
||||
enabled?: boolean
|
||||
|
||||
Loading…
Reference in New Issue
Block a user