mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 14:41:24 +08:00
💄 style: format code with yarn format
This commit is contained in:
parent
b869869e26
commit
a09c52424f
@ -8,9 +8,9 @@
|
|||||||
* This shared module can be used by both main and renderer processes.
|
* This shared module can be used by both main and renderer processes.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import Anthropic from "@anthropic-ai/sdk";
|
import Anthropic from '@anthropic-ai/sdk'
|
||||||
import {TextBlockParam} from "@anthropic-ai/sdk/resources";
|
import { TextBlockParam } from '@anthropic-ai/sdk/resources'
|
||||||
import {Provider} from "@types";
|
import { Provider } from '@types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates and configures an Anthropic SDK client based on the provider configuration.
|
* Creates and configures an Anthropic SDK client based on the provider configuration.
|
||||||
@ -54,7 +54,8 @@ export function getSdkClient(provider: Provider, oauthToken?: string | null): An
|
|||||||
defaultHeaders: {
|
defaultHeaders: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'anthropic-version': '2023-06-01',
|
'anthropic-version': '2023-06-01',
|
||||||
'anthropic-beta': 'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14',
|
'anthropic-beta':
|
||||||
|
'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14',
|
||||||
'anthropic-dangerous-direct-browser-access': 'true',
|
'anthropic-dangerous-direct-browser-access': 'true',
|
||||||
'user-agent': 'claude-cli/1.0.118 (external, sdk-ts)',
|
'user-agent': 'claude-cli/1.0.118 (external, sdk-ts)',
|
||||||
'x-app': 'cli',
|
'x-app': 'cli',
|
||||||
|
|||||||
@ -336,4 +336,4 @@
|
|||||||
"internal": {
|
"internal": {
|
||||||
"indexes": {}
|
"indexes": {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,16 +1,16 @@
|
|||||||
import {loggerService} from '@main/services/LoggerService'
|
import { loggerService } from '@main/services/LoggerService'
|
||||||
import cors from 'cors'
|
import cors from 'cors'
|
||||||
import express from 'express'
|
import express from 'express'
|
||||||
import {v4 as uuidv4} from 'uuid'
|
import { v4 as uuidv4 } from 'uuid'
|
||||||
|
|
||||||
import {authMiddleware} from './middleware/auth'
|
import { authMiddleware } from './middleware/auth'
|
||||||
import {errorHandler} from './middleware/error'
|
import { errorHandler } from './middleware/error'
|
||||||
import {setupOpenAPIDocumentation} from './middleware/openapi'
|
import { setupOpenAPIDocumentation } from './middleware/openapi'
|
||||||
import {agentsRoutes} from './routes/agents'
|
import { agentsRoutes } from './routes/agents'
|
||||||
import {chatRoutes} from './routes/chat'
|
import { chatRoutes } from './routes/chat'
|
||||||
import {mcpRoutes} from './routes/mcp'
|
import { mcpRoutes } from './routes/mcp'
|
||||||
import {messagesProviderRoutes, messagesRoutes} from './routes/messages'
|
import { messagesProviderRoutes, messagesRoutes } from './routes/messages'
|
||||||
import {modelsRoutes} from './routes/models'
|
import { modelsRoutes } from './routes/models'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServer')
|
const logger = loggerService.withContext('ApiServer')
|
||||||
|
|
||||||
@ -109,7 +109,7 @@ app.get('/', (_req, res) => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Provider-specific API routes with auth (must be before /v1 to avoid conflicts)
|
// Provider-specific API routes with auth (must be before /v1 to avoid conflicts)
|
||||||
const providerRouter = express.Router({mergeParams: true})
|
const providerRouter = express.Router({ mergeParams: true })
|
||||||
providerRouter.use(authMiddleware)
|
providerRouter.use(authMiddleware)
|
||||||
providerRouter.use(express.json())
|
providerRouter.use(express.json())
|
||||||
// Mount provider-specific messages route
|
// Mount provider-specific messages route
|
||||||
@ -128,11 +128,10 @@ apiRouter.use('/models', modelsRoutes)
|
|||||||
apiRouter.use('/agents', agentsRoutes)
|
apiRouter.use('/agents', agentsRoutes)
|
||||||
app.use('/v1', apiRouter)
|
app.use('/v1', apiRouter)
|
||||||
|
|
||||||
|
|
||||||
// Setup OpenAPI documentation
|
// Setup OpenAPI documentation
|
||||||
setupOpenAPIDocumentation(app)
|
setupOpenAPIDocumentation(app)
|
||||||
|
|
||||||
// Error handling (must be last)
|
// Error handling (must be last)
|
||||||
app.use(errorHandler)
|
app.use(errorHandler)
|
||||||
|
|
||||||
export {app}
|
export { app }
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import type {NextFunction, Request, Response} from 'express'
|
import type { NextFunction, Request, Response } from 'express'
|
||||||
import {beforeEach, describe, expect, it, vi} from 'vitest'
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
import {config} from '../../config'
|
import { config } from '../../config'
|
||||||
import {authMiddleware} from '../auth'
|
import { authMiddleware } from '../auth'
|
||||||
|
|
||||||
// Mock the config module
|
// Mock the config module
|
||||||
vi.mock('../../config', () => ({
|
vi.mock('../../config', () => ({
|
||||||
@ -31,7 +31,7 @@ describe('authMiddleware', () => {
|
|||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jsonMock = vi.fn()
|
jsonMock = vi.fn()
|
||||||
statusMock = vi.fn(() => ({json: jsonMock}))
|
statusMock = vi.fn(() => ({ json: jsonMock }))
|
||||||
|
|
||||||
req = {
|
req = {
|
||||||
header: vi.fn()
|
header: vi.fn()
|
||||||
@ -51,7 +51,7 @@ describe('authMiddleware', () => {
|
|||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(401)
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: missing credentials'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: missing credentials' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ describe('authMiddleware', () => {
|
|||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(401)
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: missing credentials'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: missing credentials' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@ -77,12 +77,12 @@ describe('authMiddleware', () => {
|
|||||||
return ''
|
return ''
|
||||||
})
|
})
|
||||||
|
|
||||||
mockConfig.get.mockResolvedValue({apiKey: ''})
|
mockConfig.get.mockResolvedValue({ apiKey: '' })
|
||||||
|
|
||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(403)
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -92,12 +92,12 @@ describe('authMiddleware', () => {
|
|||||||
return ''
|
return ''
|
||||||
})
|
})
|
||||||
|
|
||||||
mockConfig.get.mockResolvedValue({apiKey: null})
|
mockConfig.get.mockResolvedValue({ apiKey: null })
|
||||||
|
|
||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(403)
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@ -106,7 +106,7 @@ describe('authMiddleware', () => {
|
|||||||
const validApiKey = 'valid-api-key-123'
|
const validApiKey = 'valid-api-key-123'
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
|
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should authenticate successfully with valid API key', async () => {
|
it('should authenticate successfully with valid API key', async () => {
|
||||||
@ -130,7 +130,7 @@ describe('authMiddleware', () => {
|
|||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(403)
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -143,7 +143,7 @@ describe('authMiddleware', () => {
|
|||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(401)
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: empty x-api-key'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: empty x-api-key' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -182,7 +182,7 @@ describe('authMiddleware', () => {
|
|||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(403)
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@ -191,7 +191,7 @@ describe('authMiddleware', () => {
|
|||||||
const validApiKey = 'valid-api-key-123'
|
const validApiKey = 'valid-api-key-123'
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
|
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should authenticate successfully with valid Bearer token when no API key', async () => {
|
it('should authenticate successfully with valid Bearer token when no API key', async () => {
|
||||||
@ -215,7 +215,7 @@ describe('authMiddleware', () => {
|
|||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(403)
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -228,7 +228,7 @@ describe('authMiddleware', () => {
|
|||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(401)
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -241,20 +241,20 @@ describe('authMiddleware', () => {
|
|||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(401)
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle Bearer token with only trailing spaces (edge case)', async () => {
|
it('should handle Bearer token with only trailing spaces (edge case)', async () => {
|
||||||
;(req.header as any).mockImplementation((header: string) => {
|
;(req.header as any).mockImplementation((header: string) => {
|
||||||
if (header === 'authorization') return 'Bearer ' // This will be trimmed to "Bearer" and fail format check
|
if (header === 'authorization') return 'Bearer ' // This will be trimmed to "Bearer" and fail format check
|
||||||
return ''
|
return ''
|
||||||
})
|
})
|
||||||
|
|
||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(401)
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -287,7 +287,7 @@ describe('authMiddleware', () => {
|
|||||||
const validApiKey = 'valid-api-key-123'
|
const validApiKey = 'valid-api-key-123'
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
|
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle config.get() rejection', async () => {
|
it('should handle config.get() rejection', async () => {
|
||||||
@ -310,7 +310,7 @@ describe('authMiddleware', () => {
|
|||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(403)
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -323,7 +323,7 @@ describe('authMiddleware', () => {
|
|||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(401)
|
expect(statusMock).toHaveBeenCalledWith(401)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@ -332,7 +332,7 @@ describe('authMiddleware', () => {
|
|||||||
const validApiKey = 'valid-api-key-123'
|
const validApiKey = 'valid-api-key-123'
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
|
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle similar length but different API keys securely', async () => {
|
it('should handle similar length but different API keys securely', async () => {
|
||||||
@ -346,7 +346,7 @@ describe('authMiddleware', () => {
|
|||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(403)
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -361,7 +361,7 @@ describe('authMiddleware', () => {
|
|||||||
await authMiddleware(req as Request, res as Response, next)
|
await authMiddleware(req as Request, res as Response, next)
|
||||||
|
|
||||||
expect(statusMock).toHaveBeenCalledWith(403)
|
expect(statusMock).toHaveBeenCalledWith(403)
|
||||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||||
expect(next).not.toHaveBeenCalled()
|
expect(next).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
import crypto from 'crypto'
|
import crypto from 'crypto'
|
||||||
import {NextFunction, Request, Response} from 'express'
|
import { NextFunction, Request, Response } from 'express'
|
||||||
|
|
||||||
import {config} from '../config'
|
|
||||||
|
|
||||||
|
import { config } from '../config'
|
||||||
|
|
||||||
const isValidToken = (token: string, apiKey: string): boolean => {
|
const isValidToken = (token: string, apiKey: string): boolean => {
|
||||||
if (token.length !== apiKey.length) {
|
if (token.length !== apiKey.length) {
|
||||||
@ -19,26 +18,26 @@ export const authMiddleware = async (req: Request, res: Response, next: NextFunc
|
|||||||
|
|
||||||
// Fast rejection if neither credential header provided
|
// Fast rejection if neither credential header provided
|
||||||
if (!auth && !xApiKey) {
|
if (!auth && !xApiKey) {
|
||||||
return res.status(401).json({error: 'Unauthorized: missing credentials'})
|
return res.status(401).json({ error: 'Unauthorized: missing credentials' })
|
||||||
}
|
}
|
||||||
|
|
||||||
const {apiKey} = await config.get()
|
const { apiKey } = await config.get()
|
||||||
|
|
||||||
if (!apiKey) {
|
if (!apiKey) {
|
||||||
return res.status(403).json({error: 'Forbidden'})
|
return res.status(403).json({ error: 'Forbidden' })
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check API key first (priority)
|
// Check API key first (priority)
|
||||||
if (xApiKey) {
|
if (xApiKey) {
|
||||||
const trimmedApiKey = xApiKey.trim()
|
const trimmedApiKey = xApiKey.trim()
|
||||||
if (!trimmedApiKey) {
|
if (!trimmedApiKey) {
|
||||||
return res.status(401).json({error: 'Unauthorized: empty x-api-key'})
|
return res.status(401).json({ error: 'Unauthorized: empty x-api-key' })
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isValidToken(trimmedApiKey, apiKey)) {
|
if (isValidToken(trimmedApiKey, apiKey)) {
|
||||||
return next()
|
return next()
|
||||||
} else {
|
} else {
|
||||||
return res.status(403).json({error: 'Forbidden'})
|
return res.status(403).json({ error: 'Forbidden' })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,20 +47,20 @@ export const authMiddleware = async (req: Request, res: Response, next: NextFunc
|
|||||||
const bearerPrefix = /^Bearer\s+/i
|
const bearerPrefix = /^Bearer\s+/i
|
||||||
|
|
||||||
if (!bearerPrefix.test(trimmed)) {
|
if (!bearerPrefix.test(trimmed)) {
|
||||||
return res.status(401).json({error: 'Unauthorized: invalid authorization format'})
|
return res.status(401).json({ error: 'Unauthorized: invalid authorization format' })
|
||||||
}
|
}
|
||||||
|
|
||||||
const token = trimmed.replace(bearerPrefix, '').trim()
|
const token = trimmed.replace(bearerPrefix, '').trim()
|
||||||
if (!token) {
|
if (!token) {
|
||||||
return res.status(401).json({error: 'Unauthorized: empty bearer token'})
|
return res.status(401).json({ error: 'Unauthorized: empty bearer token' })
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isValidToken(token, apiKey)) {
|
if (isValidToken(token, apiKey)) {
|
||||||
return next()
|
return next()
|
||||||
} else {
|
} else {
|
||||||
return res.status(403).json({error: 'Forbidden'})
|
return res.status(403).json({ error: 'Forbidden' })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return res.status(401).json({error: 'Unauthorized: invalid credentials format'})
|
return res.status(401).json({ error: 'Unauthorized: invalid credentials format' })
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { AgentModelValidationError, agentService } from '@main/services/agents'
|
import { AgentModelValidationError, agentService } from '@main/services/agents'
|
||||||
import { ListAgentsResponse,type ReplaceAgentRequest, type UpdateAgentRequest } from '@types'
|
import { ListAgentsResponse, type ReplaceAgentRequest, type UpdateAgentRequest } from '@types'
|
||||||
import { Request, Response } from 'express'
|
import { Request, Response } from 'express'
|
||||||
|
|
||||||
import type { ValidationRequest } from '../validators/zodValidator'
|
import type { ValidationRequest } from '../validators/zodValidator'
|
||||||
|
|||||||
@ -1,9 +1,5 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import {
|
import { AgentModelValidationError, sessionMessageService, sessionService } from '@main/services/agents'
|
||||||
AgentModelValidationError,
|
|
||||||
sessionMessageService,
|
|
||||||
sessionService
|
|
||||||
} from '@main/services/agents'
|
|
||||||
import {
|
import {
|
||||||
CreateSessionResponse,
|
CreateSessionResponse,
|
||||||
ListAgentSessionsResponse,
|
ListAgentSessionsResponse,
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import { NextFunction,Request, Response } from 'express'
|
import { NextFunction, Request, Response } from 'express'
|
||||||
import { ZodError, ZodType } from 'zod'
|
import { ZodError, ZodType } from 'zod'
|
||||||
|
|
||||||
export interface ValidationRequest extends Request {
|
export interface ValidationRequest extends Request {
|
||||||
@ -35,7 +35,7 @@ export const createZodValidator = (config: ZodValidationConfig) => {
|
|||||||
type: 'field',
|
type: 'field',
|
||||||
value: err.input,
|
value: err.input,
|
||||||
msg: err.message,
|
msg: err.message,
|
||||||
path: err.path.map(p => String(p)).join('.'),
|
path: err.path.map((p) => String(p)).join('.'),
|
||||||
location: getLocationFromPath(err.path, config)
|
location: getLocationFromPath(err.path, config)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -65,4 +65,4 @@ function getLocationFromPath(path: (string | number | symbol)[], config: ZodVali
|
|||||||
if (config.params && path.length > 0) return 'params'
|
if (config.params && path.length > 0) return 'params'
|
||||||
if (config.query && path.length > 0) return 'query'
|
if (config.query && path.length > 0) return 'query'
|
||||||
return 'unknown'
|
return 'unknown'
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,14 +1,14 @@
|
|||||||
import {MessageCreateParams} from '@anthropic-ai/sdk/resources'
|
import { MessageCreateParams } from '@anthropic-ai/sdk/resources'
|
||||||
import {loggerService} from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import express, {Request, Response} from 'express'
|
import express, { Request, Response } from 'express'
|
||||||
|
|
||||||
import {messagesService} from '../services/messages'
|
import { messagesService } from '../services/messages'
|
||||||
import {getProviderById, validateModelId} from '../utils'
|
import { getProviderById, validateModelId } from '../utils'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServerMessagesRoutes')
|
const logger = loggerService.withContext('ApiServerMessagesRoutes')
|
||||||
|
|
||||||
const router = express.Router()
|
const router = express.Router()
|
||||||
const providerRouter = express.Router({mergeParams: true})
|
const providerRouter = express.Router({ mergeParams: true })
|
||||||
|
|
||||||
// Helper functions for shared logic
|
// Helper functions for shared logic
|
||||||
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
|
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
|
||||||
@ -28,7 +28,7 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return {valid: true}
|
return { valid: true }
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleStreamingResponse(
|
async function handleStreamingResponse(
|
||||||
@ -318,7 +318,7 @@ router.post('/', async (req: Request, res: Response) => {
|
|||||||
model: request.model,
|
model: request.model,
|
||||||
messageCount: request.messages?.length || 0,
|
messageCount: request.messages?.length || 0,
|
||||||
stream: request.stream,
|
stream: request.stream,
|
||||||
max_tokens: request.max_tokens,
|
max_tokens: request.max_tokens
|
||||||
})
|
})
|
||||||
|
|
||||||
// Validate model ID and get provider
|
// Validate model ID and get provider
|
||||||
@ -522,4 +522,4 @@ providerRouter.post('/', async (req: Request, res: Response) => {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
export {providerRouter as messagesProviderRoutes, router as messagesRoutes}
|
export { providerRouter as messagesProviderRoutes, router as messagesRoutes }
|
||||||
|
|||||||
@ -38,9 +38,10 @@ export type PrepareRequestResult =
|
|||||||
}
|
}
|
||||||
|
|
||||||
export class ChatCompletionService {
|
export class ChatCompletionService {
|
||||||
async resolveProviderContext(model: string): Promise<
|
async resolveProviderContext(
|
||||||
| { ok: false; error: ModelValidationError }
|
model: string
|
||||||
| { ok: true; provider: Provider; modelId: string; client: OpenAI }
|
): Promise<
|
||||||
|
{ ok: false; error: ModelValidationError } | { ok: true; provider: Provider; modelId: string; client: OpenAI }
|
||||||
> {
|
> {
|
||||||
const modelValidation = await validateModelId(model)
|
const modelValidation = await validateModelId(model)
|
||||||
if (!modelValidation.valid) {
|
if (!modelValidation.valid) {
|
||||||
@ -196,9 +197,7 @@ export class ChatCompletionService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async processStreamingCompletion(
|
async processStreamingCompletion(request: ChatCompletionCreateParams): Promise<{
|
||||||
request: ChatCompletionCreateParams
|
|
||||||
): Promise<{
|
|
||||||
provider: Provider
|
provider: Provider
|
||||||
modelId: string
|
modelId: string
|
||||||
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
|
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
|
||||||
@ -227,9 +226,9 @@ export class ChatCompletionService {
|
|||||||
})
|
})
|
||||||
|
|
||||||
const streamRequest = providerRequest as ChatCompletionCreateParamsStreaming
|
const streamRequest = providerRequest as ChatCompletionCreateParamsStreaming
|
||||||
const stream = (await client.chat.completions.create(streamRequest)) as AsyncIterable<
|
const stream = (await client.chat.completions.create(
|
||||||
OpenAI.Chat.Completions.ChatCompletionChunk
|
streamRequest
|
||||||
>
|
)) as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
|
||||||
|
|
||||||
logger.info('Successfully started streaming chat completion')
|
logger.info('Successfully started streaming chat completion')
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
import Anthropic from "@anthropic-ai/sdk";
|
import Anthropic from '@anthropic-ai/sdk'
|
||||||
import {Message, MessageCreateParams, RawMessageStreamEvent} from '@anthropic-ai/sdk/resources'
|
import { Message, MessageCreateParams, RawMessageStreamEvent } from '@anthropic-ai/sdk/resources'
|
||||||
import {loggerService} from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import anthropicService from "@main/services/AnthropicService";
|
import anthropicService from '@main/services/AnthropicService'
|
||||||
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
|
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
|
||||||
import {Provider} from '@types'
|
import { Provider } from '@types'
|
||||||
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('MessagesService')
|
const logger = loggerService.withContext('MessagesService')
|
||||||
|
|
||||||
@ -46,7 +45,6 @@ export class MessagesService {
|
|||||||
return getSdkClient(provider)
|
return getSdkClient(provider)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async processMessage(request: MessageCreateParams, provider: Provider): Promise<Message> {
|
async processMessage(request: MessageCreateParams, provider: Provider): Promise<Message> {
|
||||||
logger.info('Processing Anthropic message request:', {
|
logger.info('Processing Anthropic message request:', {
|
||||||
model: request.model,
|
model: request.model,
|
||||||
@ -79,7 +77,7 @@ export class MessagesService {
|
|||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
async* processStreamingMessage(
|
async *processStreamingMessage(
|
||||||
request: MessageCreateParams,
|
request: MessageCreateParams,
|
||||||
provider: Provider
|
provider: Provider
|
||||||
): AsyncIterable<RawMessageStreamEvent> {
|
): AsyncIterable<RawMessageStreamEvent> {
|
||||||
|
|||||||
@ -20,4 +20,3 @@ export class AgentModelValidationError extends Error {
|
|||||||
this.detail = detail
|
this.detail = detail
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,16 @@
|
|||||||
import {loggerService} from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import type {AgentSessionMessageEntity, CreateSessionMessageRequest, GetAgentSessionResponse, ListOptions} from '@types'
|
import type {
|
||||||
import {TextStreamPart} from 'ai'
|
AgentSessionMessageEntity,
|
||||||
import {desc, eq} from 'drizzle-orm'
|
CreateSessionMessageRequest,
|
||||||
|
GetAgentSessionResponse,
|
||||||
|
ListOptions
|
||||||
|
} from '@types'
|
||||||
|
import { TextStreamPart } from 'ai'
|
||||||
|
import { desc, eq } from 'drizzle-orm'
|
||||||
|
|
||||||
import {BaseService} from '../BaseService'
|
import { BaseService } from '../BaseService'
|
||||||
import {sessionMessagesTable} from '../database/schema'
|
import { sessionMessagesTable } from '../database/schema'
|
||||||
import {AgentStreamEvent} from '../interfaces/AgentStreamInterface'
|
import { AgentStreamEvent } from '../interfaces/AgentStreamInterface'
|
||||||
import ClaudeCodeService from './claudecode'
|
import ClaudeCodeService from './claudecode'
|
||||||
|
|
||||||
const logger = loggerService.withContext('SessionMessageService')
|
const logger = loggerService.withContext('SessionMessageService')
|
||||||
@ -29,7 +34,7 @@ function serializeError(error: unknown): { message: string; name?: string; stack
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (typeof error === 'string') {
|
if (typeof error === 'string') {
|
||||||
return {message: error}
|
return { message: error }
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -99,7 +104,7 @@ export class SessionMessageService extends BaseService {
|
|||||||
this.ensureInitialized()
|
this.ensureInitialized()
|
||||||
|
|
||||||
const result = await this.database
|
const result = await this.database
|
||||||
.select({id: sessionMessagesTable.id})
|
.select({ id: sessionMessagesTable.id })
|
||||||
.from(sessionMessagesTable)
|
.from(sessionMessagesTable)
|
||||||
.where(eq(sessionMessagesTable.id, id))
|
.where(eq(sessionMessagesTable.id, id))
|
||||||
.limit(1)
|
.limit(1)
|
||||||
@ -129,7 +134,7 @@ export class SessionMessageService extends BaseService {
|
|||||||
|
|
||||||
const messages = result.map((row) => this.deserializeSessionMessage(row)) as AgentSessionMessageEntity[]
|
const messages = result.map((row) => this.deserializeSessionMessage(row)) as AgentSessionMessageEntity[]
|
||||||
|
|
||||||
return {messages}
|
return { messages }
|
||||||
}
|
}
|
||||||
|
|
||||||
async createSessionMessage(
|
async createSessionMessage(
|
||||||
@ -148,11 +153,11 @@ export class SessionMessageService extends BaseService {
|
|||||||
abortController: AbortController
|
abortController: AbortController
|
||||||
): Promise<SessionStreamResult> {
|
): Promise<SessionStreamResult> {
|
||||||
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
||||||
logger.debug('Session Message stream message data:', {message: req, session_id: agentSessionId})
|
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
|
||||||
|
|
||||||
if (session.agent_type !== 'claude-code') {
|
if (session.agent_type !== 'claude-code') {
|
||||||
// TODO: Implement support for other agent types
|
// TODO: Implement support for other agent types
|
||||||
logger.error('Unsupported agent type for streaming:', {agent_type: session.agent_type})
|
logger.error('Unsupported agent type for streaming:', { agent_type: session.agent_type })
|
||||||
throw new Error('Unsupported agent type for streaming')
|
throw new Error('Unsupported agent type for streaming')
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,7 +248,7 @@ export class SessionMessageService extends BaseService {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
return {stream, completion}
|
return { stream, completion }
|
||||||
}
|
}
|
||||||
|
|
||||||
private async getLastAgentSessionId(sessionId: string): Promise<string> {
|
private async getLastAgentSessionId(sessionId: string): Promise<string> {
|
||||||
@ -251,7 +256,7 @@ export class SessionMessageService extends BaseService {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const result = await this.database
|
const result = await this.database
|
||||||
.select({agent_session_id: sessionMessagesTable.agent_session_id})
|
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
|
||||||
.from(sessionMessagesTable)
|
.from(sessionMessagesTable)
|
||||||
.where(eq(sessionMessagesTable.session_id, sessionId))
|
.where(eq(sessionMessagesTable.session_id, sessionId))
|
||||||
.orderBy(desc(sessionMessagesTable.created_at))
|
.orderBy(desc(sessionMessagesTable.created_at))
|
||||||
@ -270,7 +275,7 @@ export class SessionMessageService extends BaseService {
|
|||||||
private deserializeSessionMessage(data: any): AgentSessionMessageEntity {
|
private deserializeSessionMessage(data: any): AgentSessionMessageEntity {
|
||||||
if (!data) return data
|
if (!data) return data
|
||||||
|
|
||||||
const deserialized = {...data}
|
const deserialized = { ...data }
|
||||||
|
|
||||||
// Parse content JSON
|
// Parse content JSON
|
||||||
if (deserialized.content && typeof deserialized.content === 'string') {
|
if (deserialized.content && typeof deserialized.content === 'string') {
|
||||||
|
|||||||
@ -119,7 +119,7 @@ function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assi
|
|||||||
toolCallId: block.tool_use_id,
|
toolCallId: block.tool_use_id,
|
||||||
toolName: '',
|
toolName: '',
|
||||||
input: '',
|
input: '',
|
||||||
output: block.content,
|
output: block.content
|
||||||
})
|
})
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
@ -244,17 +244,18 @@ function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }
|
|||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
|
||||||
case 'content_block_stop': {
|
case 'content_block_stop':
|
||||||
const contentBlock = contentBlockState.get(blockKey)
|
{
|
||||||
if (contentBlock?.type === 'text') {
|
const contentBlock = contentBlockState.get(blockKey)
|
||||||
chunks.push({
|
if (contentBlock?.type === 'text') {
|
||||||
type: 'text-end',
|
chunks.push({
|
||||||
id: String(event.index)
|
type: 'text-end',
|
||||||
})
|
id: String(event.index)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
contentBlockState.delete(blockKey)
|
||||||
}
|
}
|
||||||
contentBlockState.delete(blockKey)
|
break
|
||||||
}
|
|
||||||
break
|
|
||||||
case 'message_delta':
|
case 'message_delta':
|
||||||
// Handle usage updates or other message-level deltas
|
// Handle usage updates or other message-level deltas
|
||||||
break
|
break
|
||||||
@ -304,9 +305,7 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
|
|||||||
usage = {
|
usage = {
|
||||||
inputTokens: message.usage.input_tokens ?? 0,
|
inputTokens: message.usage.input_tokens ?? 0,
|
||||||
outputTokens: message.usage.output_tokens ?? 0,
|
outputTokens: message.usage.output_tokens ?? 0,
|
||||||
totalTokens:
|
totalTokens: (message.usage.input_tokens ?? 0) + (message.usage.output_tokens ?? 0)
|
||||||
(message.usage.input_tokens ?? 0) +
|
|
||||||
(message.usage.output_tokens ?? 0)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (message.subtype === 'success') {
|
if (message.subtype === 'success') {
|
||||||
|
|||||||
@ -22,14 +22,14 @@ import {
|
|||||||
WebSearchToolResultBlockParam,
|
WebSearchToolResultBlockParam,
|
||||||
WebSearchToolResultError
|
WebSearchToolResultError
|
||||||
} from '@anthropic-ai/sdk/resources/messages'
|
} from '@anthropic-ai/sdk/resources/messages'
|
||||||
import {MessageStream} from '@anthropic-ai/sdk/resources/messages/messages'
|
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||||
import {loggerService} from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import {DEFAULT_MAX_TOKENS} from '@renderer/config/constant'
|
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||||
import {findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel} from '@renderer/config/models'
|
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||||
import {getAssistantSettings} from '@renderer/services/AssistantService'
|
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||||
import FileManager from '@renderer/services/FileManager'
|
import FileManager from '@renderer/services/FileManager'
|
||||||
import {estimateTextTokens} from '@renderer/services/TokenService'
|
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||||
import {
|
import {
|
||||||
Assistant,
|
Assistant,
|
||||||
EFFORT_RATIO,
|
EFFORT_RATIO,
|
||||||
@ -53,17 +53,27 @@ import {
|
|||||||
ThinkingDeltaChunk,
|
ThinkingDeltaChunk,
|
||||||
ThinkingStartChunk
|
ThinkingStartChunk
|
||||||
} from '@renderer/types/chunk'
|
} from '@renderer/types/chunk'
|
||||||
import {type Message} from '@renderer/types/newMessage'
|
import { type Message } from '@renderer/types/newMessage'
|
||||||
import {AnthropicSdkMessageParam, AnthropicSdkParams, AnthropicSdkRawChunk, AnthropicSdkRawOutput} from '@renderer/types/sdk'
|
import {
|
||||||
import {addImageFileToContents} from '@renderer/utils/formats'
|
AnthropicSdkMessageParam,
|
||||||
import {anthropicToolUseToMcpTool, isSupportedToolUse, mcpToolCallResponseToAnthropicMessage, mcpToolsToAnthropicTools} from '@renderer/utils/mcp-tools'
|
AnthropicSdkParams,
|
||||||
import {findFileBlocks, findImageBlocks} from '@renderer/utils/messageUtils/find'
|
AnthropicSdkRawChunk,
|
||||||
import {buildClaudeCodeSystemMessage, getSdkClient} from "@shared/anthropic";
|
AnthropicSdkRawOutput
|
||||||
import {t} from 'i18next'
|
} from '@renderer/types/sdk'
|
||||||
|
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||||
|
import {
|
||||||
|
anthropicToolUseToMcpTool,
|
||||||
|
isSupportedToolUse,
|
||||||
|
mcpToolCallResponseToAnthropicMessage,
|
||||||
|
mcpToolsToAnthropicTools
|
||||||
|
} from '@renderer/utils/mcp-tools'
|
||||||
|
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||||
|
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
|
||||||
|
import { t } from 'i18next'
|
||||||
|
|
||||||
import {GenericChunk} from '../../middleware/schemas'
|
import { GenericChunk } from '../../middleware/schemas'
|
||||||
import {BaseApiClient} from '../BaseApiClient'
|
import { BaseApiClient } from '../BaseApiClient'
|
||||||
import {AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer} from '../types'
|
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||||
|
|
||||||
const logger = loggerService.withContext('AnthropicAPIClient')
|
const logger = loggerService.withContext('AnthropicAPIClient')
|
||||||
|
|
||||||
@ -105,7 +115,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
|||||||
if (payload.stream) {
|
if (payload.stream) {
|
||||||
return sdk.messages.stream(payload, options)
|
return sdk.messages.stream(payload, options)
|
||||||
}
|
}
|
||||||
return sdk.messages.create(payload, options);
|
return sdk.messages.create(payload, options)
|
||||||
}
|
}
|
||||||
|
|
||||||
// @ts-ignore sdk未提供
|
// @ts-ignore sdk未提供
|
||||||
@ -149,7 +159,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
|||||||
if (!isReasoningModel(model)) {
|
if (!isReasoningModel(model)) {
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
const {maxTokens} = getAssistantSettings(assistant)
|
const { maxTokens } = getAssistantSettings(assistant)
|
||||||
|
|
||||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||||
|
|
||||||
@ -166,7 +176,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
|||||||
Math.floor(
|
Math.floor(
|
||||||
Math.min(
|
Math.min(
|
||||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||||
findTokenLimit(model.id)?.min!,
|
findTokenLimit(model.id)?.min!,
|
||||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -188,7 +198,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
|||||||
* @returns The message parameter
|
* @returns The message parameter
|
||||||
*/
|
*/
|
||||||
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
|
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
|
||||||
const {textContent, imageContents} = await this.getMessageContent(message)
|
const { textContent, imageContents } = await this.getMessageContent(message)
|
||||||
|
|
||||||
const parts: MessageParam['content'] = [
|
const parts: MessageParam['content'] = [
|
||||||
{
|
{
|
||||||
@ -211,7 +221,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
logger.warn('Unsupported image type, ignored.', {mime: base64Data.mime})
|
logger.warn('Unsupported image type, ignored.', { mime: base64Data.mime })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -236,7 +246,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
|||||||
// Get and process file blocks
|
// Get and process file blocks
|
||||||
const fileBlocks = findFileBlocks(message)
|
const fileBlocks = findFileBlocks(message)
|
||||||
for (const fileBlock of fileBlocks) {
|
for (const fileBlock of fileBlocks) {
|
||||||
const {file} = fileBlock
|
const { file } = fileBlock
|
||||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||||
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
|
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
|
||||||
const base64Data = await FileManager.readBase64File(file)
|
const base64Data = await FileManager.readBase64File(file)
|
||||||
@ -464,25 +474,25 @@ export class AnthropicAPIClient extends BaseApiClient<
|
|||||||
messages: AnthropicSdkMessageParam[]
|
messages: AnthropicSdkMessageParam[]
|
||||||
metadata: Record<string, any>
|
metadata: Record<string, any>
|
||||||
}> => {
|
}> => {
|
||||||
const {messages, mcpTools, maxTokens, streamOutput, enableWebSearch} = coreRequest
|
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
|
||||||
// 1. 处理系统消息
|
// 1. 处理系统消息
|
||||||
const systemPrompt = assistant.prompt
|
const systemPrompt = assistant.prompt
|
||||||
|
|
||||||
// 2. 设置工具
|
// 2. 设置工具
|
||||||
const {tools} = this.setupToolsConfig({
|
const { tools } = this.setupToolsConfig({
|
||||||
mcpTools: mcpTools,
|
mcpTools: mcpTools,
|
||||||
model,
|
model,
|
||||||
enableToolUse: isSupportedToolUse(assistant)
|
enableToolUse: isSupportedToolUse(assistant)
|
||||||
})
|
})
|
||||||
|
|
||||||
const systemMessage: TextBlockParam | undefined = systemPrompt
|
const systemMessage: TextBlockParam | undefined = systemPrompt
|
||||||
? {type: 'text', text: systemPrompt}
|
? { type: 'text', text: systemPrompt }
|
||||||
: undefined
|
: undefined
|
||||||
|
|
||||||
// 3. 处理用户消息
|
// 3. 处理用户消息
|
||||||
const sdkMessages: AnthropicSdkMessageParam[] = []
|
const sdkMessages: AnthropicSdkMessageParam[] = []
|
||||||
if (typeof messages === 'string') {
|
if (typeof messages === 'string') {
|
||||||
sdkMessages.push({role: 'user', content: messages})
|
sdkMessages.push({ role: 'user', content: messages })
|
||||||
} else {
|
} else {
|
||||||
const processedMessages = addImageFileToContents(messages)
|
const processedMessages = addImageFileToContents(messages)
|
||||||
for (const message of processedMessages) {
|
for (const message of processedMessages) {
|
||||||
@ -516,7 +526,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
|||||||
}
|
}
|
||||||
|
|
||||||
const timeout = this.getTimeout(model)
|
const timeout = this.getTimeout(model)
|
||||||
return {payload: commonParams, messages: sdkMessages, metadata: {timeout}}
|
return { payload: commonParams, messages: sdkMessages, metadata: { timeout } }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -531,7 +541,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
|||||||
try {
|
try {
|
||||||
rawChunk = JSON.parse(rawChunk)
|
rawChunk = JSON.parse(rawChunk)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('invalid chunk', {rawChunk, error})
|
logger.error('invalid chunk', { rawChunk, error })
|
||||||
throw new Error(t('error.chat.chunk.non_json'))
|
throw new Error(t('error.chat.chunk.non_json'))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { useAppDispatch } from '@renderer/store'
|
import { useAppDispatch } from '@renderer/store'
|
||||||
import { removeManyBlocks,upsertManyBlocks } from '@renderer/store/messageBlock'
|
import { removeManyBlocks, upsertManyBlocks } from '@renderer/store/messageBlock'
|
||||||
import { newMessagesActions } from '@renderer/store/newMessage'
|
import { newMessagesActions } from '@renderer/store/newMessage'
|
||||||
import { AgentPersistedMessage, UpdateSessionForm } from '@renderer/types'
|
import { AgentPersistedMessage, UpdateSessionForm } from '@renderer/types'
|
||||||
import { buildAgentSessionTopicId } from '@renderer/utils/agentSession'
|
import { buildAgentSessionTopicId } from '@renderer/utils/agentSession'
|
||||||
|
|||||||
@ -17,7 +17,7 @@ import { classNames } from '@renderer/utils'
|
|||||||
import { Flex } from 'antd'
|
import { Flex } from 'antd'
|
||||||
import { debounce } from 'lodash'
|
import { debounce } from 'lodash'
|
||||||
import { AnimatePresence, motion } from 'motion/react'
|
import { AnimatePresence, motion } from 'motion/react'
|
||||||
import React, { FC, useMemo,useState } from 'react'
|
import React, { FC, useMemo, useState } from 'react'
|
||||||
import { useHotkeys } from 'react-hotkeys-hook'
|
import { useHotkeys } from 'react-hotkeys-hook'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import styled from 'styled-components'
|
import styled from 'styled-components'
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import { useAppSelector } from '@renderer/store'
|
|||||||
import { selectMessagesForTopic } from '@renderer/store/newMessage'
|
import { selectMessagesForTopic } from '@renderer/store/newMessage'
|
||||||
import { Topic } from '@renderer/types'
|
import { Topic } from '@renderer/types'
|
||||||
import { buildAgentSessionTopicId } from '@renderer/utils/agentSession'
|
import { buildAgentSessionTopicId } from '@renderer/utils/agentSession'
|
||||||
import { memo,useMemo } from 'react'
|
import { memo, useMemo } from 'react'
|
||||||
import styled from 'styled-components'
|
import styled from 'styled-components'
|
||||||
|
|
||||||
import MessageGroup from './MessageGroup'
|
import MessageGroup from './MessageGroup'
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import type {ProviderMetadata} from "ai";
|
import type { ProviderMetadata } from 'ai'
|
||||||
import type { CompletionUsage } from 'openai/resources'
|
import type { CompletionUsage } from 'openai/resources'
|
||||||
|
|
||||||
import type {
|
import type {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user