💄 style: format code with yarn format

This commit is contained in:
Vaayne 2025-09-21 16:44:54 +08:00
parent b869869e26
commit a09c52424f
19 changed files with 159 additions and 154 deletions

View File

@ -8,9 +8,9 @@
* This shared module can be used by both main and renderer processes. * This shared module can be used by both main and renderer processes.
*/ */
import Anthropic from "@anthropic-ai/sdk"; import Anthropic from '@anthropic-ai/sdk'
import {TextBlockParam} from "@anthropic-ai/sdk/resources"; import { TextBlockParam } from '@anthropic-ai/sdk/resources'
import {Provider} from "@types"; import { Provider } from '@types'
/** /**
* Creates and configures an Anthropic SDK client based on the provider configuration. * Creates and configures an Anthropic SDK client based on the provider configuration.
@ -54,7 +54,8 @@ export function getSdkClient(provider: Provider, oauthToken?: string | null): An
defaultHeaders: { defaultHeaders: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'anthropic-version': '2023-06-01', 'anthropic-version': '2023-06-01',
'anthropic-beta': 'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14', 'anthropic-beta':
'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14',
'anthropic-dangerous-direct-browser-access': 'true', 'anthropic-dangerous-direct-browser-access': 'true',
'user-agent': 'claude-cli/1.0.118 (external, sdk-ts)', 'user-agent': 'claude-cli/1.0.118 (external, sdk-ts)',
'x-app': 'cli', 'x-app': 'cli',

View File

@ -336,4 +336,4 @@
"internal": { "internal": {
"indexes": {} "indexes": {}
} }
} }

View File

@ -1,16 +1,16 @@
import {loggerService} from '@main/services/LoggerService' import { loggerService } from '@main/services/LoggerService'
import cors from 'cors' import cors from 'cors'
import express from 'express' import express from 'express'
import {v4 as uuidv4} from 'uuid' import { v4 as uuidv4 } from 'uuid'
import {authMiddleware} from './middleware/auth' import { authMiddleware } from './middleware/auth'
import {errorHandler} from './middleware/error' import { errorHandler } from './middleware/error'
import {setupOpenAPIDocumentation} from './middleware/openapi' import { setupOpenAPIDocumentation } from './middleware/openapi'
import {agentsRoutes} from './routes/agents' import { agentsRoutes } from './routes/agents'
import {chatRoutes} from './routes/chat' import { chatRoutes } from './routes/chat'
import {mcpRoutes} from './routes/mcp' import { mcpRoutes } from './routes/mcp'
import {messagesProviderRoutes, messagesRoutes} from './routes/messages' import { messagesProviderRoutes, messagesRoutes } from './routes/messages'
import {modelsRoutes} from './routes/models' import { modelsRoutes } from './routes/models'
const logger = loggerService.withContext('ApiServer') const logger = loggerService.withContext('ApiServer')
@ -109,7 +109,7 @@ app.get('/', (_req, res) => {
}) })
// Provider-specific API routes with auth (must be before /v1 to avoid conflicts) // Provider-specific API routes with auth (must be before /v1 to avoid conflicts)
const providerRouter = express.Router({mergeParams: true}) const providerRouter = express.Router({ mergeParams: true })
providerRouter.use(authMiddleware) providerRouter.use(authMiddleware)
providerRouter.use(express.json()) providerRouter.use(express.json())
// Mount provider-specific messages route // Mount provider-specific messages route
@ -128,11 +128,10 @@ apiRouter.use('/models', modelsRoutes)
apiRouter.use('/agents', agentsRoutes) apiRouter.use('/agents', agentsRoutes)
app.use('/v1', apiRouter) app.use('/v1', apiRouter)
// Setup OpenAPI documentation // Setup OpenAPI documentation
setupOpenAPIDocumentation(app) setupOpenAPIDocumentation(app)
// Error handling (must be last) // Error handling (must be last)
app.use(errorHandler) app.use(errorHandler)
export {app} export { app }

View File

@ -1,8 +1,8 @@
import type {NextFunction, Request, Response} from 'express' import type { NextFunction, Request, Response } from 'express'
import {beforeEach, describe, expect, it, vi} from 'vitest' import { beforeEach, describe, expect, it, vi } from 'vitest'
import {config} from '../../config' import { config } from '../../config'
import {authMiddleware} from '../auth' import { authMiddleware } from '../auth'
// Mock the config module // Mock the config module
vi.mock('../../config', () => ({ vi.mock('../../config', () => ({
@ -31,7 +31,7 @@ describe('authMiddleware', () => {
beforeEach(() => { beforeEach(() => {
jsonMock = vi.fn() jsonMock = vi.fn()
statusMock = vi.fn(() => ({json: jsonMock})) statusMock = vi.fn(() => ({ json: jsonMock }))
req = { req = {
header: vi.fn() header: vi.fn()
@ -51,7 +51,7 @@ describe('authMiddleware', () => {
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401) expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: missing credentials'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: missing credentials' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
@ -65,7 +65,7 @@ describe('authMiddleware', () => {
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401) expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: missing credentials'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: missing credentials' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
}) })
@ -77,12 +77,12 @@ describe('authMiddleware', () => {
return '' return ''
}) })
mockConfig.get.mockResolvedValue({apiKey: ''}) mockConfig.get.mockResolvedValue({ apiKey: '' })
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403) expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
@ -92,12 +92,12 @@ describe('authMiddleware', () => {
return '' return ''
}) })
mockConfig.get.mockResolvedValue({apiKey: null}) mockConfig.get.mockResolvedValue({ apiKey: null })
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403) expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
}) })
@ -106,7 +106,7 @@ describe('authMiddleware', () => {
const validApiKey = 'valid-api-key-123' const validApiKey = 'valid-api-key-123'
beforeEach(() => { beforeEach(() => {
mockConfig.get.mockResolvedValue({apiKey: validApiKey}) mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
}) })
it('should authenticate successfully with valid API key', async () => { it('should authenticate successfully with valid API key', async () => {
@ -130,7 +130,7 @@ describe('authMiddleware', () => {
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403) expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
@ -143,7 +143,7 @@ describe('authMiddleware', () => {
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401) expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: empty x-api-key'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: empty x-api-key' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
@ -182,7 +182,7 @@ describe('authMiddleware', () => {
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403) expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
}) })
@ -191,7 +191,7 @@ describe('authMiddleware', () => {
const validApiKey = 'valid-api-key-123' const validApiKey = 'valid-api-key-123'
beforeEach(() => { beforeEach(() => {
mockConfig.get.mockResolvedValue({apiKey: validApiKey}) mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
}) })
it('should authenticate successfully with valid Bearer token when no API key', async () => { it('should authenticate successfully with valid Bearer token when no API key', async () => {
@ -215,7 +215,7 @@ describe('authMiddleware', () => {
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403) expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
@ -228,7 +228,7 @@ describe('authMiddleware', () => {
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401) expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
@ -241,20 +241,20 @@ describe('authMiddleware', () => {
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401) expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
it('should handle Bearer token with only trailing spaces (edge case)', async () => { it('should handle Bearer token with only trailing spaces (edge case)', async () => {
;(req.header as any).mockImplementation((header: string) => { ;(req.header as any).mockImplementation((header: string) => {
if (header === 'authorization') return 'Bearer ' // This will be trimmed to "Bearer" and fail format check if (header === 'authorization') return 'Bearer ' // This will be trimmed to "Bearer" and fail format check
return '' return ''
}) })
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401) expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
@ -287,7 +287,7 @@ describe('authMiddleware', () => {
const validApiKey = 'valid-api-key-123' const validApiKey = 'valid-api-key-123'
beforeEach(() => { beforeEach(() => {
mockConfig.get.mockResolvedValue({apiKey: validApiKey}) mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
}) })
it('should handle config.get() rejection', async () => { it('should handle config.get() rejection', async () => {
@ -310,7 +310,7 @@ describe('authMiddleware', () => {
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403) expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
@ -323,7 +323,7 @@ describe('authMiddleware', () => {
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(401) expect(statusMock).toHaveBeenCalledWith(401)
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
}) })
@ -332,7 +332,7 @@ describe('authMiddleware', () => {
const validApiKey = 'valid-api-key-123' const validApiKey = 'valid-api-key-123'
beforeEach(() => { beforeEach(() => {
mockConfig.get.mockResolvedValue({apiKey: validApiKey}) mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
}) })
it('should handle similar length but different API keys securely', async () => { it('should handle similar length but different API keys securely', async () => {
@ -346,7 +346,7 @@ describe('authMiddleware', () => {
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403) expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
@ -361,7 +361,7 @@ describe('authMiddleware', () => {
await authMiddleware(req as Request, res as Response, next) await authMiddleware(req as Request, res as Response, next)
expect(statusMock).toHaveBeenCalledWith(403) expect(statusMock).toHaveBeenCalledWith(403)
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'}) expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
expect(next).not.toHaveBeenCalled() expect(next).not.toHaveBeenCalled()
}) })
}) })

View File

@ -1,8 +1,7 @@
import crypto from 'crypto' import crypto from 'crypto'
import {NextFunction, Request, Response} from 'express' import { NextFunction, Request, Response } from 'express'
import {config} from '../config'
import { config } from '../config'
const isValidToken = (token: string, apiKey: string): boolean => { const isValidToken = (token: string, apiKey: string): boolean => {
if (token.length !== apiKey.length) { if (token.length !== apiKey.length) {
@ -19,26 +18,26 @@ export const authMiddleware = async (req: Request, res: Response, next: NextFunc
// Fast rejection if neither credential header provided // Fast rejection if neither credential header provided
if (!auth && !xApiKey) { if (!auth && !xApiKey) {
return res.status(401).json({error: 'Unauthorized: missing credentials'}) return res.status(401).json({ error: 'Unauthorized: missing credentials' })
} }
const {apiKey} = await config.get() const { apiKey } = await config.get()
if (!apiKey) { if (!apiKey) {
return res.status(403).json({error: 'Forbidden'}) return res.status(403).json({ error: 'Forbidden' })
} }
// Check API key first (priority) // Check API key first (priority)
if (xApiKey) { if (xApiKey) {
const trimmedApiKey = xApiKey.trim() const trimmedApiKey = xApiKey.trim()
if (!trimmedApiKey) { if (!trimmedApiKey) {
return res.status(401).json({error: 'Unauthorized: empty x-api-key'}) return res.status(401).json({ error: 'Unauthorized: empty x-api-key' })
} }
if (isValidToken(trimmedApiKey, apiKey)) { if (isValidToken(trimmedApiKey, apiKey)) {
return next() return next()
} else { } else {
return res.status(403).json({error: 'Forbidden'}) return res.status(403).json({ error: 'Forbidden' })
} }
} }
@ -48,20 +47,20 @@ export const authMiddleware = async (req: Request, res: Response, next: NextFunc
const bearerPrefix = /^Bearer\s+/i const bearerPrefix = /^Bearer\s+/i
if (!bearerPrefix.test(trimmed)) { if (!bearerPrefix.test(trimmed)) {
return res.status(401).json({error: 'Unauthorized: invalid authorization format'}) return res.status(401).json({ error: 'Unauthorized: invalid authorization format' })
} }
const token = trimmed.replace(bearerPrefix, '').trim() const token = trimmed.replace(bearerPrefix, '').trim()
if (!token) { if (!token) {
return res.status(401).json({error: 'Unauthorized: empty bearer token'}) return res.status(401).json({ error: 'Unauthorized: empty bearer token' })
} }
if (isValidToken(token, apiKey)) { if (isValidToken(token, apiKey)) {
return next() return next()
} else { } else {
return res.status(403).json({error: 'Forbidden'}) return res.status(403).json({ error: 'Forbidden' })
} }
} }
return res.status(401).json({error: 'Unauthorized: invalid credentials format'}) return res.status(401).json({ error: 'Unauthorized: invalid credentials format' })
} }

View File

@ -1,6 +1,6 @@
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { AgentModelValidationError, agentService } from '@main/services/agents' import { AgentModelValidationError, agentService } from '@main/services/agents'
import { ListAgentsResponse,type ReplaceAgentRequest, type UpdateAgentRequest } from '@types' import { ListAgentsResponse, type ReplaceAgentRequest, type UpdateAgentRequest } from '@types'
import { Request, Response } from 'express' import { Request, Response } from 'express'
import type { ValidationRequest } from '../validators/zodValidator' import type { ValidationRequest } from '../validators/zodValidator'

View File

@ -1,9 +1,5 @@
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { import { AgentModelValidationError, sessionMessageService, sessionService } from '@main/services/agents'
AgentModelValidationError,
sessionMessageService,
sessionService
} from '@main/services/agents'
import { import {
CreateSessionResponse, CreateSessionResponse,
ListAgentSessionsResponse, ListAgentSessionsResponse,

View File

@ -1,4 +1,4 @@
import { NextFunction,Request, Response } from 'express' import { NextFunction, Request, Response } from 'express'
import { ZodError, ZodType } from 'zod' import { ZodError, ZodType } from 'zod'
export interface ValidationRequest extends Request { export interface ValidationRequest extends Request {
@ -35,7 +35,7 @@ export const createZodValidator = (config: ZodValidationConfig) => {
type: 'field', type: 'field',
value: err.input, value: err.input,
msg: err.message, msg: err.message,
path: err.path.map(p => String(p)).join('.'), path: err.path.map((p) => String(p)).join('.'),
location: getLocationFromPath(err.path, config) location: getLocationFromPath(err.path, config)
})) }))
@ -65,4 +65,4 @@ function getLocationFromPath(path: (string | number | symbol)[], config: ZodVali
if (config.params && path.length > 0) return 'params' if (config.params && path.length > 0) return 'params'
if (config.query && path.length > 0) return 'query' if (config.query && path.length > 0) return 'query'
return 'unknown' return 'unknown'
} }

View File

@ -1,14 +1,14 @@
import {MessageCreateParams} from '@anthropic-ai/sdk/resources' import { MessageCreateParams } from '@anthropic-ai/sdk/resources'
import {loggerService} from '@logger' import { loggerService } from '@logger'
import express, {Request, Response} from 'express' import express, { Request, Response } from 'express'
import {messagesService} from '../services/messages' import { messagesService } from '../services/messages'
import {getProviderById, validateModelId} from '../utils' import { getProviderById, validateModelId } from '../utils'
const logger = loggerService.withContext('ApiServerMessagesRoutes') const logger = loggerService.withContext('ApiServerMessagesRoutes')
const router = express.Router() const router = express.Router()
const providerRouter = express.Router({mergeParams: true}) const providerRouter = express.Router({ mergeParams: true })
// Helper functions for shared logic // Helper functions for shared logic
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> { async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
@ -28,7 +28,7 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
} }
} }
return {valid: true} return { valid: true }
} }
async function handleStreamingResponse( async function handleStreamingResponse(
@ -318,7 +318,7 @@ router.post('/', async (req: Request, res: Response) => {
model: request.model, model: request.model,
messageCount: request.messages?.length || 0, messageCount: request.messages?.length || 0,
stream: request.stream, stream: request.stream,
max_tokens: request.max_tokens, max_tokens: request.max_tokens
}) })
// Validate model ID and get provider // Validate model ID and get provider
@ -522,4 +522,4 @@ providerRouter.post('/', async (req: Request, res: Response) => {
} }
}) })
export {providerRouter as messagesProviderRoutes, router as messagesRoutes} export { providerRouter as messagesProviderRoutes, router as messagesRoutes }

View File

@ -38,9 +38,10 @@ export type PrepareRequestResult =
} }
export class ChatCompletionService { export class ChatCompletionService {
async resolveProviderContext(model: string): Promise< async resolveProviderContext(
| { ok: false; error: ModelValidationError } model: string
| { ok: true; provider: Provider; modelId: string; client: OpenAI } ): Promise<
{ ok: false; error: ModelValidationError } | { ok: true; provider: Provider; modelId: string; client: OpenAI }
> { > {
const modelValidation = await validateModelId(model) const modelValidation = await validateModelId(model)
if (!modelValidation.valid) { if (!modelValidation.valid) {
@ -196,9 +197,7 @@ export class ChatCompletionService {
} }
} }
async processStreamingCompletion( async processStreamingCompletion(request: ChatCompletionCreateParams): Promise<{
request: ChatCompletionCreateParams
): Promise<{
provider: Provider provider: Provider
modelId: string modelId: string
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk> stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
@ -227,9 +226,9 @@ export class ChatCompletionService {
}) })
const streamRequest = providerRequest as ChatCompletionCreateParamsStreaming const streamRequest = providerRequest as ChatCompletionCreateParamsStreaming
const stream = (await client.chat.completions.create(streamRequest)) as AsyncIterable< const stream = (await client.chat.completions.create(
OpenAI.Chat.Completions.ChatCompletionChunk streamRequest
> )) as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
logger.info('Successfully started streaming chat completion') logger.info('Successfully started streaming chat completion')
return { return {

View File

@ -1,10 +1,9 @@
import Anthropic from "@anthropic-ai/sdk"; import Anthropic from '@anthropic-ai/sdk'
import {Message, MessageCreateParams, RawMessageStreamEvent} from '@anthropic-ai/sdk/resources' import { Message, MessageCreateParams, RawMessageStreamEvent } from '@anthropic-ai/sdk/resources'
import {loggerService} from '@logger' import { loggerService } from '@logger'
import anthropicService from "@main/services/AnthropicService"; import anthropicService from '@main/services/AnthropicService'
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic' import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
import {Provider} from '@types' import { Provider } from '@types'
const logger = loggerService.withContext('MessagesService') const logger = loggerService.withContext('MessagesService')
@ -46,7 +45,6 @@ export class MessagesService {
return getSdkClient(provider) return getSdkClient(provider)
} }
async processMessage(request: MessageCreateParams, provider: Provider): Promise<Message> { async processMessage(request: MessageCreateParams, provider: Provider): Promise<Message> {
logger.info('Processing Anthropic message request:', { logger.info('Processing Anthropic message request:', {
model: request.model, model: request.model,
@ -79,7 +77,7 @@ export class MessagesService {
return response return response
} }
async* processStreamingMessage( async *processStreamingMessage(
request: MessageCreateParams, request: MessageCreateParams,
provider: Provider provider: Provider
): AsyncIterable<RawMessageStreamEvent> { ): AsyncIterable<RawMessageStreamEvent> {

View File

@ -20,4 +20,3 @@ export class AgentModelValidationError extends Error {
this.detail = detail this.detail = detail
} }
} }

View File

@ -1,11 +1,16 @@
import {loggerService} from '@logger' import { loggerService } from '@logger'
import type {AgentSessionMessageEntity, CreateSessionMessageRequest, GetAgentSessionResponse, ListOptions} from '@types' import type {
import {TextStreamPart} from 'ai' AgentSessionMessageEntity,
import {desc, eq} from 'drizzle-orm' CreateSessionMessageRequest,
GetAgentSessionResponse,
ListOptions
} from '@types'
import { TextStreamPart } from 'ai'
import { desc, eq } from 'drizzle-orm'
import {BaseService} from '../BaseService' import { BaseService } from '../BaseService'
import {sessionMessagesTable} from '../database/schema' import { sessionMessagesTable } from '../database/schema'
import {AgentStreamEvent} from '../interfaces/AgentStreamInterface' import { AgentStreamEvent } from '../interfaces/AgentStreamInterface'
import ClaudeCodeService from './claudecode' import ClaudeCodeService from './claudecode'
const logger = loggerService.withContext('SessionMessageService') const logger = loggerService.withContext('SessionMessageService')
@ -29,7 +34,7 @@ function serializeError(error: unknown): { message: string; name?: string; stack
} }
if (typeof error === 'string') { if (typeof error === 'string') {
return {message: error} return { message: error }
} }
return { return {
@ -99,7 +104,7 @@ export class SessionMessageService extends BaseService {
this.ensureInitialized() this.ensureInitialized()
const result = await this.database const result = await this.database
.select({id: sessionMessagesTable.id}) .select({ id: sessionMessagesTable.id })
.from(sessionMessagesTable) .from(sessionMessagesTable)
.where(eq(sessionMessagesTable.id, id)) .where(eq(sessionMessagesTable.id, id))
.limit(1) .limit(1)
@ -129,7 +134,7 @@ export class SessionMessageService extends BaseService {
const messages = result.map((row) => this.deserializeSessionMessage(row)) as AgentSessionMessageEntity[] const messages = result.map((row) => this.deserializeSessionMessage(row)) as AgentSessionMessageEntity[]
return {messages} return { messages }
} }
async createSessionMessage( async createSessionMessage(
@ -148,11 +153,11 @@ export class SessionMessageService extends BaseService {
abortController: AbortController abortController: AbortController
): Promise<SessionStreamResult> { ): Promise<SessionStreamResult> {
const agentSessionId = await this.getLastAgentSessionId(session.id) const agentSessionId = await this.getLastAgentSessionId(session.id)
logger.debug('Session Message stream message data:', {message: req, session_id: agentSessionId}) logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
if (session.agent_type !== 'claude-code') { if (session.agent_type !== 'claude-code') {
// TODO: Implement support for other agent types // TODO: Implement support for other agent types
logger.error('Unsupported agent type for streaming:', {agent_type: session.agent_type}) logger.error('Unsupported agent type for streaming:', { agent_type: session.agent_type })
throw new Error('Unsupported agent type for streaming') throw new Error('Unsupported agent type for streaming')
} }
@ -243,7 +248,7 @@ export class SessionMessageService extends BaseService {
} }
}) })
return {stream, completion} return { stream, completion }
} }
private async getLastAgentSessionId(sessionId: string): Promise<string> { private async getLastAgentSessionId(sessionId: string): Promise<string> {
@ -251,7 +256,7 @@ export class SessionMessageService extends BaseService {
try { try {
const result = await this.database const result = await this.database
.select({agent_session_id: sessionMessagesTable.agent_session_id}) .select({ agent_session_id: sessionMessagesTable.agent_session_id })
.from(sessionMessagesTable) .from(sessionMessagesTable)
.where(eq(sessionMessagesTable.session_id, sessionId)) .where(eq(sessionMessagesTable.session_id, sessionId))
.orderBy(desc(sessionMessagesTable.created_at)) .orderBy(desc(sessionMessagesTable.created_at))
@ -270,7 +275,7 @@ export class SessionMessageService extends BaseService {
private deserializeSessionMessage(data: any): AgentSessionMessageEntity { private deserializeSessionMessage(data: any): AgentSessionMessageEntity {
if (!data) return data if (!data) return data
const deserialized = {...data} const deserialized = { ...data }
// Parse content JSON // Parse content JSON
if (deserialized.content && typeof deserialized.content === 'string') { if (deserialized.content && typeof deserialized.content === 'string') {

View File

@ -119,7 +119,7 @@ function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assi
toolCallId: block.tool_use_id, toolCallId: block.tool_use_id,
toolName: '', toolName: '',
input: '', input: '',
output: block.content, output: block.content
}) })
break break
default: default:
@ -244,17 +244,18 @@ function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }
} }
break break
case 'content_block_stop': { case 'content_block_stop':
const contentBlock = contentBlockState.get(blockKey) {
if (contentBlock?.type === 'text') { const contentBlock = contentBlockState.get(blockKey)
chunks.push({ if (contentBlock?.type === 'text') {
type: 'text-end', chunks.push({
id: String(event.index) type: 'text-end',
}) id: String(event.index)
})
}
contentBlockState.delete(blockKey)
} }
contentBlockState.delete(blockKey) break
}
break
case 'message_delta': case 'message_delta':
// Handle usage updates or other message-level deltas // Handle usage updates or other message-level deltas
break break
@ -304,9 +305,7 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
usage = { usage = {
inputTokens: message.usage.input_tokens ?? 0, inputTokens: message.usage.input_tokens ?? 0,
outputTokens: message.usage.output_tokens ?? 0, outputTokens: message.usage.output_tokens ?? 0,
totalTokens: totalTokens: (message.usage.input_tokens ?? 0) + (message.usage.output_tokens ?? 0)
(message.usage.input_tokens ?? 0) +
(message.usage.output_tokens ?? 0)
} }
} }
if (message.subtype === 'success') { if (message.subtype === 'success') {

View File

@ -22,14 +22,14 @@ import {
WebSearchToolResultBlockParam, WebSearchToolResultBlockParam,
WebSearchToolResultError WebSearchToolResultError
} from '@anthropic-ai/sdk/resources/messages' } from '@anthropic-ai/sdk/resources/messages'
import {MessageStream} from '@anthropic-ai/sdk/resources/messages/messages' import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
import AnthropicVertex from '@anthropic-ai/vertex-sdk' import AnthropicVertex from '@anthropic-ai/vertex-sdk'
import {loggerService} from '@logger' import { loggerService } from '@logger'
import {DEFAULT_MAX_TOKENS} from '@renderer/config/constant' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel} from '@renderer/config/models' import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
import {getAssistantSettings} from '@renderer/services/AssistantService' import { getAssistantSettings } from '@renderer/services/AssistantService'
import FileManager from '@renderer/services/FileManager' import FileManager from '@renderer/services/FileManager'
import {estimateTextTokens} from '@renderer/services/TokenService' import { estimateTextTokens } from '@renderer/services/TokenService'
import { import {
Assistant, Assistant,
EFFORT_RATIO, EFFORT_RATIO,
@ -53,17 +53,27 @@ import {
ThinkingDeltaChunk, ThinkingDeltaChunk,
ThinkingStartChunk ThinkingStartChunk
} from '@renderer/types/chunk' } from '@renderer/types/chunk'
import {type Message} from '@renderer/types/newMessage' import { type Message } from '@renderer/types/newMessage'
import {AnthropicSdkMessageParam, AnthropicSdkParams, AnthropicSdkRawChunk, AnthropicSdkRawOutput} from '@renderer/types/sdk' import {
import {addImageFileToContents} from '@renderer/utils/formats' AnthropicSdkMessageParam,
import {anthropicToolUseToMcpTool, isSupportedToolUse, mcpToolCallResponseToAnthropicMessage, mcpToolsToAnthropicTools} from '@renderer/utils/mcp-tools' AnthropicSdkParams,
import {findFileBlocks, findImageBlocks} from '@renderer/utils/messageUtils/find' AnthropicSdkRawChunk,
import {buildClaudeCodeSystemMessage, getSdkClient} from "@shared/anthropic"; AnthropicSdkRawOutput
import {t} from 'i18next' } from '@renderer/types/sdk'
import { addImageFileToContents } from '@renderer/utils/formats'
import {
anthropicToolUseToMcpTool,
isSupportedToolUse,
mcpToolCallResponseToAnthropicMessage,
mcpToolsToAnthropicTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
import { t } from 'i18next'
import {GenericChunk} from '../../middleware/schemas' import { GenericChunk } from '../../middleware/schemas'
import {BaseApiClient} from '../BaseApiClient' import { BaseApiClient } from '../BaseApiClient'
import {AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer} from '../types' import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
const logger = loggerService.withContext('AnthropicAPIClient') const logger = loggerService.withContext('AnthropicAPIClient')
@ -105,7 +115,7 @@ export class AnthropicAPIClient extends BaseApiClient<
if (payload.stream) { if (payload.stream) {
return sdk.messages.stream(payload, options) return sdk.messages.stream(payload, options)
} }
return sdk.messages.create(payload, options); return sdk.messages.create(payload, options)
} }
// @ts-ignore sdk未提供 // @ts-ignore sdk未提供
@ -149,7 +159,7 @@ export class AnthropicAPIClient extends BaseApiClient<
if (!isReasoningModel(model)) { if (!isReasoningModel(model)) {
return undefined return undefined
} }
const {maxTokens} = getAssistantSettings(assistant) const { maxTokens } = getAssistantSettings(assistant)
const reasoningEffort = assistant?.settings?.reasoning_effort const reasoningEffort = assistant?.settings?.reasoning_effort
@ -166,7 +176,7 @@ export class AnthropicAPIClient extends BaseApiClient<
Math.floor( Math.floor(
Math.min( Math.min(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + (findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
findTokenLimit(model.id)?.min!, findTokenLimit(model.id)?.min!,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
) )
) )
@ -188,7 +198,7 @@ export class AnthropicAPIClient extends BaseApiClient<
* @returns The message parameter * @returns The message parameter
*/ */
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> { public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
const {textContent, imageContents} = await this.getMessageContent(message) const { textContent, imageContents } = await this.getMessageContent(message)
const parts: MessageParam['content'] = [ const parts: MessageParam['content'] = [
{ {
@ -211,7 +221,7 @@ export class AnthropicAPIClient extends BaseApiClient<
} }
}) })
} else { } else {
logger.warn('Unsupported image type, ignored.', {mime: base64Data.mime}) logger.warn('Unsupported image type, ignored.', { mime: base64Data.mime })
} }
} }
} }
@ -236,7 +246,7 @@ export class AnthropicAPIClient extends BaseApiClient<
// Get and process file blocks // Get and process file blocks
const fileBlocks = findFileBlocks(message) const fileBlocks = findFileBlocks(message)
for (const fileBlock of fileBlocks) { for (const fileBlock of fileBlocks) {
const {file} = fileBlock const { file } = fileBlock
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) { if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
const base64Data = await FileManager.readBase64File(file) const base64Data = await FileManager.readBase64File(file)
@ -464,25 +474,25 @@ export class AnthropicAPIClient extends BaseApiClient<
messages: AnthropicSdkMessageParam[] messages: AnthropicSdkMessageParam[]
metadata: Record<string, any> metadata: Record<string, any>
}> => { }> => {
const {messages, mcpTools, maxTokens, streamOutput, enableWebSearch} = coreRequest const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
// 1. 处理系统消息 // 1. 处理系统消息
const systemPrompt = assistant.prompt const systemPrompt = assistant.prompt
// 2. 设置工具 // 2. 设置工具
const {tools} = this.setupToolsConfig({ const { tools } = this.setupToolsConfig({
mcpTools: mcpTools, mcpTools: mcpTools,
model, model,
enableToolUse: isSupportedToolUse(assistant) enableToolUse: isSupportedToolUse(assistant)
}) })
const systemMessage: TextBlockParam | undefined = systemPrompt const systemMessage: TextBlockParam | undefined = systemPrompt
? {type: 'text', text: systemPrompt} ? { type: 'text', text: systemPrompt }
: undefined : undefined
// 3. 处理用户消息 // 3. 处理用户消息
const sdkMessages: AnthropicSdkMessageParam[] = [] const sdkMessages: AnthropicSdkMessageParam[] = []
if (typeof messages === 'string') { if (typeof messages === 'string') {
sdkMessages.push({role: 'user', content: messages}) sdkMessages.push({ role: 'user', content: messages })
} else { } else {
const processedMessages = addImageFileToContents(messages) const processedMessages = addImageFileToContents(messages)
for (const message of processedMessages) { for (const message of processedMessages) {
@ -516,7 +526,7 @@ export class AnthropicAPIClient extends BaseApiClient<
} }
const timeout = this.getTimeout(model) const timeout = this.getTimeout(model)
return {payload: commonParams, messages: sdkMessages, metadata: {timeout}} return { payload: commonParams, messages: sdkMessages, metadata: { timeout } }
} }
} }
} }
@ -531,7 +541,7 @@ export class AnthropicAPIClient extends BaseApiClient<
try { try {
rawChunk = JSON.parse(rawChunk) rawChunk = JSON.parse(rawChunk)
} catch (error) { } catch (error) {
logger.error('invalid chunk', {rawChunk, error}) logger.error('invalid chunk', { rawChunk, error })
throw new Error(t('error.chat.chunk.non_json')) throw new Error(t('error.chat.chunk.non_json'))
} }
} }

View File

@ -1,5 +1,5 @@
import { useAppDispatch } from '@renderer/store' import { useAppDispatch } from '@renderer/store'
import { removeManyBlocks,upsertManyBlocks } from '@renderer/store/messageBlock' import { removeManyBlocks, upsertManyBlocks } from '@renderer/store/messageBlock'
import { newMessagesActions } from '@renderer/store/newMessage' import { newMessagesActions } from '@renderer/store/newMessage'
import { AgentPersistedMessage, UpdateSessionForm } from '@renderer/types' import { AgentPersistedMessage, UpdateSessionForm } from '@renderer/types'
import { buildAgentSessionTopicId } from '@renderer/utils/agentSession' import { buildAgentSessionTopicId } from '@renderer/utils/agentSession'

View File

@ -17,7 +17,7 @@ import { classNames } from '@renderer/utils'
import { Flex } from 'antd' import { Flex } from 'antd'
import { debounce } from 'lodash' import { debounce } from 'lodash'
import { AnimatePresence, motion } from 'motion/react' import { AnimatePresence, motion } from 'motion/react'
import React, { FC, useMemo,useState } from 'react' import React, { FC, useMemo, useState } from 'react'
import { useHotkeys } from 'react-hotkeys-hook' import { useHotkeys } from 'react-hotkeys-hook'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import styled from 'styled-components' import styled from 'styled-components'

View File

@ -6,7 +6,7 @@ import { useAppSelector } from '@renderer/store'
import { selectMessagesForTopic } from '@renderer/store/newMessage' import { selectMessagesForTopic } from '@renderer/store/newMessage'
import { Topic } from '@renderer/types' import { Topic } from '@renderer/types'
import { buildAgentSessionTopicId } from '@renderer/utils/agentSession' import { buildAgentSessionTopicId } from '@renderer/utils/agentSession'
import { memo,useMemo } from 'react' import { memo, useMemo } from 'react'
import styled from 'styled-components' import styled from 'styled-components'
import MessageGroup from './MessageGroup' import MessageGroup from './MessageGroup'

View File

@ -1,4 +1,4 @@
import type {ProviderMetadata} from "ai"; import type { ProviderMetadata } from 'ai'
import type { CompletionUsage } from 'openai/resources' import type { CompletionUsage } from 'openai/resources'
import type { import type {