mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-20 23:22:05 +08:00
feat(chat): enhance chat completion error handling and streaming support
feat(messages): improve message validation and add streaming support feat(tests): add HTTP tests for chat and message endpoints
This commit is contained in:
parent
73380d76df
commit
2cf2f04a70
@ -1,15 +1,105 @@
|
|||||||
import express, { Request, Response } from 'express'
|
import express, { Request, Response } from 'express'
|
||||||
import OpenAI from 'openai'
|
|
||||||
import { ChatCompletionCreateParams } from 'openai/resources'
|
import { ChatCompletionCreateParams } from 'openai/resources'
|
||||||
|
|
||||||
import { loggerService } from '../../services/LoggerService'
|
import { loggerService } from '../../services/LoggerService'
|
||||||
import { chatCompletionService } from '../services/chat-completion'
|
import {
|
||||||
import { validateModelId } from '../utils'
|
ChatCompletionModelError,
|
||||||
|
chatCompletionService,
|
||||||
|
ChatCompletionValidationError
|
||||||
|
} from '../services/chat-completion'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServerChatRoutes')
|
const logger = loggerService.withContext('ApiServerChatRoutes')
|
||||||
|
|
||||||
const router = express.Router()
|
const router = express.Router()
|
||||||
|
|
||||||
|
interface ErrorResponseBody {
|
||||||
|
error: {
|
||||||
|
message: string
|
||||||
|
type: string
|
||||||
|
code: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const mapChatCompletionError = (error: unknown): { status: number; body: ErrorResponseBody } => {
|
||||||
|
if (error instanceof ChatCompletionValidationError) {
|
||||||
|
logger.warn('Chat completion validation error:', {
|
||||||
|
errors: error.errors
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
status: 400,
|
||||||
|
body: {
|
||||||
|
error: {
|
||||||
|
message: error.errors.join('; '),
|
||||||
|
type: 'invalid_request_error',
|
||||||
|
code: 'validation_failed'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error instanceof ChatCompletionModelError) {
|
||||||
|
logger.warn('Chat completion model error:', error.error)
|
||||||
|
|
||||||
|
return {
|
||||||
|
status: 400,
|
||||||
|
body: {
|
||||||
|
error: {
|
||||||
|
message: error.error.message,
|
||||||
|
type: 'invalid_request_error',
|
||||||
|
code: error.error.code
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error instanceof Error) {
|
||||||
|
let statusCode = 500
|
||||||
|
let errorType = 'server_error'
|
||||||
|
let errorCode = 'internal_error'
|
||||||
|
|
||||||
|
if (error.message.includes('API key') || error.message.includes('authentication')) {
|
||||||
|
statusCode = 401
|
||||||
|
errorType = 'authentication_error'
|
||||||
|
errorCode = 'invalid_api_key'
|
||||||
|
} else if (error.message.includes('rate limit') || error.message.includes('quota')) {
|
||||||
|
statusCode = 429
|
||||||
|
errorType = 'rate_limit_error'
|
||||||
|
errorCode = 'rate_limit_exceeded'
|
||||||
|
} else if (error.message.includes('timeout') || error.message.includes('connection')) {
|
||||||
|
statusCode = 502
|
||||||
|
errorType = 'server_error'
|
||||||
|
errorCode = 'upstream_error'
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.error('Chat completion error:', { error })
|
||||||
|
|
||||||
|
return {
|
||||||
|
status: statusCode,
|
||||||
|
body: {
|
||||||
|
error: {
|
||||||
|
message: error.message || 'Internal server error',
|
||||||
|
type: errorType,
|
||||||
|
code: errorCode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.error('Chat completion unknown error:', { error })
|
||||||
|
|
||||||
|
return {
|
||||||
|
status: 500,
|
||||||
|
body: {
|
||||||
|
error: {
|
||||||
|
message: 'Internal server error',
|
||||||
|
type: 'server_error',
|
||||||
|
code: 'internal_error'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /v1/chat/completions:
|
* /v1/chat/completions:
|
||||||
@ -60,7 +150,7 @@ const router = express.Router()
|
|||||||
* type: integer
|
* type: integer
|
||||||
* total_tokens:
|
* total_tokens:
|
||||||
* type: integer
|
* type: integer
|
||||||
* text/plain:
|
* text/event-stream:
|
||||||
* schema:
|
* schema:
|
||||||
* type: string
|
* type: string
|
||||||
* description: Server-sent events stream (when stream=true)
|
* description: Server-sent events stream (when stream=true)
|
||||||
@ -110,63 +200,22 @@ router.post('/completions', async (req: Request, res: Response) => {
|
|||||||
temperature: request.temperature
|
temperature: request.temperature
|
||||||
})
|
})
|
||||||
|
|
||||||
// Validate request
|
const isStreaming = !!request.stream
|
||||||
const validation = chatCompletionService.validateRequest(request)
|
|
||||||
if (!validation.isValid) {
|
|
||||||
return res.status(400).json({
|
|
||||||
error: {
|
|
||||||
message: validation.errors.join('; '),
|
|
||||||
type: 'invalid_request_error',
|
|
||||||
code: 'validation_failed'
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate model ID and get provider
|
if (isStreaming) {
|
||||||
const modelValidation = await validateModelId(request.model)
|
const { stream } = await chatCompletionService.processStreamingCompletion(request)
|
||||||
if (!modelValidation.valid) {
|
|
||||||
const error = modelValidation.error!
|
|
||||||
logger.warn(`Model validation failed for '${request.model}':`, error)
|
|
||||||
return res.status(400).json({
|
|
||||||
error: {
|
|
||||||
message: error.message,
|
|
||||||
type: 'invalid_request_error',
|
|
||||||
code: error.code
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const provider = modelValidation.provider!
|
res.setHeader('Content-Type', 'text/event-stream; charset=utf-8')
|
||||||
const modelId = modelValidation.modelId!
|
res.setHeader('Cache-Control', 'no-cache, no-transform')
|
||||||
|
|
||||||
logger.info('Model validation successful:', {
|
|
||||||
provider: provider.id,
|
|
||||||
providerType: provider.type,
|
|
||||||
modelId: modelId,
|
|
||||||
fullModelId: request.model
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create OpenAI client
|
|
||||||
const client = new OpenAI({
|
|
||||||
baseURL: provider.apiHost,
|
|
||||||
apiKey: provider.apiKey
|
|
||||||
})
|
|
||||||
request.model = modelId
|
|
||||||
|
|
||||||
// Handle streaming
|
|
||||||
if (request.stream) {
|
|
||||||
const streamResponse = await client.chat.completions.create(request)
|
|
||||||
|
|
||||||
res.setHeader('Content-Type', 'text/plain; charset=utf-8')
|
|
||||||
res.setHeader('Cache-Control', 'no-cache')
|
|
||||||
res.setHeader('Connection', 'keep-alive')
|
res.setHeader('Connection', 'keep-alive')
|
||||||
|
res.setHeader('X-Accel-Buffering', 'no')
|
||||||
|
res.flushHeaders()
|
||||||
|
|
||||||
try {
|
try {
|
||||||
for await (const chunk of streamResponse as any) {
|
for await (const chunk of stream) {
|
||||||
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
|
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
|
||||||
}
|
}
|
||||||
res.write('data: [DONE]\n\n')
|
res.write('data: [DONE]\n\n')
|
||||||
res.end()
|
|
||||||
} catch (streamError: any) {
|
} catch (streamError: any) {
|
||||||
logger.error('Stream error:', streamError)
|
logger.error('Stream error:', streamError)
|
||||||
res.write(
|
res.write(
|
||||||
@ -178,47 +227,17 @@ router.post('/completions', async (req: Request, res: Response) => {
|
|||||||
}
|
}
|
||||||
})}\n\n`
|
})}\n\n`
|
||||||
)
|
)
|
||||||
|
} finally {
|
||||||
res.end()
|
res.end()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle non-streaming
|
const { response } = await chatCompletionService.processCompletion(request)
|
||||||
const response = await client.chat.completions.create(request)
|
|
||||||
return res.json(response)
|
return res.json(response)
|
||||||
} catch (error: any) {
|
} catch (error: unknown) {
|
||||||
logger.error('Chat completion error:', error)
|
const { status, body } = mapChatCompletionError(error)
|
||||||
|
return res.status(status).json(body)
|
||||||
let statusCode = 500
|
|
||||||
let errorType = 'server_error'
|
|
||||||
let errorCode = 'internal_error'
|
|
||||||
let errorMessage = 'Internal server error'
|
|
||||||
|
|
||||||
if (error instanceof Error) {
|
|
||||||
errorMessage = error.message
|
|
||||||
|
|
||||||
if (error.message.includes('API key') || error.message.includes('authentication')) {
|
|
||||||
statusCode = 401
|
|
||||||
errorType = 'authentication_error'
|
|
||||||
errorCode = 'invalid_api_key'
|
|
||||||
} else if (error.message.includes('rate limit') || error.message.includes('quota')) {
|
|
||||||
statusCode = 429
|
|
||||||
errorType = 'rate_limit_error'
|
|
||||||
errorCode = 'rate_limit_exceeded'
|
|
||||||
} else if (error.message.includes('timeout') || error.message.includes('connection')) {
|
|
||||||
statusCode = 502
|
|
||||||
errorType = 'server_error'
|
|
||||||
errorCode = 'upstream_error'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return res.status(statusCode).json({
|
|
||||||
error: {
|
|
||||||
message: errorMessage,
|
|
||||||
type: errorType,
|
|
||||||
code: errorCode
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@ -104,7 +104,7 @@ const router = express.Router()
|
|||||||
* type: integer
|
* type: integer
|
||||||
* output_tokens:
|
* output_tokens:
|
||||||
* type: integer
|
* type: integer
|
||||||
* text/plain:
|
* text/event-stream:
|
||||||
* schema:
|
* schema:
|
||||||
* type: string
|
* type: string
|
||||||
* description: Server-sent events stream (when stream=true)
|
* description: Server-sent events stream (when stream=true)
|
||||||
@ -154,18 +154,6 @@ router.post('/', async (req: Request, res: Response) => {
|
|||||||
temperature: request.temperature
|
temperature: request.temperature
|
||||||
})
|
})
|
||||||
|
|
||||||
// Validate request
|
|
||||||
const validation = messagesService.validateRequest(request)
|
|
||||||
if (!validation.isValid) {
|
|
||||||
return res.status(400).json({
|
|
||||||
type: 'error',
|
|
||||||
error: {
|
|
||||||
type: 'invalid_request_error',
|
|
||||||
message: validation.errors.join('; ')
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate model ID and get provider
|
// Validate model ID and get provider
|
||||||
const modelValidation = await validateModelId(request.model)
|
const modelValidation = await validateModelId(request.model)
|
||||||
if (!modelValidation.valid) {
|
if (!modelValidation.valid) {
|
||||||
@ -203,18 +191,31 @@ router.post('/', async (req: Request, res: Response) => {
|
|||||||
fullModelId: request.model
|
fullModelId: request.model
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Validate request
|
||||||
|
const validation = messagesService.validateRequest(request)
|
||||||
|
if (!validation.isValid) {
|
||||||
|
return res.status(400).json({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'invalid_request_error',
|
||||||
|
message: validation.errors.join('; ')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Handle streaming
|
// Handle streaming
|
||||||
if (request.stream) {
|
if (request.stream) {
|
||||||
res.setHeader('Content-Type', 'text/plain; charset=utf-8')
|
res.setHeader('Content-Type', 'text/event-stream; charset=utf-8')
|
||||||
res.setHeader('Cache-Control', 'no-cache')
|
res.setHeader('Cache-Control', 'no-cache, no-transform')
|
||||||
res.setHeader('Connection', 'keep-alive')
|
res.setHeader('Connection', 'keep-alive')
|
||||||
|
res.setHeader('X-Accel-Buffering', 'no')
|
||||||
|
res.flushHeaders()
|
||||||
|
|
||||||
try {
|
try {
|
||||||
for await (const chunk of messagesService.processStreamingMessage(request, provider)) {
|
for await (const chunk of messagesService.processStreamingMessage(request, provider)) {
|
||||||
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
|
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
|
||||||
}
|
}
|
||||||
res.write('data: [DONE]\n\n')
|
res.write('data: [DONE]\n\n')
|
||||||
res.end()
|
|
||||||
} catch (streamError: any) {
|
} catch (streamError: any) {
|
||||||
logger.error('Stream error:', streamError)
|
logger.error('Stream error:', streamError)
|
||||||
res.write(
|
res.write(
|
||||||
@ -226,6 +227,7 @@ router.post('/', async (req: Request, res: Response) => {
|
|||||||
}
|
}
|
||||||
})}\n\n`
|
})}\n\n`
|
||||||
)
|
)
|
||||||
|
} finally {
|
||||||
res.end()
|
res.end()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -241,9 +243,24 @@ router.post('/', async (req: Request, res: Response) => {
|
|||||||
let errorType = 'api_error'
|
let errorType = 'api_error'
|
||||||
let errorMessage = 'Internal server error'
|
let errorMessage = 'Internal server error'
|
||||||
|
|
||||||
if (error instanceof Error) {
|
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
|
||||||
errorMessage = error.message
|
const anthropicError = error?.error
|
||||||
|
|
||||||
|
if (anthropicStatus) {
|
||||||
|
statusCode = anthropicStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
if (anthropicError?.type) {
|
||||||
|
errorType = anthropicError.type
|
||||||
|
}
|
||||||
|
|
||||||
|
if (anthropicError?.message) {
|
||||||
|
errorMessage = anthropicError.message
|
||||||
|
} else if (error instanceof Error && error.message) {
|
||||||
|
errorMessage = error.message
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!anthropicStatus && error instanceof Error) {
|
||||||
if (error.message.includes('API key') || error.message.includes('authentication')) {
|
if (error.message.includes('API key') || error.message.includes('authentication')) {
|
||||||
statusCode = 401
|
statusCode = 401
|
||||||
errorType = 'authentication_error'
|
errorType = 'authentication_error'
|
||||||
@ -263,7 +280,8 @@ router.post('/', async (req: Request, res: Response) => {
|
|||||||
type: 'error',
|
type: 'error',
|
||||||
error: {
|
error: {
|
||||||
type: errorType,
|
type: errorType,
|
||||||
message: errorMessage
|
message: errorMessage,
|
||||||
|
requestId: error?.request_id
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
|
import { Provider } from '@types'
|
||||||
import OpenAI from 'openai'
|
import OpenAI from 'openai'
|
||||||
import { ChatCompletionCreateParams } from 'openai/resources'
|
import { ChatCompletionCreateParams, ChatCompletionCreateParamsStreaming } from 'openai/resources'
|
||||||
|
|
||||||
import { loggerService } from '../../services/LoggerService'
|
import { loggerService } from '../../services/LoggerService'
|
||||||
import { getProviderByModel, getRealProviderModel, validateProvider } from '../utils'
|
import { ModelValidationError, validateModelId } from '../utils'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ChatCompletionService')
|
const logger = loggerService.withContext('ChatCompletionService')
|
||||||
|
|
||||||
@ -11,19 +12,120 @@ export interface ValidationResult {
|
|||||||
errors: string[]
|
errors: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export class ChatCompletionValidationError extends Error {
|
||||||
|
constructor(public readonly errors: string[]) {
|
||||||
|
super(`Request validation failed: ${errors.join('; ')}`)
|
||||||
|
this.name = 'ChatCompletionValidationError'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ChatCompletionModelError extends Error {
|
||||||
|
constructor(public readonly error: ModelValidationError) {
|
||||||
|
super(`Model validation failed: ${error.message}`)
|
||||||
|
this.name = 'ChatCompletionModelError'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export type PrepareRequestResult =
|
||||||
|
| { status: 'validation_error'; errors: string[] }
|
||||||
|
| { status: 'model_error'; error: ModelValidationError }
|
||||||
|
| {
|
||||||
|
status: 'ok'
|
||||||
|
provider: Provider
|
||||||
|
modelId: string
|
||||||
|
client: OpenAI
|
||||||
|
providerRequest: ChatCompletionCreateParams
|
||||||
|
}
|
||||||
|
|
||||||
export class ChatCompletionService {
|
export class ChatCompletionService {
|
||||||
|
async resolveProviderContext(model: string): Promise<
|
||||||
|
| { ok: false; error: ModelValidationError }
|
||||||
|
| { ok: true; provider: Provider; modelId: string; client: OpenAI }
|
||||||
|
> {
|
||||||
|
const modelValidation = await validateModelId(model)
|
||||||
|
if (!modelValidation.valid) {
|
||||||
|
return {
|
||||||
|
ok: false,
|
||||||
|
error: modelValidation.error!
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const provider = modelValidation.provider!
|
||||||
|
|
||||||
|
if (provider.type !== 'openai') {
|
||||||
|
return {
|
||||||
|
ok: false,
|
||||||
|
error: {
|
||||||
|
type: 'unsupported_provider_type',
|
||||||
|
message: `Provider '${provider.id}' of type '${provider.type}' is not supported for OpenAI chat completions`,
|
||||||
|
code: 'unsupported_provider_type'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const modelId = modelValidation.modelId!
|
||||||
|
|
||||||
|
const client = new OpenAI({
|
||||||
|
baseURL: provider.apiHost,
|
||||||
|
apiKey: provider.apiKey
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
ok: true,
|
||||||
|
provider,
|
||||||
|
modelId,
|
||||||
|
client
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async prepareRequest(request: ChatCompletionCreateParams, stream: boolean): Promise<PrepareRequestResult> {
|
||||||
|
const requestValidation = this.validateRequest(request)
|
||||||
|
if (!requestValidation.isValid) {
|
||||||
|
return {
|
||||||
|
status: 'validation_error',
|
||||||
|
errors: requestValidation.errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const providerContext = await this.resolveProviderContext(request.model!)
|
||||||
|
if (!providerContext.ok) {
|
||||||
|
return {
|
||||||
|
status: 'model_error',
|
||||||
|
error: providerContext.error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const { provider, modelId, client } = providerContext
|
||||||
|
|
||||||
|
logger.info('Model validation successful:', {
|
||||||
|
provider: provider.id,
|
||||||
|
providerType: provider.type,
|
||||||
|
modelId,
|
||||||
|
fullModelId: request.model
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
status: 'ok',
|
||||||
|
provider,
|
||||||
|
modelId,
|
||||||
|
client,
|
||||||
|
providerRequest: stream
|
||||||
|
? {
|
||||||
|
...request,
|
||||||
|
model: modelId,
|
||||||
|
stream: true as const
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
...request,
|
||||||
|
model: modelId,
|
||||||
|
stream: false as const
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
validateRequest(request: ChatCompletionCreateParams): ValidationResult {
|
validateRequest(request: ChatCompletionCreateParams): ValidationResult {
|
||||||
const errors: string[] = []
|
const errors: string[] = []
|
||||||
|
|
||||||
// Validate model
|
|
||||||
if (!request.model) {
|
|
||||||
errors.push('Model is required')
|
|
||||||
} else if (typeof request.model !== 'string') {
|
|
||||||
errors.push('Model must be a string')
|
|
||||||
} else if (!request.model.includes(':')) {
|
|
||||||
errors.push('Model must be in format "provider:model_id"')
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate messages
|
// Validate messages
|
||||||
if (!request.messages) {
|
if (!request.messages) {
|
||||||
errors.push('Messages array is required')
|
errors.push('Messages array is required')
|
||||||
@ -51,7 +153,11 @@ export class ChatCompletionService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async processCompletion(request: ChatCompletionCreateParams): Promise<OpenAI.Chat.Completions.ChatCompletion> {
|
async processCompletion(request: ChatCompletionCreateParams): Promise<{
|
||||||
|
provider: Provider
|
||||||
|
modelId: string
|
||||||
|
response: OpenAI.Chat.Completions.ChatCompletion
|
||||||
|
}> {
|
||||||
try {
|
try {
|
||||||
logger.info('Processing chat completion request:', {
|
logger.info('Processing chat completion request:', {
|
||||||
model: request.model,
|
model: request.model,
|
||||||
@ -59,38 +165,16 @@ export class ChatCompletionService {
|
|||||||
stream: request.stream
|
stream: request.stream
|
||||||
})
|
})
|
||||||
|
|
||||||
// Validate request
|
const preparation = await this.prepareRequest(request, false)
|
||||||
const validation = this.validateRequest(request)
|
if (preparation.status === 'validation_error') {
|
||||||
if (!validation.isValid) {
|
throw new ChatCompletionValidationError(preparation.errors)
|
||||||
throw new Error(`Request validation failed: ${validation.errors.join(', ')}`)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get provider for the model
|
if (preparation.status === 'model_error') {
|
||||||
const provider = await getProviderByModel(request.model!)
|
throw new ChatCompletionModelError(preparation.error)
|
||||||
if (!provider) {
|
|
||||||
throw new Error(`Provider not found for model: ${request.model}`)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate provider
|
const { provider, modelId, client, providerRequest } = preparation
|
||||||
if (!validateProvider(provider)) {
|
|
||||||
throw new Error(`Provider validation failed for: ${provider.id}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract model ID from the full model string
|
|
||||||
const modelId = getRealProviderModel(request.model)
|
|
||||||
|
|
||||||
// Create OpenAI client for the provider
|
|
||||||
const client = new OpenAI({
|
|
||||||
baseURL: provider.apiHost,
|
|
||||||
apiKey: provider.apiKey
|
|
||||||
})
|
|
||||||
|
|
||||||
// Prepare request with the actual model ID
|
|
||||||
const providerRequest = {
|
|
||||||
...request,
|
|
||||||
model: modelId,
|
|
||||||
stream: false
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug('Sending request to provider:', {
|
logger.debug('Sending request to provider:', {
|
||||||
provider: provider.id,
|
provider: provider.id,
|
||||||
@ -101,54 +185,40 @@ export class ChatCompletionService {
|
|||||||
const response = (await client.chat.completions.create(providerRequest)) as OpenAI.Chat.Completions.ChatCompletion
|
const response = (await client.chat.completions.create(providerRequest)) as OpenAI.Chat.Completions.ChatCompletion
|
||||||
|
|
||||||
logger.info('Successfully processed chat completion')
|
logger.info('Successfully processed chat completion')
|
||||||
return response
|
return {
|
||||||
|
provider,
|
||||||
|
modelId,
|
||||||
|
response
|
||||||
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error processing chat completion:', error)
|
logger.error('Error processing chat completion:', error)
|
||||||
throw error
|
throw error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async *processStreamingCompletion(
|
async processStreamingCompletion(
|
||||||
request: ChatCompletionCreateParams
|
request: ChatCompletionCreateParams
|
||||||
): AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk> {
|
): Promise<{
|
||||||
|
provider: Provider
|
||||||
|
modelId: string
|
||||||
|
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
|
||||||
|
}> {
|
||||||
try {
|
try {
|
||||||
logger.info('Processing streaming chat completion request:', {
|
logger.info('Processing streaming chat completion request:', {
|
||||||
model: request.model,
|
model: request.model,
|
||||||
messageCount: request.messages.length
|
messageCount: request.messages.length
|
||||||
})
|
})
|
||||||
|
|
||||||
// Validate request
|
const preparation = await this.prepareRequest(request, true)
|
||||||
const validation = this.validateRequest(request)
|
if (preparation.status === 'validation_error') {
|
||||||
if (!validation.isValid) {
|
throw new ChatCompletionValidationError(preparation.errors)
|
||||||
throw new Error(`Request validation failed: ${validation.errors.join(', ')}`)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get provider for the model
|
if (preparation.status === 'model_error') {
|
||||||
const provider = await getProviderByModel(request.model!)
|
throw new ChatCompletionModelError(preparation.error)
|
||||||
if (!provider) {
|
|
||||||
throw new Error(`Provider not found for model: ${request.model}`)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate provider
|
const { provider, modelId, client, providerRequest } = preparation
|
||||||
if (!validateProvider(provider)) {
|
|
||||||
throw new Error(`Provider validation failed for: ${provider.id}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract model ID from the full model string
|
|
||||||
const modelId = getRealProviderModel(request.model)
|
|
||||||
|
|
||||||
// Create OpenAI client for the provider
|
|
||||||
const client = new OpenAI({
|
|
||||||
baseURL: provider.apiHost,
|
|
||||||
apiKey: provider.apiKey
|
|
||||||
})
|
|
||||||
|
|
||||||
// Prepare streaming request
|
|
||||||
const streamingRequest = {
|
|
||||||
...request,
|
|
||||||
model: modelId,
|
|
||||||
stream: true as const
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug('Sending streaming request to provider:', {
|
logger.debug('Sending streaming request to provider:', {
|
||||||
provider: provider.id,
|
provider: provider.id,
|
||||||
@ -156,13 +226,17 @@ export class ChatCompletionService {
|
|||||||
apiHost: provider.apiHost
|
apiHost: provider.apiHost
|
||||||
})
|
})
|
||||||
|
|
||||||
const stream = await client.chat.completions.create(streamingRequest)
|
const streamRequest = providerRequest as ChatCompletionCreateParamsStreaming
|
||||||
|
const stream = (await client.chat.completions.create(streamRequest)) as AsyncIterable<
|
||||||
|
OpenAI.Chat.Completions.ChatCompletionChunk
|
||||||
|
>
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
logger.info('Successfully started streaming chat completion')
|
||||||
yield chunk
|
return {
|
||||||
|
provider,
|
||||||
|
modelId,
|
||||||
|
stream
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info('Successfully completed streaming chat completion')
|
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error processing streaming chat completion:', error)
|
logger.error('Error processing streaming chat completion:', error)
|
||||||
throw error
|
throw error
|
||||||
|
|||||||
@ -21,6 +21,14 @@ export class MessagesService {
|
|||||||
errors.push('Model is required')
|
errors.push('Model is required')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!request.max_tokens || request.max_tokens < 1) {
|
||||||
|
errors.push('max_tokens is required and must be at least 1')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!request.messages || !Array.isArray(request.messages) || request.messages.length === 0) {
|
||||||
|
errors.push('messages is required and must be a non-empty array')
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isValid: errors.length === 0,
|
isValid: errors.length === 0,
|
||||||
errors
|
errors
|
||||||
@ -28,7 +36,6 @@ export class MessagesService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async processMessage(request: MessageCreateParams, provider: Provider): Promise<Message> {
|
async processMessage(request: MessageCreateParams, provider: Provider): Promise<Message> {
|
||||||
try {
|
|
||||||
logger.info('Processing Anthropic message request:', {
|
logger.info('Processing Anthropic message request:', {
|
||||||
model: request.model,
|
model: request.model,
|
||||||
messageCount: request.messages.length,
|
messageCount: request.messages.length,
|
||||||
@ -36,12 +43,6 @@ export class MessagesService {
|
|||||||
max_tokens: request.max_tokens
|
max_tokens: request.max_tokens
|
||||||
})
|
})
|
||||||
|
|
||||||
// Validate request
|
|
||||||
const validation = this.validateRequest(request)
|
|
||||||
if (!validation.isValid) {
|
|
||||||
throw new Error(`Request validation failed: ${validation.errors.join(', ')}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create Anthropic client for the provider
|
// Create Anthropic client for the provider
|
||||||
const client = new Anthropic({
|
const client = new Anthropic({
|
||||||
baseURL: provider.apiHost,
|
baseURL: provider.apiHost,
|
||||||
@ -63,28 +64,17 @@ export class MessagesService {
|
|||||||
|
|
||||||
logger.info('Successfully processed Anthropic message')
|
logger.info('Successfully processed Anthropic message')
|
||||||
return response
|
return response
|
||||||
} catch (error: any) {
|
|
||||||
logger.error('Error processing Anthropic message:', error)
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async *processStreamingMessage(
|
async *processStreamingMessage(
|
||||||
request: MessageCreateParams,
|
request: MessageCreateParams,
|
||||||
provider: Provider
|
provider: Provider
|
||||||
): AsyncIterable<RawMessageStreamEvent> {
|
): AsyncIterable<RawMessageStreamEvent> {
|
||||||
try {
|
|
||||||
logger.info('Processing streaming Anthropic message request:', {
|
logger.info('Processing streaming Anthropic message request:', {
|
||||||
model: request.model,
|
model: request.model,
|
||||||
messageCount: request.messages.length
|
messageCount: request.messages.length
|
||||||
})
|
})
|
||||||
|
|
||||||
// Validate request
|
|
||||||
const validation = this.validateRequest(request)
|
|
||||||
if (!validation.isValid) {
|
|
||||||
throw new Error(`Request validation failed: ${validation.errors.join(', ')}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create Anthropic client for the provider
|
// Create Anthropic client for the provider
|
||||||
const client = new Anthropic({
|
const client = new Anthropic({
|
||||||
baseURL: provider.apiHost,
|
baseURL: provider.apiHost,
|
||||||
@ -109,10 +99,6 @@ export class MessagesService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.info('Successfully completed streaming Anthropic message')
|
logger.info('Successfully completed streaming Anthropic message')
|
||||||
} catch (error: any) {
|
|
||||||
logger.error('Error processing streaming Anthropic message:', error)
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
79
tests/apis/chat.http
Normal file
79
tests/apis/chat.http
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
@host=http://localhost:23333
|
||||||
|
@token=cs-sk-af798ed4-7cf5-4fd7-ae4b-df203b164194
|
||||||
|
@agent_id=agent_1758092281575_tn9dxio9k
|
||||||
|
|
||||||
|
|
||||||
|
### List All Models
|
||||||
|
GET {{host}}/v1/models
|
||||||
|
Authorization: Bearer {{token}}
|
||||||
|
|
||||||
|
|
||||||
|
### List Models With Filters
|
||||||
|
GET {{host}}/v1/models?provider=anthropic&limit=5
|
||||||
|
Authorization: Bearer {{token}}
|
||||||
|
|
||||||
|
|
||||||
|
### OpenAI Chat Completion
|
||||||
|
POST {{host}}/v1/chat/completions
|
||||||
|
Authorization: Bearer {{token}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "tokenflux:openai/gpt-5-nano",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Explain the theory of relativity in simple terms."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
### OpenAI Chat Completion with streaming
|
||||||
|
POST {{host}}/v1/chat/completions
|
||||||
|
Authorization: Bearer {{token}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "tokenflux:openai/gpt-5-nano",
|
||||||
|
"stream": true,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Explain the theory of relativity in simple terms."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
### Anthropic Chat Message
|
||||||
|
POST {{host}}/v1/messages
|
||||||
|
Authorization: Bearer {{token}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "anthropic:claude-sonnet-4-20250514",
|
||||||
|
"stream": false,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Explain the theory of relativity in simple terms."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
### Anthropic Chat Message with streaming
|
||||||
|
POST {{host}}/v1/messages
|
||||||
|
Authorization: Bearer {{token}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "anthropic:claude-sonnet-4-20250514",
|
||||||
|
"stream": true,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Explain the theory of relativity in simple terms."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user