mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 14:41:24 +08:00
💄 style: format code with yarn format
This commit is contained in:
parent
b869869e26
commit
a09c52424f
@ -8,9 +8,9 @@
|
||||
* This shared module can be used by both main and renderer processes.
|
||||
*/
|
||||
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
import {TextBlockParam} from "@anthropic-ai/sdk/resources";
|
||||
import {Provider} from "@types";
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import { TextBlockParam } from '@anthropic-ai/sdk/resources'
|
||||
import { Provider } from '@types'
|
||||
|
||||
/**
|
||||
* Creates and configures an Anthropic SDK client based on the provider configuration.
|
||||
@ -54,7 +54,8 @@ export function getSdkClient(provider: Provider, oauthToken?: string | null): An
|
||||
defaultHeaders: {
|
||||
'Content-Type': 'application/json',
|
||||
'anthropic-version': '2023-06-01',
|
||||
'anthropic-beta': 'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14',
|
||||
'anthropic-beta':
|
||||
'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14',
|
||||
'anthropic-dangerous-direct-browser-access': 'true',
|
||||
'user-agent': 'claude-cli/1.0.118 (external, sdk-ts)',
|
||||
'x-app': 'cli',
|
||||
|
||||
@ -336,4 +336,4 @@
|
||||
"internal": {
|
||||
"indexes": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
import {loggerService} from '@main/services/LoggerService'
|
||||
import { loggerService } from '@main/services/LoggerService'
|
||||
import cors from 'cors'
|
||||
import express from 'express'
|
||||
import {v4 as uuidv4} from 'uuid'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
import {authMiddleware} from './middleware/auth'
|
||||
import {errorHandler} from './middleware/error'
|
||||
import {setupOpenAPIDocumentation} from './middleware/openapi'
|
||||
import {agentsRoutes} from './routes/agents'
|
||||
import {chatRoutes} from './routes/chat'
|
||||
import {mcpRoutes} from './routes/mcp'
|
||||
import {messagesProviderRoutes, messagesRoutes} from './routes/messages'
|
||||
import {modelsRoutes} from './routes/models'
|
||||
import { authMiddleware } from './middleware/auth'
|
||||
import { errorHandler } from './middleware/error'
|
||||
import { setupOpenAPIDocumentation } from './middleware/openapi'
|
||||
import { agentsRoutes } from './routes/agents'
|
||||
import { chatRoutes } from './routes/chat'
|
||||
import { mcpRoutes } from './routes/mcp'
|
||||
import { messagesProviderRoutes, messagesRoutes } from './routes/messages'
|
||||
import { modelsRoutes } from './routes/models'
|
||||
|
||||
const logger = loggerService.withContext('ApiServer')
|
||||
|
||||
@ -109,7 +109,7 @@ app.get('/', (_req, res) => {
|
||||
})
|
||||
|
||||
// Provider-specific API routes with auth (must be before /v1 to avoid conflicts)
|
||||
const providerRouter = express.Router({mergeParams: true})
|
||||
const providerRouter = express.Router({ mergeParams: true })
|
||||
providerRouter.use(authMiddleware)
|
||||
providerRouter.use(express.json())
|
||||
// Mount provider-specific messages route
|
||||
@ -128,11 +128,10 @@ apiRouter.use('/models', modelsRoutes)
|
||||
apiRouter.use('/agents', agentsRoutes)
|
||||
app.use('/v1', apiRouter)
|
||||
|
||||
|
||||
// Setup OpenAPI documentation
|
||||
setupOpenAPIDocumentation(app)
|
||||
|
||||
// Error handling (must be last)
|
||||
app.use(errorHandler)
|
||||
|
||||
export {app}
|
||||
export { app }
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import type {NextFunction, Request, Response} from 'express'
|
||||
import {beforeEach, describe, expect, it, vi} from 'vitest'
|
||||
import type { NextFunction, Request, Response } from 'express'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {config} from '../../config'
|
||||
import {authMiddleware} from '../auth'
|
||||
import { config } from '../../config'
|
||||
import { authMiddleware } from '../auth'
|
||||
|
||||
// Mock the config module
|
||||
vi.mock('../../config', () => ({
|
||||
@ -31,7 +31,7 @@ describe('authMiddleware', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
jsonMock = vi.fn()
|
||||
statusMock = vi.fn(() => ({json: jsonMock}))
|
||||
statusMock = vi.fn(() => ({ json: jsonMock }))
|
||||
|
||||
req = {
|
||||
header: vi.fn()
|
||||
@ -51,7 +51,7 @@ describe('authMiddleware', () => {
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: missing credentials'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: missing credentials' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
@ -65,7 +65,7 @@ describe('authMiddleware', () => {
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: missing credentials'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: missing credentials' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@ -77,12 +77,12 @@ describe('authMiddleware', () => {
|
||||
return ''
|
||||
})
|
||||
|
||||
mockConfig.get.mockResolvedValue({apiKey: ''})
|
||||
mockConfig.get.mockResolvedValue({ apiKey: '' })
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
@ -92,12 +92,12 @@ describe('authMiddleware', () => {
|
||||
return ''
|
||||
})
|
||||
|
||||
mockConfig.get.mockResolvedValue({apiKey: null})
|
||||
mockConfig.get.mockResolvedValue({ apiKey: null })
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@ -106,7 +106,7 @@ describe('authMiddleware', () => {
|
||||
const validApiKey = 'valid-api-key-123'
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
|
||||
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||
})
|
||||
|
||||
it('should authenticate successfully with valid API key', async () => {
|
||||
@ -130,7 +130,7 @@ describe('authMiddleware', () => {
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
@ -143,7 +143,7 @@ describe('authMiddleware', () => {
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: empty x-api-key'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: empty x-api-key' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
@ -182,7 +182,7 @@ describe('authMiddleware', () => {
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@ -191,7 +191,7 @@ describe('authMiddleware', () => {
|
||||
const validApiKey = 'valid-api-key-123'
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
|
||||
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||
})
|
||||
|
||||
it('should authenticate successfully with valid Bearer token when no API key', async () => {
|
||||
@ -215,7 +215,7 @@ describe('authMiddleware', () => {
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
@ -228,7 +228,7 @@ describe('authMiddleware', () => {
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
@ -241,20 +241,20 @@ describe('authMiddleware', () => {
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle Bearer token with only trailing spaces (edge case)', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return 'Bearer ' // This will be trimmed to "Bearer" and fail format check
|
||||
if (header === 'authorization') return 'Bearer ' // This will be trimmed to "Bearer" and fail format check
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
@ -287,7 +287,7 @@ describe('authMiddleware', () => {
|
||||
const validApiKey = 'valid-api-key-123'
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
|
||||
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||
})
|
||||
|
||||
it('should handle config.get() rejection', async () => {
|
||||
@ -310,7 +310,7 @@ describe('authMiddleware', () => {
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
@ -323,7 +323,7 @@ describe('authMiddleware', () => {
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@ -332,7 +332,7 @@ describe('authMiddleware', () => {
|
||||
const validApiKey = 'valid-api-key-123'
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
|
||||
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||
})
|
||||
|
||||
it('should handle similar length but different API keys securely', async () => {
|
||||
@ -346,7 +346,7 @@ describe('authMiddleware', () => {
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
@ -361,7 +361,7 @@ describe('authMiddleware', () => {
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import crypto from 'crypto'
|
||||
import {NextFunction, Request, Response} from 'express'
|
||||
|
||||
import {config} from '../config'
|
||||
import { NextFunction, Request, Response } from 'express'
|
||||
|
||||
import { config } from '../config'
|
||||
|
||||
const isValidToken = (token: string, apiKey: string): boolean => {
|
||||
if (token.length !== apiKey.length) {
|
||||
@ -19,26 +18,26 @@ export const authMiddleware = async (req: Request, res: Response, next: NextFunc
|
||||
|
||||
// Fast rejection if neither credential header provided
|
||||
if (!auth && !xApiKey) {
|
||||
return res.status(401).json({error: 'Unauthorized: missing credentials'})
|
||||
return res.status(401).json({ error: 'Unauthorized: missing credentials' })
|
||||
}
|
||||
|
||||
const {apiKey} = await config.get()
|
||||
const { apiKey } = await config.get()
|
||||
|
||||
if (!apiKey) {
|
||||
return res.status(403).json({error: 'Forbidden'})
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
}
|
||||
|
||||
// Check API key first (priority)
|
||||
if (xApiKey) {
|
||||
const trimmedApiKey = xApiKey.trim()
|
||||
if (!trimmedApiKey) {
|
||||
return res.status(401).json({error: 'Unauthorized: empty x-api-key'})
|
||||
return res.status(401).json({ error: 'Unauthorized: empty x-api-key' })
|
||||
}
|
||||
|
||||
if (isValidToken(trimmedApiKey, apiKey)) {
|
||||
return next()
|
||||
} else {
|
||||
return res.status(403).json({error: 'Forbidden'})
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,20 +47,20 @@ export const authMiddleware = async (req: Request, res: Response, next: NextFunc
|
||||
const bearerPrefix = /^Bearer\s+/i
|
||||
|
||||
if (!bearerPrefix.test(trimmed)) {
|
||||
return res.status(401).json({error: 'Unauthorized: invalid authorization format'})
|
||||
return res.status(401).json({ error: 'Unauthorized: invalid authorization format' })
|
||||
}
|
||||
|
||||
const token = trimmed.replace(bearerPrefix, '').trim()
|
||||
if (!token) {
|
||||
return res.status(401).json({error: 'Unauthorized: empty bearer token'})
|
||||
return res.status(401).json({ error: 'Unauthorized: empty bearer token' })
|
||||
}
|
||||
|
||||
if (isValidToken(token, apiKey)) {
|
||||
return next()
|
||||
} else {
|
||||
return res.status(403).json({error: 'Forbidden'})
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
}
|
||||
}
|
||||
|
||||
return res.status(401).json({error: 'Unauthorized: invalid credentials format'})
|
||||
return res.status(401).json({ error: 'Unauthorized: invalid credentials format' })
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { AgentModelValidationError, agentService } from '@main/services/agents'
|
||||
import { ListAgentsResponse,type ReplaceAgentRequest, type UpdateAgentRequest } from '@types'
|
||||
import { ListAgentsResponse, type ReplaceAgentRequest, type UpdateAgentRequest } from '@types'
|
||||
import { Request, Response } from 'express'
|
||||
|
||||
import type { ValidationRequest } from '../validators/zodValidator'
|
||||
|
||||
@ -1,9 +1,5 @@
|
||||
import { loggerService } from '@logger'
|
||||
import {
|
||||
AgentModelValidationError,
|
||||
sessionMessageService,
|
||||
sessionService
|
||||
} from '@main/services/agents'
|
||||
import { AgentModelValidationError, sessionMessageService, sessionService } from '@main/services/agents'
|
||||
import {
|
||||
CreateSessionResponse,
|
||||
ListAgentSessionsResponse,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { NextFunction,Request, Response } from 'express'
|
||||
import { NextFunction, Request, Response } from 'express'
|
||||
import { ZodError, ZodType } from 'zod'
|
||||
|
||||
export interface ValidationRequest extends Request {
|
||||
@ -35,7 +35,7 @@ export const createZodValidator = (config: ZodValidationConfig) => {
|
||||
type: 'field',
|
||||
value: err.input,
|
||||
msg: err.message,
|
||||
path: err.path.map(p => String(p)).join('.'),
|
||||
path: err.path.map((p) => String(p)).join('.'),
|
||||
location: getLocationFromPath(err.path, config)
|
||||
}))
|
||||
|
||||
@ -65,4 +65,4 @@ function getLocationFromPath(path: (string | number | symbol)[], config: ZodVali
|
||||
if (config.params && path.length > 0) return 'params'
|
||||
if (config.query && path.length > 0) return 'query'
|
||||
return 'unknown'
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
import {MessageCreateParams} from '@anthropic-ai/sdk/resources'
|
||||
import {loggerService} from '@logger'
|
||||
import express, {Request, Response} from 'express'
|
||||
import { MessageCreateParams } from '@anthropic-ai/sdk/resources'
|
||||
import { loggerService } from '@logger'
|
||||
import express, { Request, Response } from 'express'
|
||||
|
||||
import {messagesService} from '../services/messages'
|
||||
import {getProviderById, validateModelId} from '../utils'
|
||||
import { messagesService } from '../services/messages'
|
||||
import { getProviderById, validateModelId } from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerMessagesRoutes')
|
||||
|
||||
const router = express.Router()
|
||||
const providerRouter = express.Router({mergeParams: true})
|
||||
const providerRouter = express.Router({ mergeParams: true })
|
||||
|
||||
// Helper functions for shared logic
|
||||
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
|
||||
@ -28,7 +28,7 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
|
||||
}
|
||||
}
|
||||
|
||||
return {valid: true}
|
||||
return { valid: true }
|
||||
}
|
||||
|
||||
async function handleStreamingResponse(
|
||||
@ -318,7 +318,7 @@ router.post('/', async (req: Request, res: Response) => {
|
||||
model: request.model,
|
||||
messageCount: request.messages?.length || 0,
|
||||
stream: request.stream,
|
||||
max_tokens: request.max_tokens,
|
||||
max_tokens: request.max_tokens
|
||||
})
|
||||
|
||||
// Validate model ID and get provider
|
||||
@ -522,4 +522,4 @@ providerRouter.post('/', async (req: Request, res: Response) => {
|
||||
}
|
||||
})
|
||||
|
||||
export {providerRouter as messagesProviderRoutes, router as messagesRoutes}
|
||||
export { providerRouter as messagesProviderRoutes, router as messagesRoutes }
|
||||
|
||||
@ -38,9 +38,10 @@ export type PrepareRequestResult =
|
||||
}
|
||||
|
||||
export class ChatCompletionService {
|
||||
async resolveProviderContext(model: string): Promise<
|
||||
| { ok: false; error: ModelValidationError }
|
||||
| { ok: true; provider: Provider; modelId: string; client: OpenAI }
|
||||
async resolveProviderContext(
|
||||
model: string
|
||||
): Promise<
|
||||
{ ok: false; error: ModelValidationError } | { ok: true; provider: Provider; modelId: string; client: OpenAI }
|
||||
> {
|
||||
const modelValidation = await validateModelId(model)
|
||||
if (!modelValidation.valid) {
|
||||
@ -196,9 +197,7 @@ export class ChatCompletionService {
|
||||
}
|
||||
}
|
||||
|
||||
async processStreamingCompletion(
|
||||
request: ChatCompletionCreateParams
|
||||
): Promise<{
|
||||
async processStreamingCompletion(request: ChatCompletionCreateParams): Promise<{
|
||||
provider: Provider
|
||||
modelId: string
|
||||
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
|
||||
@ -227,9 +226,9 @@ export class ChatCompletionService {
|
||||
})
|
||||
|
||||
const streamRequest = providerRequest as ChatCompletionCreateParamsStreaming
|
||||
const stream = (await client.chat.completions.create(streamRequest)) as AsyncIterable<
|
||||
OpenAI.Chat.Completions.ChatCompletionChunk
|
||||
>
|
||||
const stream = (await client.chat.completions.create(
|
||||
streamRequest
|
||||
)) as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
|
||||
|
||||
logger.info('Successfully started streaming chat completion')
|
||||
return {
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
import {Message, MessageCreateParams, RawMessageStreamEvent} from '@anthropic-ai/sdk/resources'
|
||||
import {loggerService} from '@logger'
|
||||
import anthropicService from "@main/services/AnthropicService";
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import { Message, MessageCreateParams, RawMessageStreamEvent } from '@anthropic-ai/sdk/resources'
|
||||
import { loggerService } from '@logger'
|
||||
import anthropicService from '@main/services/AnthropicService'
|
||||
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
|
||||
import {Provider} from '@types'
|
||||
|
||||
import { Provider } from '@types'
|
||||
|
||||
const logger = loggerService.withContext('MessagesService')
|
||||
|
||||
@ -46,7 +45,6 @@ export class MessagesService {
|
||||
return getSdkClient(provider)
|
||||
}
|
||||
|
||||
|
||||
async processMessage(request: MessageCreateParams, provider: Provider): Promise<Message> {
|
||||
logger.info('Processing Anthropic message request:', {
|
||||
model: request.model,
|
||||
@ -79,7 +77,7 @@ export class MessagesService {
|
||||
return response
|
||||
}
|
||||
|
||||
async* processStreamingMessage(
|
||||
async *processStreamingMessage(
|
||||
request: MessageCreateParams,
|
||||
provider: Provider
|
||||
): AsyncIterable<RawMessageStreamEvent> {
|
||||
|
||||
@ -20,4 +20,3 @@ export class AgentModelValidationError extends Error {
|
||||
this.detail = detail
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
import {loggerService} from '@logger'
|
||||
import type {AgentSessionMessageEntity, CreateSessionMessageRequest, GetAgentSessionResponse, ListOptions} from '@types'
|
||||
import {TextStreamPart} from 'ai'
|
||||
import {desc, eq} from 'drizzle-orm'
|
||||
import { loggerService } from '@logger'
|
||||
import type {
|
||||
AgentSessionMessageEntity,
|
||||
CreateSessionMessageRequest,
|
||||
GetAgentSessionResponse,
|
||||
ListOptions
|
||||
} from '@types'
|
||||
import { TextStreamPart } from 'ai'
|
||||
import { desc, eq } from 'drizzle-orm'
|
||||
|
||||
import {BaseService} from '../BaseService'
|
||||
import {sessionMessagesTable} from '../database/schema'
|
||||
import {AgentStreamEvent} from '../interfaces/AgentStreamInterface'
|
||||
import { BaseService } from '../BaseService'
|
||||
import { sessionMessagesTable } from '../database/schema'
|
||||
import { AgentStreamEvent } from '../interfaces/AgentStreamInterface'
|
||||
import ClaudeCodeService from './claudecode'
|
||||
|
||||
const logger = loggerService.withContext('SessionMessageService')
|
||||
@ -29,7 +34,7 @@ function serializeError(error: unknown): { message: string; name?: string; stack
|
||||
}
|
||||
|
||||
if (typeof error === 'string') {
|
||||
return {message: error}
|
||||
return { message: error }
|
||||
}
|
||||
|
||||
return {
|
||||
@ -99,7 +104,7 @@ export class SessionMessageService extends BaseService {
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database
|
||||
.select({id: sessionMessagesTable.id})
|
||||
.select({ id: sessionMessagesTable.id })
|
||||
.from(sessionMessagesTable)
|
||||
.where(eq(sessionMessagesTable.id, id))
|
||||
.limit(1)
|
||||
@ -129,7 +134,7 @@ export class SessionMessageService extends BaseService {
|
||||
|
||||
const messages = result.map((row) => this.deserializeSessionMessage(row)) as AgentSessionMessageEntity[]
|
||||
|
||||
return {messages}
|
||||
return { messages }
|
||||
}
|
||||
|
||||
async createSessionMessage(
|
||||
@ -148,11 +153,11 @@ export class SessionMessageService extends BaseService {
|
||||
abortController: AbortController
|
||||
): Promise<SessionStreamResult> {
|
||||
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
||||
logger.debug('Session Message stream message data:', {message: req, session_id: agentSessionId})
|
||||
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
|
||||
|
||||
if (session.agent_type !== 'claude-code') {
|
||||
// TODO: Implement support for other agent types
|
||||
logger.error('Unsupported agent type for streaming:', {agent_type: session.agent_type})
|
||||
logger.error('Unsupported agent type for streaming:', { agent_type: session.agent_type })
|
||||
throw new Error('Unsupported agent type for streaming')
|
||||
}
|
||||
|
||||
@ -243,7 +248,7 @@ export class SessionMessageService extends BaseService {
|
||||
}
|
||||
})
|
||||
|
||||
return {stream, completion}
|
||||
return { stream, completion }
|
||||
}
|
||||
|
||||
private async getLastAgentSessionId(sessionId: string): Promise<string> {
|
||||
@ -251,7 +256,7 @@ export class SessionMessageService extends BaseService {
|
||||
|
||||
try {
|
||||
const result = await this.database
|
||||
.select({agent_session_id: sessionMessagesTable.agent_session_id})
|
||||
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
|
||||
.from(sessionMessagesTable)
|
||||
.where(eq(sessionMessagesTable.session_id, sessionId))
|
||||
.orderBy(desc(sessionMessagesTable.created_at))
|
||||
@ -270,7 +275,7 @@ export class SessionMessageService extends BaseService {
|
||||
private deserializeSessionMessage(data: any): AgentSessionMessageEntity {
|
||||
if (!data) return data
|
||||
|
||||
const deserialized = {...data}
|
||||
const deserialized = { ...data }
|
||||
|
||||
// Parse content JSON
|
||||
if (deserialized.content && typeof deserialized.content === 'string') {
|
||||
|
||||
@ -119,7 +119,7 @@ function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assi
|
||||
toolCallId: block.tool_use_id,
|
||||
toolName: '',
|
||||
input: '',
|
||||
output: block.content,
|
||||
output: block.content
|
||||
})
|
||||
break
|
||||
default:
|
||||
@ -244,17 +244,18 @@ function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }
|
||||
}
|
||||
break
|
||||
|
||||
case 'content_block_stop': {
|
||||
const contentBlock = contentBlockState.get(blockKey)
|
||||
if (contentBlock?.type === 'text') {
|
||||
chunks.push({
|
||||
type: 'text-end',
|
||||
id: String(event.index)
|
||||
})
|
||||
case 'content_block_stop':
|
||||
{
|
||||
const contentBlock = contentBlockState.get(blockKey)
|
||||
if (contentBlock?.type === 'text') {
|
||||
chunks.push({
|
||||
type: 'text-end',
|
||||
id: String(event.index)
|
||||
})
|
||||
}
|
||||
contentBlockState.delete(blockKey)
|
||||
}
|
||||
contentBlockState.delete(blockKey)
|
||||
}
|
||||
break
|
||||
break
|
||||
case 'message_delta':
|
||||
// Handle usage updates or other message-level deltas
|
||||
break
|
||||
@ -304,9 +305,7 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
|
||||
usage = {
|
||||
inputTokens: message.usage.input_tokens ?? 0,
|
||||
outputTokens: message.usage.output_tokens ?? 0,
|
||||
totalTokens:
|
||||
(message.usage.input_tokens ?? 0) +
|
||||
(message.usage.output_tokens ?? 0)
|
||||
totalTokens: (message.usage.input_tokens ?? 0) + (message.usage.output_tokens ?? 0)
|
||||
}
|
||||
}
|
||||
if (message.subtype === 'success') {
|
||||
|
||||
@ -22,14 +22,14 @@ import {
|
||||
WebSearchToolResultBlockParam,
|
||||
WebSearchToolResultError
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import {MessageStream} from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import {loggerService} from '@logger'
|
||||
import {DEFAULT_MAX_TOKENS} from '@renderer/config/constant'
|
||||
import {findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel} from '@renderer/config/models'
|
||||
import {getAssistantSettings} from '@renderer/services/AssistantService'
|
||||
import { loggerService } from '@logger'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import {estimateTextTokens} from '@renderer/services/TokenService'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import {
|
||||
Assistant,
|
||||
EFFORT_RATIO,
|
||||
@ -53,17 +53,27 @@ import {
|
||||
ThinkingDeltaChunk,
|
||||
ThinkingStartChunk
|
||||
} from '@renderer/types/chunk'
|
||||
import {type Message} from '@renderer/types/newMessage'
|
||||
import {AnthropicSdkMessageParam, AnthropicSdkParams, AnthropicSdkRawChunk, AnthropicSdkRawOutput} from '@renderer/types/sdk'
|
||||
import {addImageFileToContents} from '@renderer/utils/formats'
|
||||
import {anthropicToolUseToMcpTool, isSupportedToolUse, mcpToolCallResponseToAnthropicMessage, mcpToolsToAnthropicTools} from '@renderer/utils/mcp-tools'
|
||||
import {findFileBlocks, findImageBlocks} from '@renderer/utils/messageUtils/find'
|
||||
import {buildClaudeCodeSystemMessage, getSdkClient} from "@shared/anthropic";
|
||||
import {t} from 'i18next'
|
||||
import { type Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
AnthropicSdkMessageParam,
|
||||
AnthropicSdkParams,
|
||||
AnthropicSdkRawChunk,
|
||||
AnthropicSdkRawOutput
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
anthropicToolUseToMcpTool,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToAnthropicMessage,
|
||||
mcpToolsToAnthropicTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
|
||||
import { t } from 'i18next'
|
||||
|
||||
import {GenericChunk} from '../../middleware/schemas'
|
||||
import {BaseApiClient} from '../BaseApiClient'
|
||||
import {AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer} from '../types'
|
||||
import { GenericChunk } from '../../middleware/schemas'
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
const logger = loggerService.withContext('AnthropicAPIClient')
|
||||
|
||||
@ -105,7 +115,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
if (payload.stream) {
|
||||
return sdk.messages.stream(payload, options)
|
||||
}
|
||||
return sdk.messages.create(payload, options);
|
||||
return sdk.messages.create(payload, options)
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
@ -149,7 +159,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
if (!isReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
const {maxTokens} = getAssistantSettings(assistant)
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
@ -166,7 +176,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
Math.floor(
|
||||
Math.min(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||
findTokenLimit(model.id)?.min!,
|
||||
findTokenLimit(model.id)?.min!,
|
||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||
)
|
||||
)
|
||||
@ -188,7 +198,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
* @returns The message parameter
|
||||
*/
|
||||
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
|
||||
const {textContent, imageContents} = await this.getMessageContent(message)
|
||||
const { textContent, imageContents } = await this.getMessageContent(message)
|
||||
|
||||
const parts: MessageParam['content'] = [
|
||||
{
|
||||
@ -211,7 +221,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
}
|
||||
})
|
||||
} else {
|
||||
logger.warn('Unsupported image type, ignored.', {mime: base64Data.mime})
|
||||
logger.warn('Unsupported image type, ignored.', { mime: base64Data.mime })
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -236,7 +246,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
// Get and process file blocks
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const {file} = fileBlock
|
||||
const { file } = fileBlock
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
|
||||
const base64Data = await FileManager.readBase64File(file)
|
||||
@ -464,25 +474,25 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
messages: AnthropicSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const {messages, mcpTools, maxTokens, streamOutput, enableWebSearch} = coreRequest
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
const systemPrompt = assistant.prompt
|
||||
|
||||
// 2. 设置工具
|
||||
const {tools} = this.setupToolsConfig({
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
const systemMessage: TextBlockParam | undefined = systemPrompt
|
||||
? {type: 'text', text: systemPrompt}
|
||||
? { type: 'text', text: systemPrompt }
|
||||
: undefined
|
||||
|
||||
// 3. 处理用户消息
|
||||
const sdkMessages: AnthropicSdkMessageParam[] = []
|
||||
if (typeof messages === 'string') {
|
||||
sdkMessages.push({role: 'user', content: messages})
|
||||
sdkMessages.push({ role: 'user', content: messages })
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
@ -516,7 +526,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
}
|
||||
|
||||
const timeout = this.getTimeout(model)
|
||||
return {payload: commonParams, messages: sdkMessages, metadata: {timeout}}
|
||||
return { payload: commonParams, messages: sdkMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -531,7 +541,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
try {
|
||||
rawChunk = JSON.parse(rawChunk)
|
||||
} catch (error) {
|
||||
logger.error('invalid chunk', {rawChunk, error})
|
||||
logger.error('invalid chunk', { rawChunk, error })
|
||||
throw new Error(t('error.chat.chunk.non_json'))
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { useAppDispatch } from '@renderer/store'
|
||||
import { removeManyBlocks,upsertManyBlocks } from '@renderer/store/messageBlock'
|
||||
import { removeManyBlocks, upsertManyBlocks } from '@renderer/store/messageBlock'
|
||||
import { newMessagesActions } from '@renderer/store/newMessage'
|
||||
import { AgentPersistedMessage, UpdateSessionForm } from '@renderer/types'
|
||||
import { buildAgentSessionTopicId } from '@renderer/utils/agentSession'
|
||||
|
||||
@ -17,7 +17,7 @@ import { classNames } from '@renderer/utils'
|
||||
import { Flex } from 'antd'
|
||||
import { debounce } from 'lodash'
|
||||
import { AnimatePresence, motion } from 'motion/react'
|
||||
import React, { FC, useMemo,useState } from 'react'
|
||||
import React, { FC, useMemo, useState } from 'react'
|
||||
import { useHotkeys } from 'react-hotkeys-hook'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
|
||||
@ -6,7 +6,7 @@ import { useAppSelector } from '@renderer/store'
|
||||
import { selectMessagesForTopic } from '@renderer/store/newMessage'
|
||||
import { Topic } from '@renderer/types'
|
||||
import { buildAgentSessionTopicId } from '@renderer/utils/agentSession'
|
||||
import { memo,useMemo } from 'react'
|
||||
import { memo, useMemo } from 'react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
import MessageGroup from './MessageGroup'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type {ProviderMetadata} from "ai";
|
||||
import type { ProviderMetadata } from 'ai'
|
||||
import type { CompletionUsage } from 'openai/resources'
|
||||
|
||||
import type {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user