feat: implement comprehensive Claude Code OAuth integration and API enhancements

- Add shared Anthropic utilities package with OAuth and API key client creation
- Implement provider-specific message routing alongside existing v1 API
- Enhance authentication middleware with priority handling (API key > Bearer token)
- Add comprehensive auth middleware test suite with timing attack protection
- Update session handling and message transformation for Claude Code integration
- Improve error handling and validation across message processing pipeline
- Standardize import formatting and code structure across affected modules

This establishes the foundation for Claude Code OAuth authentication while maintaining
backward compatibility with existing API key authentication methods.
This commit is contained in:
Vaayne 2025-09-21 16:42:46 +08:00
parent c3b2af5a15
commit b869869e26
12 changed files with 1053 additions and 322 deletions

View File

@ -0,0 +1,146 @@
/**
* @fileoverview Shared Anthropic AI client utilities for Cherry Studio
*
* This module provides functions for creating Anthropic SDK clients with different
* authentication methods (OAuth, API key) and building Claude Code system messages.
* It supports both standard Anthropic API and Anthropic Vertex AI endpoints.
*
* This shared module can be used by both main and renderer processes.
*/
import Anthropic from "@anthropic-ai/sdk";
import {TextBlockParam} from "@anthropic-ai/sdk/resources";
import {Provider} from "@types";
/**
* Creates and configures an Anthropic SDK client based on the provider configuration.
*
* This function supports two authentication methods:
* 1. OAuth: Uses OAuth tokens passed as parameter
* 2. API Key: Uses traditional API key authentication
*
* For OAuth authentication, it includes Claude Code specific headers and beta features.
* For API key authentication, it uses the provider's configuration with custom headers.
*
* @param provider - The provider configuration containing authentication details
* @param oauthToken - Optional OAuth token for OAuth authentication
* @returns An initialized Anthropic or AnthropicVertex client
* @throws Error when OAuth token is not available for OAuth authentication
*
* @example
* ```typescript
* // OAuth authentication
* const oauthProvider = { authType: 'oauth' };
* const oauthClient = getSdkClient(oauthProvider, 'oauth-token-here');
*
* // API key authentication
* const apiKeyProvider = {
* authType: 'apikey',
* apiKey: 'your-api-key',
* apiHost: 'https://api.anthropic.com'
* };
* const apiKeyClient = getSdkClient(apiKeyProvider);
* ```
*/
export function getSdkClient(provider: Provider, oauthToken?: string | null): Anthropic {
if (provider.authType === 'oauth') {
if (!oauthToken) {
throw new Error('OAuth token is not available')
}
return new Anthropic({
authToken: oauthToken,
baseURL: 'https://api.anthropic.com',
dangerouslyAllowBrowser: true,
defaultHeaders: {
'Content-Type': 'application/json',
'anthropic-version': '2023-06-01',
'anthropic-beta': 'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14',
'anthropic-dangerous-direct-browser-access': 'true',
'user-agent': 'claude-cli/1.0.118 (external, sdk-ts)',
'x-app': 'cli',
'x-stainless-retry-count': '0',
'x-stainless-timeout': '600',
'x-stainless-lang': 'js',
'x-stainless-package-version': '0.60.0',
'x-stainless-os': 'MacOS',
'x-stainless-arch': 'arm64',
'x-stainless-runtime': 'node',
'x-stainless-runtime-version': 'v22.18.0'
}
})
}
return new Anthropic({
apiKey: provider.apiKey,
baseURL: provider.apiHost,
dangerouslyAllowBrowser: true,
defaultHeaders: {
'anthropic-beta': 'output-128k-2025-02-19',
...provider.extra_headers
}
})
}
/**
* Builds and prepends the Claude Code system message to user-provided system messages.
*
* This function ensures that all interactions with Claude include the official Claude Code
* system prompt, which identifies the assistant as "Claude Code, Anthropic's official CLI for Claude."
*
* The function handles three cases:
* 1. No system message provided: Returns only the default Claude Code system message
* 2. String system message: Converts to array format and prepends Claude Code message
* 3. Array system message: Checks if Claude Code message exists and prepends if missing
*
* @param system - Optional user-provided system message (string or TextBlockParam array)
* @returns Combined system message with Claude Code prompt prepended
*
* @example
* ```typescript
* // No system message
* const result1 = buildClaudeCodeSystemMessage();
* // Returns: "You are Claude Code, Anthropic's official CLI for Claude."
*
* // String system message
* const result2 = buildClaudeCodeSystemMessage("You are a helpful assistant.");
* // Returns: [
* // { type: 'text', text: "You are Claude Code, Anthropic's official CLI for Claude." },
* // { type: 'text', text: "You are a helpful assistant." }
* // ]
*
* // Array system message
* const systemArray = [{ type: 'text', text: 'Custom instructions' }];
* const result3 = buildClaudeCodeSystemMessage(systemArray);
* // Returns: Array with Claude Code message prepended
* ```
*/
export function buildClaudeCodeSystemMessage(system?: string | Array<TextBlockParam>): string | Array<TextBlockParam> {
const defaultClaudeCodeSystem = `You are Claude Code, Anthropic's official CLI for Claude.`
if (!system) {
return defaultClaudeCodeSystem
}
if (typeof system === 'string') {
if (system.trim() === defaultClaudeCodeSystem) {
return system
}
return [
{
type: 'text',
text: defaultClaudeCodeSystem
},
{
type: 'text',
text: system
}
]
}
if (system[0].text.trim() != defaultClaudeCodeSystem) {
system.unshift({
type: 'text',
text: defaultClaudeCodeSystem
})
}
return system
}

View File

@ -1,16 +1,16 @@
import { loggerService } from '@main/services/LoggerService' import {loggerService} from '@main/services/LoggerService'
import cors from 'cors' import cors from 'cors'
import express from 'express' import express from 'express'
import { v4 as uuidv4 } from 'uuid' import {v4 as uuidv4} from 'uuid'
import { authMiddleware } from './middleware/auth' import {authMiddleware} from './middleware/auth'
import { errorHandler } from './middleware/error' import {errorHandler} from './middleware/error'
import { setupOpenAPIDocumentation } from './middleware/openapi' import {setupOpenAPIDocumentation} from './middleware/openapi'
import { agentsRoutes } from './routes/agents' import {agentsRoutes} from './routes/agents'
import { chatRoutes } from './routes/chat' import {chatRoutes} from './routes/chat'
import { mcpRoutes } from './routes/mcp' import {mcpRoutes} from './routes/mcp'
import { messagesRoutes } from './routes/messages' import {messagesProviderRoutes, messagesRoutes} from './routes/messages'
import { modelsRoutes } from './routes/models' import {modelsRoutes} from './routes/models'
const logger = loggerService.withContext('ApiServer') const logger = loggerService.withContext('ApiServer')
@ -108,6 +108,14 @@ app.get('/', (_req, res) => {
}) })
}) })
// Provider-specific API routes with auth (must be before /v1 to avoid conflicts)
const providerRouter = express.Router({mergeParams: true})
providerRouter.use(authMiddleware)
providerRouter.use(express.json())
// Mount provider-specific messages route
providerRouter.use('/v1/messages', messagesProviderRoutes)
app.use('/:provider', providerRouter)
// API v1 routes with auth // API v1 routes with auth
const apiRouter = express.Router() const apiRouter = express.Router()
apiRouter.use(authMiddleware) apiRouter.use(authMiddleware)
@ -120,10 +128,11 @@ apiRouter.use('/models', modelsRoutes)
apiRouter.use('/agents', agentsRoutes) apiRouter.use('/agents', agentsRoutes)
app.use('/v1', apiRouter) app.use('/v1', apiRouter)
// Setup OpenAPI documentation // Setup OpenAPI documentation
setupOpenAPIDocumentation(app) setupOpenAPIDocumentation(app)
// Error handling (must be last) // Error handling (must be last)
app.use(errorHandler) app.use(errorHandler)
export { app } export {app}

View File

@ -0,0 +1,368 @@
import type {NextFunction, Request, Response} from 'express'
import {beforeEach, describe, expect, it, vi} from 'vitest'
import {config} from '../../config'
import {authMiddleware} from '../auth'
// Mock the config module
vi.mock('../../config', () => ({
config: {
get: vi.fn()
}
}))
// Mock the logger
vi.mock('@logger', () => ({
loggerService: {
withContext: vi.fn(() => ({
debug: vi.fn()
}))
}
}))
const mockConfig = config as any
describe('authMiddleware', () => {
let req: Partial<Request>
let res: Partial<Response>
let next: NextFunction
let jsonMock: ReturnType<typeof vi.fn>
let statusMock: ReturnType<typeof vi.fn>
beforeEach(() => {
jsonMock = vi.fn()
statusMock = vi.fn(() => ({json: jsonMock}))
req = {
header: vi.fn()
}
res = {
status: statusMock
}
next = vi.fn()
vi.clearAllMocks()
})
describe('Missing credentials', () => {
it('should return 401 when both auth headers are missing', async () => {
;(req.header as any).mockReturnValue('')
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: missing credentials'})
expect(next).not.toHaveBeenCalled()
})
it('should return 401 when both auth headers are empty strings', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'authorization') return ''
if (header === 'x-api-key') return ''
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: missing credentials'})
expect(next).not.toHaveBeenCalled()
})
})
describe('Server configuration', () => {
it('should return 403 when API key is not configured', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'x-api-key') return 'some-key'
return ''
})
mockConfig.get.mockResolvedValue({apiKey: ''})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
expect(next).not.toHaveBeenCalled()
})
it('should return 403 when API key is null', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'x-api-key') return 'some-key'
return ''
})
mockConfig.get.mockResolvedValue({apiKey: null})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
expect(next).not.toHaveBeenCalled()
})
})
describe('API Key authentication (priority)', () => {
const validApiKey = 'valid-api-key-123'
beforeEach(() => {
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
})
it('should authenticate successfully with valid API key', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'x-api-key') return validApiKey
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(next).toHaveBeenCalled()
expect(statusMock).not.toHaveBeenCalled()
})
it('should return 403 with invalid API key', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'x-api-key') return 'invalid-key'
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
expect(next).not.toHaveBeenCalled()
})
it('should return 401 with empty API key', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'x-api-key') return ' '
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: empty x-api-key'})
expect(next).not.toHaveBeenCalled()
})
it('should handle API key with whitespace', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'x-api-key') return ` ${validApiKey} `
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(next).toHaveBeenCalled()
expect(statusMock).not.toHaveBeenCalled()
})
it('should prioritize API key over Bearer token when both are present', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'x-api-key') return validApiKey
if (header === 'authorization') return 'Bearer invalid-token'
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(next).toHaveBeenCalled()
expect(statusMock).not.toHaveBeenCalled()
})
it('should return 403 when API key is invalid even if Bearer token is valid', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'x-api-key') return 'invalid-key'
if (header === 'authorization') return `Bearer ${validApiKey}`
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
expect(next).not.toHaveBeenCalled()
})
})
describe('Bearer token authentication (fallback)', () => {
const validApiKey = 'valid-api-key-123'
beforeEach(() => {
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
})
it('should authenticate successfully with valid Bearer token when no API key', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'authorization') return `Bearer ${validApiKey}`
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(next).toHaveBeenCalled()
expect(statusMock).not.toHaveBeenCalled()
})
it('should return 403 with invalid Bearer token', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'authorization') return 'Bearer invalid-token'
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
expect(next).not.toHaveBeenCalled()
})
it('should return 401 with malformed authorization header', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'authorization') return 'Basic sometoken'
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
expect(next).not.toHaveBeenCalled()
})
it('should return 401 with Bearer without space', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'authorization') return 'Bearer'
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
expect(next).not.toHaveBeenCalled()
})
it('should handle Bearer token with only trailing spaces (edge case)', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'authorization') return 'Bearer ' // This will be trimmed to "Bearer" and fail format check
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
expect(next).not.toHaveBeenCalled()
})
it('should handle Bearer token with case insensitive prefix', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'authorization') return `bearer ${validApiKey}`
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(next).toHaveBeenCalled()
expect(statusMock).not.toHaveBeenCalled()
})
it('should handle Bearer token with whitespace', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'authorization') return ` Bearer ${validApiKey} `
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(next).toHaveBeenCalled()
expect(statusMock).not.toHaveBeenCalled()
})
})
describe('Edge cases', () => {
const validApiKey = 'valid-api-key-123'
beforeEach(() => {
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
})
it('should handle config.get() rejection', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'x-api-key') return validApiKey
return ''
})
mockConfig.get.mockRejectedValue(new Error('Config error'))
await expect(authMiddleware(req as Request, res as Response, next)).rejects.toThrow('Config error')
})
it('should use timing-safe comparison for different length tokens', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'x-api-key') return 'short'
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
expect(next).not.toHaveBeenCalled()
})
it('should return 401 when neither credential format is valid', async () => {
;(req.header as any).mockImplementation((header: string) => {
if (header === 'authorization') return 'Invalid format'
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
expect(next).not.toHaveBeenCalled()
})
})
describe('Timing attack protection', () => {
const validApiKey = 'valid-api-key-123'
beforeEach(() => {
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
})
it('should handle similar length but different API keys securely', async () => {
const similarKey = 'valid-api-key-124' // Same length, different last char
;(req.header as any).mockImplementation((header: string) => {
if (header === 'x-api-key') return similarKey
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
expect(next).not.toHaveBeenCalled()
})
it('should handle similar length but different Bearer tokens securely', async () => {
const similarKey = 'valid-api-key-124' // Same length, different last char
;(req.header as any).mockImplementation((header: string) => {
if (header === 'authorization') return `Bearer ${similarKey}`
return ''
})
await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
expect(next).not.toHaveBeenCalled()
})
})
})

View File

@ -1,62 +1,67 @@
import crypto from 'crypto' import crypto from 'crypto'
import { NextFunction, Request, Response } from 'express' import {NextFunction, Request, Response} from 'express'
import { config } from '../config' import {config} from '../config'
const isValidToken = (token: string, apiKey: string): boolean => {
if (token.length !== apiKey.length) {
return false
}
const tokenBuf = Buffer.from(token)
const keyBuf = Buffer.from(apiKey)
return crypto.timingSafeEqual(tokenBuf, keyBuf)
}
export const authMiddleware = async (req: Request, res: Response, next: NextFunction) => { export const authMiddleware = async (req: Request, res: Response, next: NextFunction) => {
const auth = req.header('Authorization') || '' const auth = req.header('authorization') || ''
const xApiKey = req.header('x-api-key') || '' const xApiKey = req.header('x-api-key') || ''
// Fast rejection if neither credential header provided // Fast rejection if neither credential header provided
if (!auth && !xApiKey) { if (!auth && !xApiKey) {
return res.status(401).json({ error: 'Unauthorized: missing credentials' }) return res.status(401).json({error: 'Unauthorized: missing credentials'})
} }
let token: string | undefined const {apiKey} = await config.get()
// Prefer Bearer if wellformed if (!apiKey) {
return res.status(403).json({error: 'Forbidden'})
}
// Check API key first (priority)
if (xApiKey) {
const trimmedApiKey = xApiKey.trim()
if (!trimmedApiKey) {
return res.status(401).json({error: 'Unauthorized: empty x-api-key'})
}
if (isValidToken(trimmedApiKey, apiKey)) {
return next()
} else {
return res.status(403).json({error: 'Forbidden'})
}
}
// Fallback to Bearer token
if (auth) { if (auth) {
const trimmed = auth.trim() const trimmed = auth.trim()
const bearerPrefix = /^Bearer\s+/i const bearerPrefix = /^Bearer\s+/i
if (bearerPrefix.test(trimmed)) {
const candidate = trimmed.replace(bearerPrefix, '').trim() if (!bearerPrefix.test(trimmed)) {
if (!candidate) { return res.status(401).json({error: 'Unauthorized: invalid authorization format'})
return res.status(401).json({ error: 'Unauthorized: empty bearer token' }) }
}
token = candidate const token = trimmed.replace(bearerPrefix, '').trim()
if (!token) {
return res.status(401).json({error: 'Unauthorized: empty bearer token'})
}
if (isValidToken(token, apiKey)) {
return next()
} else {
return res.status(403).json({error: 'Forbidden'})
} }
} }
// Fallback to x-api-key if token still not resolved return res.status(401).json({error: 'Unauthorized: invalid credentials format'})
if (!token && xApiKey) {
if (!xApiKey.trim()) {
return res.status(401).json({ error: 'Unauthorized: empty x-api-key' })
}
token = xApiKey.trim()
}
if (!token) {
// At this point we had at least one header, but none yielded a usable token
return res.status(401).json({ error: 'Unauthorized: invalid credentials format' })
}
const { apiKey } = await config.get()
if (!apiKey) {
// If server not configured, treat as forbidden (or could be 500). Choose 403 to avoid leaking config state.
return res.status(403).json({ error: 'Forbidden' })
}
// Timing-safe compare when lengths match, else immediate forbidden
if (token.length !== apiKey.length) {
return res.status(403).json({ error: 'Forbidden' })
}
const tokenBuf = Buffer.from(token)
const keyBuf = Buffer.from(apiKey)
if (!crypto.timingSafeEqual(tokenBuf, keyBuf)) {
return res.status(403).json({ error: 'Forbidden' })
}
return next()
} }

View File

@ -1,13 +1,185 @@
import { MessageCreateParams } from '@anthropic-ai/sdk/resources' import {MessageCreateParams} from '@anthropic-ai/sdk/resources'
import express, { Request, Response } from 'express' import {loggerService} from '@logger'
import express, {Request, Response} from 'express'
import { loggerService } from '../../services/LoggerService' import {messagesService} from '../services/messages'
import { messagesService } from '../services/messages' import {getProviderById, validateModelId} from '../utils'
import { validateModelId } from '../utils'
const logger = loggerService.withContext('ApiServerMessagesRoutes') const logger = loggerService.withContext('ApiServerMessagesRoutes')
const router = express.Router() const router = express.Router()
const providerRouter = express.Router({mergeParams: true})
// Helper functions for shared logic
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
logger.info('Validating request body', { body: req.body })
const request: MessageCreateParams = req.body
if (!request) {
return {
valid: false,
error: {
type: 'error',
error: {
type: 'invalid_request_error',
message: 'Request body is required'
}
}
}
}
return {valid: true}
}
async function handleStreamingResponse(
res: Response,
request: MessageCreateParams,
provider: any,
messagesService: any,
logger: any
): Promise<void> {
res.setHeader('Content-Type', 'text/event-stream; charset=utf-8')
res.setHeader('Cache-Control', 'no-cache, no-transform')
res.setHeader('Connection', 'keep-alive')
res.setHeader('X-Accel-Buffering', 'no')
res.flushHeaders()
try {
for await (const chunk of messagesService.processStreamingMessage(request, provider)) {
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
}
res.write('data: [DONE]\n\n')
} catch (streamError: any) {
logger.error('Stream error:', streamError)
res.write(
`data: ${JSON.stringify({
type: 'error',
error: {
type: 'api_error',
message: 'Stream processing error'
}
})}\n\n`
)
} finally {
res.end()
}
}
function handleErrorResponse(res: Response, error: any, logger: any): Response {
logger.error('Message processing error:', error)
let statusCode = 500
let errorType = 'api_error'
let errorMessage = 'Internal server error'
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
const anthropicError = error?.error
if (anthropicStatus) {
statusCode = anthropicStatus
}
if (anthropicError?.type) {
errorType = anthropicError.type
}
if (anthropicError?.message) {
errorMessage = anthropicError.message
} else if (error instanceof Error && error.message) {
errorMessage = error.message
}
if (!anthropicStatus && error instanceof Error) {
if (error.message.includes('API key') || error.message.includes('authentication')) {
statusCode = 401
errorType = 'authentication_error'
} else if (error.message.includes('rate limit') || error.message.includes('quota')) {
statusCode = 429
errorType = 'rate_limit_error'
} else if (error.message.includes('timeout') || error.message.includes('connection')) {
statusCode = 502
errorType = 'api_error'
} else if (error.message.includes('validation') || error.message.includes('invalid')) {
statusCode = 400
errorType = 'invalid_request_error'
}
}
return res.status(statusCode).json({
type: 'error',
error: {
type: errorType,
message: errorMessage,
requestId: error?.request_id
}
})
}
async function processMessageRequest(
req: Request,
res: Response,
provider: any,
modelId?: string
): Promise<Response | void> {
try {
const request: MessageCreateParams = req.body
// Use provided modelId or keep original model
if (modelId) {
request.model = modelId
}
logger.info('Processing message request:', {
provider: provider.id,
model: request.model,
messageCount: request.messages?.length || 0,
stream: request.stream,
max_tokens: request.max_tokens,
temperature: request.temperature
})
// Ensure provider is Anthropic type
if (provider.type !== 'anthropic') {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: `Invalid provider type '${provider.type}' for messages endpoint. Expected 'anthropic' provider.`
}
})
}
logger.info('Provider validation successful:', {
provider: provider.id,
providerType: provider.type,
modelId: request.model
})
// Validate request
const validation = messagesService.validateRequest(request)
if (!validation.isValid) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: validation.errors.join('; ')
}
})
}
// Handle streaming
if (request.stream) {
await handleStreamingResponse(res, request, provider, messagesService, logger)
return
}
// Handle non-streaming
const response = await messagesService.processMessage(request, provider)
return res.json(response)
} catch (error: any) {
return handleErrorResponse(res, error, logger)
}
}
/** /**
* @swagger * @swagger
@ -133,25 +305,20 @@ const router = express.Router()
* description: Internal server error * description: Internal server error
*/ */
router.post('/', async (req: Request, res: Response) => { router.post('/', async (req: Request, res: Response) => {
// Validate request body
const bodyValidation = await validateRequestBody(req)
if (!bodyValidation.valid) {
return res.status(400).json(bodyValidation.error)
}
try { try {
const request: MessageCreateParams = req.body const request: MessageCreateParams = req.body
if (!request) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'Request body is required'
}
})
}
logger.info('Anthropic message request:', { logger.info('Anthropic message request:', {
model: request.model, model: request.model,
messageCount: request.messages?.length || 0, messageCount: request.messages?.length || 0,
stream: request.stream, stream: request.stream,
max_tokens: request.max_tokens, max_tokens: request.max_tokens,
temperature: request.temperature
}) })
// Validate model ID and get provider // Validate model ID and get provider
@ -169,20 +336,7 @@ router.post('/', async (req: Request, res: Response) => {
} }
const provider = modelValidation.provider! const provider = modelValidation.provider!
// Ensure provider is Anthropic type
if (provider.type !== 'anthropic') {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: `Invalid provider type '${provider.type}' for messages endpoint. Expected 'anthropic' provider.`
}
})
}
const modelId = modelValidation.modelId! const modelId = modelValidation.modelId!
request.model = modelId
logger.info('Model validation successful:', { logger.info('Model validation successful:', {
provider: provider.id, provider: provider.id,
@ -191,100 +345,181 @@ router.post('/', async (req: Request, res: Response) => {
fullModelId: request.model fullModelId: request.model
}) })
// Validate request // Use shared processing function
const validation = messagesService.validateRequest(request) return await processMessageRequest(req, res, provider, modelId)
if (!validation.isValid) { } catch (error: any) {
return handleErrorResponse(res, error, logger)
}
})
/**
* @swagger
* /{provider_id}/v1/messages:
* post:
* summary: Create message with provider in path
* description: Create a message response using provider ID from URL path
* tags: [Messages]
* parameters:
* - in: path
* name: provider_id
* required: true
* schema:
* type: string
* description: Provider ID (e.g., "my-anthropic")
* example: "my-anthropic"
* requestBody:
* required: true
* content:
* application/json:
* schema:
* type: object
* required:
* - model
* - max_tokens
* - messages
* properties:
* model:
* type: string
* description: Model ID without provider prefix
* example: "claude-3-5-sonnet-20241022"
* max_tokens:
* type: integer
* minimum: 1
* description: Maximum number of tokens to generate
* example: 1024
* messages:
* type: array
* items:
* type: object
* properties:
* role:
* type: string
* enum: [user, assistant]
* content:
* oneOf:
* - type: string
* - type: array
* system:
* type: string
* description: System message
* temperature:
* type: number
* minimum: 0
* maximum: 1
* description: Sampling temperature
* top_p:
* type: number
* minimum: 0
* maximum: 1
* description: Nucleus sampling
* top_k:
* type: integer
* minimum: 0
* description: Top-k sampling
* stream:
* type: boolean
* description: Whether to stream the response
* tools:
* type: array
* description: Available tools for the model
* responses:
* 200:
* description: Message response
* content:
* application/json:
* schema:
* type: object
* properties:
* id:
* type: string
* type:
* type: string
* example: message
* role:
* type: string
* example: assistant
* content:
* type: array
* items:
* type: object
* model:
* type: string
* stop_reason:
* type: string
* stop_sequence:
* type: string
* usage:
* type: object
* properties:
* input_tokens:
* type: integer
* output_tokens:
* type: integer
* text/event-stream:
* schema:
* type: string
* description: Server-sent events stream (when stream=true)
* 400:
* description: Bad request
* 401:
* description: Unauthorized
* 429:
* description: Rate limit exceeded
* 500:
* description: Internal server error
*/
providerRouter.post('/', async (req: Request, res: Response) => {
// Validate request body
const bodyValidation = await validateRequestBody(req)
if (!bodyValidation.valid) {
return res.status(400).json(bodyValidation.error)
}
try {
const providerId = req.params.provider
const request: MessageCreateParams = req.body
if (!providerId) {
return res.status(400).json({ return res.status(400).json({
type: 'error', type: 'error',
error: { error: {
type: 'invalid_request_error', type: 'invalid_request_error',
message: validation.errors.join('; ') message: 'Provider ID is required in URL path'
} }
}) })
} }
// Handle streaming logger.info('Provider-specific message request:', {
if (request.stream) { providerId,
res.setHeader('Content-Type', 'text/event-stream; charset=utf-8') model: request.model,
res.setHeader('Cache-Control', 'no-cache, no-transform') messageCount: request.messages?.length || 0,
res.setHeader('Connection', 'keep-alive') stream: request.stream,
res.setHeader('X-Accel-Buffering', 'no') max_tokens: request.max_tokens
res.flushHeaders()
try {
for await (const chunk of messagesService.processStreamingMessage(request, provider)) {
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
}
res.write('data: [DONE]\n\n')
} catch (streamError: any) {
logger.error('Stream error:', streamError)
res.write(
`data: ${JSON.stringify({
type: 'error',
error: {
type: 'api_error',
message: 'Stream processing error'
}
})}\n\n`
)
} finally {
res.end()
}
return
}
// Handle non-streaming
const response = await messagesService.processMessage(request, provider)
return res.json(response)
} catch (error: any) {
logger.error('Anthropic message error:', error)
let statusCode = 500
let errorType = 'api_error'
let errorMessage = 'Internal server error'
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
const anthropicError = error?.error
if (anthropicStatus) {
statusCode = anthropicStatus
}
if (anthropicError?.type) {
errorType = anthropicError.type
}
if (anthropicError?.message) {
errorMessage = anthropicError.message
} else if (error instanceof Error && error.message) {
errorMessage = error.message
}
if (!anthropicStatus && error instanceof Error) {
if (error.message.includes('API key') || error.message.includes('authentication')) {
statusCode = 401
errorType = 'authentication_error'
} else if (error.message.includes('rate limit') || error.message.includes('quota')) {
statusCode = 429
errorType = 'rate_limit_error'
} else if (error.message.includes('timeout') || error.message.includes('connection')) {
statusCode = 502
errorType = 'api_error'
} else if (error.message.includes('validation') || error.message.includes('invalid')) {
statusCode = 400
errorType = 'invalid_request_error'
}
}
return res.status(statusCode).json({
type: 'error',
error: {
type: errorType,
message: errorMessage,
requestId: error?.request_id
}
}) })
// Get provider directly by ID from URL path
const provider = await getProviderById(providerId)
if (!provider) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: `Provider '${providerId}' not found or not enabled`
}
})
}
logger.info('Provider validation successful:', {
provider: provider.id,
providerType: provider.type,
modelId: request.model
})
// Use shared processing function (no modelId override needed)
return await processMessageRequest(req, res, provider)
} catch (error: any) {
return handleErrorResponse(res, error, logger)
} }
}) })
export { router as messagesRoutes } export {providerRouter as messagesProviderRoutes, router as messagesRoutes}

View File

@ -1,8 +1,10 @@
import Anthropic from '@anthropic-ai/sdk' import Anthropic from "@anthropic-ai/sdk";
import { Message, MessageCreateParams, RawMessageStreamEvent } from '@anthropic-ai/sdk/resources' import {Message, MessageCreateParams, RawMessageStreamEvent} from '@anthropic-ai/sdk/resources'
import { Provider } from '@types' import {loggerService} from '@logger'
import anthropicService from "@main/services/AnthropicService";
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
import {Provider} from '@types'
import { loggerService } from '../../services/LoggerService'
const logger = loggerService.withContext('MessagesService') const logger = loggerService.withContext('MessagesService')
@ -35,6 +37,16 @@ export class MessagesService {
} }
} }
async getClient(provider: Provider): Promise<Anthropic> {
// Create Anthropic client for the provider
if (provider.authType === 'oauth') {
const oauthToken = await anthropicService.getValidAccessToken()
return getSdkClient(provider, oauthToken)
}
return getSdkClient(provider)
}
async processMessage(request: MessageCreateParams, provider: Provider): Promise<Message> { async processMessage(request: MessageCreateParams, provider: Provider): Promise<Message> {
logger.info('Processing Anthropic message request:', { logger.info('Processing Anthropic message request:', {
model: request.model, model: request.model,
@ -44,10 +56,7 @@ export class MessagesService {
}) })
// Create Anthropic client for the provider // Create Anthropic client for the provider
const client = new Anthropic({ const client = await this.getClient(provider)
baseURL: provider.apiHost,
apiKey: provider.apiKey
})
// Prepare request with the actual model ID // Prepare request with the actual model ID
const anthropicRequest: MessageCreateParams = { const anthropicRequest: MessageCreateParams = {
@ -55,6 +64,10 @@ export class MessagesService {
stream: false stream: false
} }
if (provider.authType === 'oauth') {
anthropicRequest.system = buildClaudeCodeSystemMessage(request.system || '')
}
logger.debug('Sending request to Anthropic provider:', { logger.debug('Sending request to Anthropic provider:', {
provider: provider.id, provider: provider.id,
apiHost: provider.apiHost apiHost: provider.apiHost
@ -66,7 +79,7 @@ export class MessagesService {
return response return response
} }
async *processStreamingMessage( async* processStreamingMessage(
request: MessageCreateParams, request: MessageCreateParams,
provider: Provider provider: Provider
): AsyncIterable<RawMessageStreamEvent> { ): AsyncIterable<RawMessageStreamEvent> {
@ -76,10 +89,7 @@ export class MessagesService {
}) })
// Create Anthropic client for the provider // Create Anthropic client for the provider
const client = new Anthropic({ const client = await this.getClient(provider)
baseURL: provider.apiHost,
apiKey: provider.apiKey
})
// Prepare streaming request // Prepare streaming request
const streamingRequest: MessageCreateParams = { const streamingRequest: MessageCreateParams = {
@ -87,6 +97,10 @@ export class MessagesService {
stream: true stream: true
} }
if (provider.authType === 'oauth') {
streamingRequest.system = buildClaudeCodeSystemMessage(request.system || '')
}
logger.debug('Sending streaming request to Anthropic provider:', { logger.debug('Sending streaming request to Anthropic provider:', {
provider: provider.id, provider: provider.id,
apiHost: provider.apiHost apiHost: provider.apiHost

View File

@ -204,6 +204,31 @@ export function transformModelToOpenAI(model: Model, providers: Provider[]): Api
} }
} }
export async function getProviderById(providerId: string): Promise<Provider | undefined> {
try {
if (!providerId || typeof providerId !== 'string') {
logger.warn(`Invalid provider ID parameter: ${providerId}`)
return undefined
}
const providers = await getAvailableProviders()
const provider = providers.find((p: Provider) => p.id === providerId)
if (!provider) {
logger.warn(
`Provider '${providerId}' not found or not enabled. Available providers: ${providers.map((p) => p.id).join(', ')}`
)
return undefined
}
logger.debug(`Found provider '${providerId}'`)
return provider
} catch (error: any) {
logger.error('Failed to get provider by ID:', error)
return undefined
}
}
export function validateProvider(provider: Provider): boolean { export function validateProvider(provider: Provider): boolean {
try { try {
if (!provider) { if (!provider) {

View File

@ -1,16 +1,11 @@
import { loggerService } from '@logger' import {loggerService} from '@logger'
import type { import type {AgentSessionMessageEntity, CreateSessionMessageRequest, GetAgentSessionResponse, ListOptions} from '@types'
AgentSessionMessageEntity, import {TextStreamPart} from 'ai'
CreateSessionMessageRequest, import {desc, eq} from 'drizzle-orm'
GetAgentSessionResponse,
ListOptions
} from '@types'
import { ModelMessage, TextStreamPart } from 'ai'
import { desc, eq } from 'drizzle-orm'
import { BaseService } from '../BaseService' import {BaseService} from '../BaseService'
import { sessionMessagesTable } from '../database/schema' import {sessionMessagesTable} from '../database/schema'
import { AgentStreamEvent } from '../interfaces/AgentStreamInterface' import {AgentStreamEvent} from '../interfaces/AgentStreamInterface'
import ClaudeCodeService from './claudecode' import ClaudeCodeService from './claudecode'
const logger = loggerService.withContext('SessionMessageService') const logger = loggerService.withContext('SessionMessageService')
@ -34,7 +29,7 @@ function serializeError(error: unknown): { message: string; name?: string; stack
} }
if (typeof error === 'string') { if (typeof error === 'string') {
return { message: error } return {message: error}
} }
return { return {
@ -104,7 +99,7 @@ export class SessionMessageService extends BaseService {
this.ensureInitialized() this.ensureInitialized()
const result = await this.database const result = await this.database
.select({ id: sessionMessagesTable.id }) .select({id: sessionMessagesTable.id})
.from(sessionMessagesTable) .from(sessionMessagesTable)
.where(eq(sessionMessagesTable.id, id)) .where(eq(sessionMessagesTable.id, id))
.limit(1) .limit(1)
@ -134,7 +129,7 @@ export class SessionMessageService extends BaseService {
const messages = result.map((row) => this.deserializeSessionMessage(row)) as AgentSessionMessageEntity[] const messages = result.map((row) => this.deserializeSessionMessage(row)) as AgentSessionMessageEntity[]
return { messages } return {messages}
} }
async createSessionMessage( async createSessionMessage(
@ -153,11 +148,11 @@ export class SessionMessageService extends BaseService {
abortController: AbortController abortController: AbortController
): Promise<SessionStreamResult> { ): Promise<SessionStreamResult> {
const agentSessionId = await this.getLastAgentSessionId(session.id) const agentSessionId = await this.getLastAgentSessionId(session.id)
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId }) logger.debug('Session Message stream message data:', {message: req, session_id: agentSessionId})
if (session.agent_type !== 'claude-code') { if (session.agent_type !== 'claude-code') {
// TODO: Implement support for other agent types // TODO: Implement support for other agent types
logger.error('Unsupported agent type for streaming:', { agent_type: session.agent_type }) logger.error('Unsupported agent type for streaming:', {agent_type: session.agent_type})
throw new Error('Unsupported agent type for streaming') throw new Error('Unsupported agent type for streaming')
} }
@ -248,7 +243,7 @@ export class SessionMessageService extends BaseService {
} }
}) })
return { stream, completion } return {stream, completion}
} }
private async getLastAgentSessionId(sessionId: string): Promise<string> { private async getLastAgentSessionId(sessionId: string): Promise<string> {
@ -256,7 +251,7 @@ export class SessionMessageService extends BaseService {
try { try {
const result = await this.database const result = await this.database
.select({ agent_session_id: sessionMessagesTable.agent_session_id }) .select({agent_session_id: sessionMessagesTable.agent_session_id})
.from(sessionMessagesTable) .from(sessionMessagesTable)
.where(eq(sessionMessagesTable.session_id, sessionId)) .where(eq(sessionMessagesTable.session_id, sessionId))
.orderBy(desc(sessionMessagesTable.created_at)) .orderBy(desc(sessionMessagesTable.created_at))
@ -275,7 +270,7 @@ export class SessionMessageService extends BaseService {
private deserializeSessionMessage(data: any): AgentSessionMessageEntity { private deserializeSessionMessage(data: any): AgentSessionMessageEntity {
if (!data) return data if (!data) return data
const deserialized = { ...data } const deserialized = {...data}
// Parse content JSON // Parse content JSON
if (deserialized.content && typeof deserialized.content === 'string') { if (deserialized.content && typeof deserialized.content === 'string') {

View File

@ -69,17 +69,9 @@ class ClaudeCodeService implements AgentServiceInterface {
// process.env.ANTHROPIC_BASE_URL = `http://${apiConfig.host}:${apiConfig.port}` // process.env.ANTHROPIC_BASE_URL = `http://${apiConfig.host}:${apiConfig.port}`
const env = { const env = {
...process.env, ...process.env,
ELECTRON_RUN_AS_NODE: '1', ANTHROPIC_API_KEY: apiConfig.apiKey,
} ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
ELECTRON_RUN_AS_NODE: '1'
if (modelInfo.provider.authType === 'oauth') {
// TODO: support claude code max oauth
// env['ANTHROPIC_AUTH_TOKEN'] = await anthropicService.getValidAccessToken()
// env['ANTHROPIC_BASE_URL'] = 'https://api.anthropic.com'
} else {
env['ANTHROPIC_AUTH_TOKEN'] = modelInfo.provider.apiKey
env['ANTHROPIC_API_KEY'] = modelInfo.provider.apiKey
env['ANTHROPIC_BASE_URL'] = modelInfo.provider.apiHost
} }
// Build SDK options from parameters // Build SDK options from parameters
@ -121,7 +113,7 @@ class ClaudeCodeService implements AgentServiceInterface {
options.resume = lastAgentSessionId options.resume = lastAgentSessionId
} }
logger.info('Starting Claude Code SDK query', { logger.silly('Starting Claude Code SDK query', {
prompt, prompt,
options options
}) })

View File

@ -150,8 +150,7 @@ function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }
break break
case 'content_block_start': case 'content_block_start':
const contentBlockType = event.content_block.type switch (event.content_block.type) {
switch (contentBlockType) {
case 'text': { case 'text': {
contentBlockState.set(blockKey, { type: 'text' }) contentBlockState.set(blockKey, { type: 'text' })
chunks.push({ chunks.push({

View File

@ -22,14 +22,14 @@ import {
WebSearchToolResultBlockParam, WebSearchToolResultBlockParam,
WebSearchToolResultError WebSearchToolResultError
} from '@anthropic-ai/sdk/resources/messages' } from '@anthropic-ai/sdk/resources/messages'
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages' import {MessageStream} from '@anthropic-ai/sdk/resources/messages/messages'
import AnthropicVertex from '@anthropic-ai/vertex-sdk' import AnthropicVertex from '@anthropic-ai/vertex-sdk'
import { loggerService } from '@logger' import {loggerService} from '@logger'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import {DEFAULT_MAX_TOKENS} from '@renderer/config/constant'
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models' import {findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel} from '@renderer/config/models'
import { getAssistantSettings } from '@renderer/services/AssistantService' import {getAssistantSettings} from '@renderer/services/AssistantService'
import FileManager from '@renderer/services/FileManager' import FileManager from '@renderer/services/FileManager'
import { estimateTextTokens } from '@renderer/services/TokenService' import {estimateTextTokens} from '@renderer/services/TokenService'
import { import {
Assistant, Assistant,
EFFORT_RATIO, EFFORT_RATIO,
@ -53,26 +53,17 @@ import {
ThinkingDeltaChunk, ThinkingDeltaChunk,
ThinkingStartChunk ThinkingStartChunk
} from '@renderer/types/chunk' } from '@renderer/types/chunk'
import { type Message } from '@renderer/types/newMessage' import {type Message} from '@renderer/types/newMessage'
import { import {AnthropicSdkMessageParam, AnthropicSdkParams, AnthropicSdkRawChunk, AnthropicSdkRawOutput} from '@renderer/types/sdk'
AnthropicSdkMessageParam, import {addImageFileToContents} from '@renderer/utils/formats'
AnthropicSdkParams, import {anthropicToolUseToMcpTool, isSupportedToolUse, mcpToolCallResponseToAnthropicMessage, mcpToolsToAnthropicTools} from '@renderer/utils/mcp-tools'
AnthropicSdkRawChunk, import {findFileBlocks, findImageBlocks} from '@renderer/utils/messageUtils/find'
AnthropicSdkRawOutput import {buildClaudeCodeSystemMessage, getSdkClient} from "@shared/anthropic";
} from '@renderer/types/sdk' import {t} from 'i18next'
import { addImageFileToContents } from '@renderer/utils/formats'
import {
anthropicToolUseToMcpTool,
isSupportedToolUse,
mcpToolCallResponseToAnthropicMessage,
mcpToolsToAnthropicTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { t } from 'i18next'
import { GenericChunk } from '../../middleware/schemas' import {GenericChunk} from '../../middleware/schemas'
import { BaseApiClient } from '../BaseApiClient' import {BaseApiClient} from '../BaseApiClient'
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types' import {AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer} from '../types'
const logger = loggerService.withContext('AnthropicAPIClient') const logger = loggerService.withContext('AnthropicAPIClient')
@ -86,8 +77,8 @@ export class AnthropicAPIClient extends BaseApiClient<
ToolUnion ToolUnion
> { > {
oauthToken: string | undefined = undefined oauthToken: string | undefined = undefined
isOAuthMode: boolean = false
sdkInstance: Anthropic | AnthropicVertex | undefined = undefined sdkInstance: Anthropic | AnthropicVertex | undefined = undefined
constructor(provider: Provider) { constructor(provider: Provider) {
super(provider) super(provider)
} }
@ -96,84 +87,25 @@ export class AnthropicAPIClient extends BaseApiClient<
if (this.sdkInstance) { if (this.sdkInstance) {
return this.sdkInstance return this.sdkInstance
} }
if (this.provider.authType === 'oauth') { if (this.provider.authType === 'oauth') {
if (!this.oauthToken) { this.oauthToken = await window.api.anthropic_oauth.getAccessToken()
throw new Error('OAuth token is not available')
}
this.sdkInstance = new Anthropic({
authToken: this.oauthToken,
baseURL: 'https://api.anthropic.com',
dangerouslyAllowBrowser: true,
defaultHeaders: {
'Content-Type': 'application/json',
'anthropic-version': '2023-06-01',
'anthropic-beta': 'oauth-2025-04-20'
// ...this.provider.extra_headers
}
})
} else {
this.sdkInstance = new Anthropic({
apiKey: this.apiKey,
baseURL: this.getBaseURL(),
dangerouslyAllowBrowser: true,
defaultHeaders: {
'anthropic-beta': 'output-128k-2025-02-19',
...this.provider.extra_headers
}
})
} }
this.sdkInstance = getSdkClient(this.provider, this.oauthToken)
return this.sdkInstance return this.sdkInstance
} }
private buildClaudeCodeSystemMessage(system?: string | Array<TextBlockParam>): string | Array<TextBlockParam> {
const defaultClaudeCodeSystem = `You are Claude Code, Anthropic's official CLI for Claude.`
if (!system) {
return defaultClaudeCodeSystem
}
if (typeof system === 'string') {
if (system.trim() === defaultClaudeCodeSystem) {
return system
}
return [
{
type: 'text',
text: defaultClaudeCodeSystem
},
{
type: 'text',
text: system
}
]
}
if (system[0].text.trim() != defaultClaudeCodeSystem) {
system.unshift({
type: 'text',
text: defaultClaudeCodeSystem
})
}
return system
}
override async createCompletions( override async createCompletions(
payload: AnthropicSdkParams, payload: AnthropicSdkParams,
options?: Anthropic.RequestOptions options?: Anthropic.RequestOptions
): Promise<AnthropicSdkRawOutput> { ): Promise<AnthropicSdkRawOutput> {
if (this.provider.authType === 'oauth') { if (this.provider.authType === 'oauth') {
this.oauthToken = await window.api.anthropic_oauth.getAccessToken() payload.system = buildClaudeCodeSystemMessage(payload.system)
this.isOAuthMode = true
logger.info('[Anthropic Provider] Using OAuth token for authentication')
payload.system = this.buildClaudeCodeSystemMessage(payload.system)
} }
const sdk = (await this.getSdkInstance()) as Anthropic const sdk = (await this.getSdkInstance()) as Anthropic
if (payload.stream) { if (payload.stream) {
return sdk.messages.stream(payload, options) return sdk.messages.stream(payload, options)
} }
return await sdk.messages.create(payload, options) return sdk.messages.create(payload, options);
} }
// @ts-ignore sdk未提供 // @ts-ignore sdk未提供
@ -183,14 +115,8 @@ export class AnthropicAPIClient extends BaseApiClient<
} }
override async listModels(): Promise<Anthropic.ModelInfo[]> { override async listModels(): Promise<Anthropic.ModelInfo[]> {
if (this.provider.authType === 'oauth') {
this.oauthToken = await window.api.anthropic_oauth.getAccessToken()
this.isOAuthMode = true
logger.info('[Anthropic Provider] Using OAuth token for authentication')
}
const sdk = (await this.getSdkInstance()) as Anthropic const sdk = (await this.getSdkInstance()) as Anthropic
const response = await sdk.models.list() const response = await sdk.models.list()
return response.data return response.data
} }
@ -223,7 +149,7 @@ export class AnthropicAPIClient extends BaseApiClient<
if (!isReasoningModel(model)) { if (!isReasoningModel(model)) {
return undefined return undefined
} }
const { maxTokens } = getAssistantSettings(assistant) const {maxTokens} = getAssistantSettings(assistant)
const reasoningEffort = assistant?.settings?.reasoning_effort const reasoningEffort = assistant?.settings?.reasoning_effort
@ -240,7 +166,7 @@ export class AnthropicAPIClient extends BaseApiClient<
Math.floor( Math.floor(
Math.min( Math.min(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
findTokenLimit(model.id)?.min!, findTokenLimit(model.id)?.min!,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
) )
) )
@ -262,7 +188,7 @@ export class AnthropicAPIClient extends BaseApiClient<
* @returns The message parameter * @returns The message parameter
*/ */
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> { public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
const { textContent, imageContents } = await this.getMessageContent(message) const {textContent, imageContents} = await this.getMessageContent(message)
const parts: MessageParam['content'] = [ const parts: MessageParam['content'] = [
{ {
@ -285,7 +211,7 @@ export class AnthropicAPIClient extends BaseApiClient<
} }
}) })
} else { } else {
logger.warn('Unsupported image type, ignored.', { mime: base64Data.mime }) logger.warn('Unsupported image type, ignored.', {mime: base64Data.mime})
} }
} }
} }
@ -310,7 +236,7 @@ export class AnthropicAPIClient extends BaseApiClient<
// Get and process file blocks // Get and process file blocks
const fileBlocks = findFileBlocks(message) const fileBlocks = findFileBlocks(message)
for (const fileBlock of fileBlocks) { for (const fileBlock of fileBlocks) {
const { file } = fileBlock const {file} = fileBlock
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) { if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
const base64Data = await FileManager.readBase64File(file) const base64Data = await FileManager.readBase64File(file)
@ -538,25 +464,25 @@ export class AnthropicAPIClient extends BaseApiClient<
messages: AnthropicSdkMessageParam[] messages: AnthropicSdkMessageParam[]
metadata: Record<string, any> metadata: Record<string, any>
}> => { }> => {
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest const {messages, mcpTools, maxTokens, streamOutput, enableWebSearch} = coreRequest
// 1. 处理系统消息 // 1. 处理系统消息
const systemPrompt = assistant.prompt const systemPrompt = assistant.prompt
// 2. 设置工具 // 2. 设置工具
const { tools } = this.setupToolsConfig({ const {tools} = this.setupToolsConfig({
mcpTools: mcpTools, mcpTools: mcpTools,
model, model,
enableToolUse: isSupportedToolUse(assistant) enableToolUse: isSupportedToolUse(assistant)
}) })
const systemMessage: TextBlockParam | undefined = systemPrompt const systemMessage: TextBlockParam | undefined = systemPrompt
? { type: 'text', text: systemPrompt } ? {type: 'text', text: systemPrompt}
: undefined : undefined
// 3. 处理用户消息 // 3. 处理用户消息
const sdkMessages: AnthropicSdkMessageParam[] = [] const sdkMessages: AnthropicSdkMessageParam[] = []
if (typeof messages === 'string') { if (typeof messages === 'string') {
sdkMessages.push({ role: 'user', content: messages }) sdkMessages.push({role: 'user', content: messages})
} else { } else {
const processedMessages = addImageFileToContents(messages) const processedMessages = addImageFileToContents(messages)
for (const message of processedMessages) { for (const message of processedMessages) {
@ -590,7 +516,7 @@ export class AnthropicAPIClient extends BaseApiClient<
} }
const timeout = this.getTimeout(model) const timeout = this.getTimeout(model)
return { payload: commonParams, messages: sdkMessages, metadata: { timeout } } return {payload: commonParams, messages: sdkMessages, metadata: {timeout}}
} }
} }
} }
@ -605,7 +531,7 @@ export class AnthropicAPIClient extends BaseApiClient<
try { try {
rawChunk = JSON.parse(rawChunk) rawChunk = JSON.parse(rawChunk)
} catch (error) { } catch (error) {
logger.error('invalid chunk', { rawChunk, error }) logger.error('invalid chunk', {rawChunk, error})
throw new Error(t('error.chat.chunk.non_json')) throw new Error(t('error.chat.chunk.non_json'))
} }
} }

View File

@ -77,3 +77,20 @@ Content-Type: application/json
} }
] ]
} }
### Anthropic Chat Message with streaming
POST {{host}}/anthropic/v1/messages
Authorization: Bearer {{token}}
Content-Type: application/json
{
"model": "claude-sonnet-4-20250514",
"stream": true,
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Explain the theory of relativity in simple terms."
}
]
}