mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-08 22:39:36 +08:00
feat: enhance Anthropic API support for compatible providers
- Add support for anthropicApiHost configuration in providers - Improve model filtering for Anthropic-compatible providers - Add isAnthropicModel function to validate Anthropic models - Update ClaudeCode service to support compatible providers - Enhance logging and error handling in API routes - Fix model transformation and validation logic
This commit is contained in:
parent
35b885798b
commit
4d133d59ea
@ -10,7 +10,9 @@
|
|||||||
|
|
||||||
import Anthropic from '@anthropic-ai/sdk'
|
import Anthropic from '@anthropic-ai/sdk'
|
||||||
import { TextBlockParam } from '@anthropic-ai/sdk/resources'
|
import { TextBlockParam } from '@anthropic-ai/sdk/resources'
|
||||||
|
import { loggerService } from '@logger'
|
||||||
import { Provider } from '@types'
|
import { Provider } from '@types'
|
||||||
|
const logger = loggerService.withContext('anthropic-sdk')
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates and configures an Anthropic SDK client based on the provider configuration.
|
* Creates and configures an Anthropic SDK client based on the provider configuration.
|
||||||
@ -75,6 +77,7 @@ export function getSdkClient(provider: Provider, oauthToken?: string | null): An
|
|||||||
? provider.apiHost
|
? provider.apiHost
|
||||||
: (provider.anthropicApiHost && provider.anthropicApiHost.trim()) || provider.apiHost
|
: (provider.anthropicApiHost && provider.anthropicApiHost.trim()) || provider.apiHost
|
||||||
|
|
||||||
|
logger.debug('Anthropic API baseURL', { baseURL })
|
||||||
return new Anthropic({
|
return new Anthropic({
|
||||||
apiKey: provider.apiKey,
|
apiKey: provider.apiKey,
|
||||||
authToken: provider.apiKey,
|
authToken: provider.apiKey,
|
||||||
|
|||||||
@ -2,7 +2,8 @@ import { MessageCreateParams } from '@anthropic-ai/sdk/resources'
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import express, { Request, Response } from 'express'
|
import express, { Request, Response } from 'express'
|
||||||
|
|
||||||
import { messagesService } from '../services/messages'
|
import { Provider } from '../../../renderer/src/types/provider'
|
||||||
|
import { MessagesService, messagesService } from '../services/messages'
|
||||||
import { getProviderById, validateModelId } from '../utils'
|
import { getProviderById, validateModelId } from '../utils'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServerMessagesRoutes')
|
const logger = loggerService.withContext('ApiServerMessagesRoutes')
|
||||||
@ -33,9 +34,8 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
|
|||||||
async function handleStreamingResponse(
|
async function handleStreamingResponse(
|
||||||
res: Response,
|
res: Response,
|
||||||
request: MessageCreateParams,
|
request: MessageCreateParams,
|
||||||
provider: any,
|
provider: Provider,
|
||||||
messagesService: any,
|
messagesService: MessagesService
|
||||||
logger: any
|
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
res.setHeader('Content-Type', 'text/event-stream; charset=utf-8')
|
res.setHeader('Content-Type', 'text/event-stream; charset=utf-8')
|
||||||
res.setHeader('Cache-Control', 'no-cache, no-transform')
|
res.setHeader('Cache-Control', 'no-cache, no-transform')
|
||||||
@ -80,7 +80,7 @@ async function handleStreamingResponse(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleErrorResponse(res: Response, error: any, logger: any): Response {
|
function handleErrorResponse(res: Response, error: any): Response {
|
||||||
logger.error('Message processing error', { error })
|
logger.error('Message processing error', { error })
|
||||||
|
|
||||||
let statusCode = 500
|
let statusCode = 500
|
||||||
@ -133,7 +133,7 @@ function handleErrorResponse(res: Response, error: any, logger: any): Response {
|
|||||||
async function processMessageRequest(
|
async function processMessageRequest(
|
||||||
req: Request,
|
req: Request,
|
||||||
res: Response,
|
res: Response,
|
||||||
provider: any,
|
provider: Provider,
|
||||||
modelId?: string
|
modelId?: string
|
||||||
): Promise<Response | void> {
|
): Promise<Response | void> {
|
||||||
try {
|
try {
|
||||||
@ -144,17 +144,6 @@ async function processMessageRequest(
|
|||||||
request.model = modelId
|
request.model = modelId
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure provider is Anthropic type
|
|
||||||
if (provider.type !== 'anthropic') {
|
|
||||||
return res.status(400).json({
|
|
||||||
type: 'error',
|
|
||||||
error: {
|
|
||||||
type: 'invalid_request_error',
|
|
||||||
message: `Invalid provider type '${provider.type}' for messages endpoint. Expected 'anthropic' provider.`
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate request
|
// Validate request
|
||||||
const validation = messagesService.validateRequest(request)
|
const validation = messagesService.validateRequest(request)
|
||||||
if (!validation.isValid) {
|
if (!validation.isValid) {
|
||||||
@ -167,9 +156,14 @@ async function processMessageRequest(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.silly('Processing message request', {
|
||||||
|
request,
|
||||||
|
provider: provider.id
|
||||||
|
})
|
||||||
|
|
||||||
// Handle streaming
|
// Handle streaming
|
||||||
if (request.stream) {
|
if (request.stream) {
|
||||||
await handleStreamingResponse(res, request, provider, messagesService, logger)
|
await handleStreamingResponse(res, request, provider, messagesService)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -177,7 +171,7 @@ async function processMessageRequest(
|
|||||||
const response = await messagesService.processMessage(request, provider)
|
const response = await messagesService.processMessage(request, provider)
|
||||||
return res.json(response)
|
return res.json(response)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
return handleErrorResponse(res, error, logger)
|
return handleErrorResponse(res, error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -337,7 +331,7 @@ router.post('/', async (req: Request, res: Response) => {
|
|||||||
// Use shared processing function
|
// Use shared processing function
|
||||||
return await processMessageRequest(req, res, provider, modelId)
|
return await processMessageRequest(req, res, provider, modelId)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
return handleErrorResponse(res, error, logger)
|
return handleErrorResponse(res, error)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -492,7 +486,7 @@ providerRouter.post('/', async (req: Request, res: Response) => {
|
|||||||
// Use shared processing function (no modelId override needed)
|
// Use shared processing function (no modelId override needed)
|
||||||
return await processMessageRequest(req, res, provider)
|
return await processMessageRequest(req, res, provider)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
return handleErrorResponse(res, error, logger)
|
return handleErrorResponse(res, error)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@ -13,14 +13,24 @@ export class ModelsService {
|
|||||||
try {
|
try {
|
||||||
logger.debug('Getting available models from providers', { filter })
|
logger.debug('Getting available models from providers', { filter })
|
||||||
|
|
||||||
const models = await listAllAvailableModels()
|
let providers = await getAvailableProviders()
|
||||||
const providers = await getAvailableProviders()
|
|
||||||
|
|
||||||
|
if (filter.providerType === 'anthropic') {
|
||||||
|
providers = providers.filter(
|
||||||
|
(p) => p.type === 'anthropic' || (p.anthropicApiHost !== undefined && p.anthropicApiHost.trim() !== '')
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const models = await listAllAvailableModels(providers)
|
||||||
// Use Map to deduplicate models by their full ID (provider:model_id)
|
// Use Map to deduplicate models by their full ID (provider:model_id)
|
||||||
const uniqueModels = new Map<string, ApiModel>()
|
const uniqueModels = new Map<string, ApiModel>()
|
||||||
|
|
||||||
for (const model of models) {
|
for (const model of models) {
|
||||||
const openAIModel = transformModelToOpenAI(model, providers)
|
const provider = providers.find((p) => p.id === model.provider)
|
||||||
|
if (!provider || (provider.isAnthropicModel && !provider.isAnthropicModel(model))) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const openAIModel = transformModelToOpenAI(model, provider)
|
||||||
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
||||||
|
|
||||||
// Only add if not already present (first occurrence wins)
|
// Only add if not already present (first occurrence wins)
|
||||||
@ -32,16 +42,6 @@ export class ModelsService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let modelData = Array.from(uniqueModels.values())
|
let modelData = Array.from(uniqueModels.values())
|
||||||
if (filter.providerType) {
|
|
||||||
// Apply filters
|
|
||||||
const providerType = filter.providerType
|
|
||||||
modelData = modelData.filter((model) => {
|
|
||||||
// Find the provider for this model and check its type
|
|
||||||
return model.provider_type === providerType
|
|
||||||
})
|
|
||||||
logger.debug(`Filtered by provider type '${providerType}': ${modelData.length} models`)
|
|
||||||
}
|
|
||||||
|
|
||||||
const total = modelData.length
|
const total = modelData.length
|
||||||
|
|
||||||
// Apply pagination
|
// Apply pagination
|
||||||
|
|||||||
@ -47,9 +47,11 @@ export async function getAvailableProviders(): Promise<Provider[]> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function listAllAvailableModels(): Promise<Model[]> {
|
export async function listAllAvailableModels(providers?: Provider[]): Promise<Model[]> {
|
||||||
try {
|
try {
|
||||||
const providers = await getAvailableProviders()
|
if (!providers) {
|
||||||
|
providers = await getAvailableProviders()
|
||||||
|
}
|
||||||
return providers.map((p: Provider) => p.models || []).flat()
|
return providers.map((p: Provider) => p.models || []).flat()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Failed to list available models', { error })
|
logger.error('Failed to list available models', { error })
|
||||||
@ -107,9 +109,12 @@ export interface ModelValidationError {
|
|||||||
code: string
|
code: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function validateModelId(
|
export async function validateModelId(model: string): Promise<{
|
||||||
model: string
|
valid: boolean
|
||||||
): Promise<{ valid: boolean; error?: ModelValidationError; provider?: Provider; modelId?: string }> {
|
error?: ModelValidationError
|
||||||
|
provider?: Provider
|
||||||
|
modelId?: string
|
||||||
|
}> {
|
||||||
try {
|
try {
|
||||||
if (!model || typeof model !== 'string') {
|
if (!model || typeof model !== 'string') {
|
||||||
return {
|
return {
|
||||||
@ -192,8 +197,7 @@ export async function validateModelId(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function transformModelToOpenAI(model: Model, providers: Provider[]): ApiModel {
|
export function transformModelToOpenAI(model: Model, provider?: Provider): ApiModel {
|
||||||
const provider = providers.find((p) => p.id === model.provider)
|
|
||||||
const providerDisplayName = provider?.name
|
const providerDisplayName = provider?.name
|
||||||
return {
|
return {
|
||||||
id: `${model.provider}:${model.id}`,
|
id: `${model.provider}:${model.id}`,
|
||||||
@ -268,7 +272,10 @@ export function validateProvider(provider: Provider): boolean {
|
|||||||
|
|
||||||
return true
|
return true
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error validating provider', { error, providerId: provider?.id })
|
logger.error('Error validating provider', {
|
||||||
|
error,
|
||||||
|
providerId: provider?.id
|
||||||
|
})
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import { type Client, createClient } from '@libsql/client'
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { mcpApiService } from '@main/apiServer/services/mcp'
|
import { mcpApiService } from '@main/apiServer/services/mcp'
|
||||||
import { ModelValidationError, validateModelId } from '@main/apiServer/utils'
|
import { ModelValidationError, validateModelId } from '@main/apiServer/utils'
|
||||||
import { AgentType, MCPTool, objectKeys, Provider, SlashCommand, Tool } from '@types'
|
import { AgentType, MCPTool, objectKeys, SlashCommand, Tool } from '@types'
|
||||||
import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
|
import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
|
||||||
import fs from 'fs'
|
import fs from 'fs'
|
||||||
import path from 'path'
|
import path from 'path'
|
||||||
@ -306,23 +306,6 @@ export abstract class BaseService {
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// different agent types may have different provider requirements
|
|
||||||
const agentTypeProviderRequirements: Record<AgentType, Provider['type']> = {
|
|
||||||
'claude-code': 'anthropic'
|
|
||||||
}
|
|
||||||
for (const [ak, pk] of Object.entries(agentTypeProviderRequirements)) {
|
|
||||||
if (agentType === ak && validation.provider.type !== pk) {
|
|
||||||
throw new AgentModelValidationError(
|
|
||||||
{ agentType, field, model: modelValue },
|
|
||||||
{
|
|
||||||
type: 'unsupported_provider_type',
|
|
||||||
message: `Provider type '${validation.provider.type}' is not supported for agent type '${agentType}'. Expected '${pk}'`,
|
|
||||||
code: 'unsupported_provider_type'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -60,7 +60,15 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
})
|
})
|
||||||
return aiStream
|
return aiStream
|
||||||
}
|
}
|
||||||
if (modelInfo.provider?.type !== 'anthropic' || modelInfo.provider.apiKey === '') {
|
if (
|
||||||
|
(modelInfo.provider?.type !== 'anthropic' &&
|
||||||
|
(modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) ||
|
||||||
|
modelInfo.provider.apiKey === ''
|
||||||
|
) {
|
||||||
|
logger.error('Anthropic provider configuration is missing', {
|
||||||
|
modelInfo
|
||||||
|
})
|
||||||
|
|
||||||
aiStream.emit('data', {
|
aiStream.emit('data', {
|
||||||
type: 'error',
|
type: 'error',
|
||||||
error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`)
|
error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`)
|
||||||
|
|||||||
@ -20,7 +20,7 @@ export const ApiModelLabel: React.FC<ModelLabelProps> = ({ model, className, cla
|
|||||||
<Avatar src={model ? getModelLogo(model.id) : undefined} className={cn('h-4 w-4', classNames?.avatar)} />
|
<Avatar src={model ? getModelLogo(model.id) : undefined} className={cn('h-4 w-4', classNames?.avatar)} />
|
||||||
<span className={classNames?.modelName}>{model?.name}</span>
|
<span className={classNames?.modelName}>{model?.name}</span>
|
||||||
<span className={classNames?.divider}> | </span>
|
<span className={classNames?.divider}> | </span>
|
||||||
<span className={classNames?.providerName}>{model?.provider_name}</span>
|
<span className={classNames?.providerName}>{model?.provider}</span>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -56,6 +56,7 @@ import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png'
|
|||||||
import {
|
import {
|
||||||
AtLeast,
|
AtLeast,
|
||||||
isSystemProvider,
|
isSystemProvider,
|
||||||
|
Model,
|
||||||
OpenAIServiceTiers,
|
OpenAIServiceTiers,
|
||||||
Provider,
|
Provider,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
@ -105,6 +106,7 @@ export const SYSTEM_PROVIDERS_CONFIG: Record<SystemProviderId, SystemProvider> =
|
|||||||
apiKey: '',
|
apiKey: '',
|
||||||
apiHost: 'https://aihubmix.com',
|
apiHost: 'https://aihubmix.com',
|
||||||
anthropicApiHost: 'https://aihubmix.com/anthropic',
|
anthropicApiHost: 'https://aihubmix.com/anthropic',
|
||||||
|
isAnthropicModel: (m: Model) => m.id.includes('claude'),
|
||||||
models: SYSTEM_MODELS.aihubmix,
|
models: SYSTEM_MODELS.aihubmix,
|
||||||
isSystem: true,
|
isSystem: true,
|
||||||
enabled: false
|
enabled: false
|
||||||
|
|||||||
@ -2618,6 +2618,7 @@ const migrateConfig = {
|
|||||||
break
|
break
|
||||||
case 'aihubmix':
|
case 'aihubmix':
|
||||||
provider.anthropicApiHost = 'https://aihubmix.com/anthropic'
|
provider.anthropicApiHost = 'https://aihubmix.com/anthropic'
|
||||||
|
provider.isAnthropicModel = (m: Model) => m.id.includes('claude')
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@ -253,6 +253,7 @@ export type Provider = {
|
|||||||
apiKey: string
|
apiKey: string
|
||||||
apiHost: string
|
apiHost: string
|
||||||
anthropicApiHost?: string
|
anthropicApiHost?: string
|
||||||
|
isAnthropicModel?: (m: Model) => boolean
|
||||||
apiVersion?: string
|
apiVersion?: string
|
||||||
models: Model[]
|
models: Model[]
|
||||||
enabled?: boolean
|
enabled?: boolean
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user