mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-11 16:39:15 +08:00
refactor: message processing to unify streaming and non-streaming handling
This commit is contained in:
parent
3b44392e5a
commit
b52afe075f
@ -3,7 +3,7 @@ import express from 'express'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import type { ExtendedChatCompletionCreateParams } from '../adapters'
|
||||
import { generateMessage, streamToResponse } from '../services/ProxyStreamService'
|
||||
import { processMessage } from '../services/ProxyStreamService'
|
||||
import { validateModelId } from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerChatRoutes')
|
||||
@ -205,38 +205,15 @@ router.post('/completions', async (req: Request, res: Response) => {
|
||||
|
||||
const provider = modelValidation.provider!
|
||||
const modelId = modelValidation.modelId!
|
||||
const isStreaming = !!request.stream
|
||||
|
||||
if (isStreaming) {
|
||||
try {
|
||||
await streamToResponse({
|
||||
response: res,
|
||||
provider,
|
||||
modelId,
|
||||
params: request,
|
||||
inputFormat: 'openai',
|
||||
outputFormat: 'openai'
|
||||
})
|
||||
} catch (streamError) {
|
||||
logger.error('Stream error', { error: streamError })
|
||||
// If headers weren't sent yet, return JSON error
|
||||
if (!res.headersSent) {
|
||||
const { status, body } = mapChatCompletionError(streamError)
|
||||
return res.status(status).json(body)
|
||||
}
|
||||
// Otherwise the error is already handled by streamToResponse
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const response = await generateMessage({
|
||||
return processMessage({
|
||||
response: res,
|
||||
provider,
|
||||
modelId,
|
||||
params: request,
|
||||
inputFormat: 'openai',
|
||||
outputFormat: 'openai'
|
||||
})
|
||||
return res.json(response)
|
||||
} catch (error: unknown) {
|
||||
const { status, body } = mapChatCompletionError(error)
|
||||
return res.status(status).json(body)
|
||||
|
||||
@ -8,7 +8,7 @@ import express from 'express'
|
||||
import { approximateTokenSize } from 'tokenx'
|
||||
|
||||
import { messagesService } from '../services/messages'
|
||||
import { generateMessage, streamToResponse } from '../services/ProxyStreamService'
|
||||
import { processMessage } from '../services/ProxyStreamService'
|
||||
import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils'
|
||||
|
||||
/**
|
||||
@ -321,29 +321,19 @@ async function handleUnifiedProcessing({
|
||||
providerId: provider.id
|
||||
})
|
||||
|
||||
if (request.stream) {
|
||||
await streamToResponse({
|
||||
response: res,
|
||||
provider,
|
||||
modelId: actualModelId,
|
||||
params: request,
|
||||
middlewares,
|
||||
onError: (error) => {
|
||||
logger.error('Stream error', error as Error)
|
||||
},
|
||||
onComplete: () => {
|
||||
logger.debug('Stream completed')
|
||||
}
|
||||
})
|
||||
} else {
|
||||
const response = await generateMessage({
|
||||
provider,
|
||||
modelId: actualModelId,
|
||||
params: request,
|
||||
middlewares
|
||||
})
|
||||
res.json(response)
|
||||
}
|
||||
await processMessage({
|
||||
response: res,
|
||||
provider,
|
||||
modelId: actualModelId,
|
||||
params: request,
|
||||
middlewares,
|
||||
onError: (error) => {
|
||||
logger.error('Message error', error as Error)
|
||||
},
|
||||
onComplete: () => {
|
||||
logger.debug('Message completed')
|
||||
}
|
||||
})
|
||||
} catch (error: any) {
|
||||
const { statusCode, errorResponse } = messagesService.transformError(error)
|
||||
res.status(statusCode).json(errorResponse)
|
||||
|
||||
@ -44,10 +44,6 @@ initializeSharedProviders({
|
||||
error: (message, error) => logger.error(message, error)
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Configuration Interfaces
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Middleware type alias
|
||||
*/
|
||||
@ -59,9 +55,9 @@ type LanguageModelMiddleware = LanguageModelV2Middleware
|
||||
type InputParams = InputParamsMap[InputFormat]
|
||||
|
||||
/**
|
||||
* Configuration for streaming message requests
|
||||
* Configuration for message requests (both streaming and non-streaming)
|
||||
*/
|
||||
export interface StreamConfig {
|
||||
export interface MessageConfig {
|
||||
response: Response
|
||||
provider: Provider
|
||||
modelId: string
|
||||
@ -74,19 +70,6 @@ export interface StreamConfig {
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for non-streaming message generation
|
||||
*/
|
||||
export interface GenerateConfig {
|
||||
provider: Provider
|
||||
modelId: string
|
||||
params: InputParams
|
||||
inputFormat?: InputFormat
|
||||
outputFormat?: OutputFormat
|
||||
middlewares?: LanguageModelMiddleware[]
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal configuration for stream execution
|
||||
*/
|
||||
@ -304,27 +287,14 @@ async function executeStream(config: ExecuteStreamConfig): Promise<{
|
||||
return { adapter, outputStream }
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Public API
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Stream a message request and write to HTTP response
|
||||
* Process a message request - handles both streaming and non-streaming
|
||||
*
|
||||
* Uses TransformStream-based adapters for efficient streaming.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* await streamToResponse({
|
||||
* response: res,
|
||||
* provider,
|
||||
* modelId: 'claude-3-opus',
|
||||
* params: messageCreateParams,
|
||||
* outputFormat: 'anthropic'
|
||||
* })
|
||||
* ```
|
||||
* Automatically detects streaming mode from params.stream:
|
||||
* - stream=true: SSE streaming response
|
||||
* - stream=false: JSON response
|
||||
*/
|
||||
export async function streamToResponse(config: StreamConfig): Promise<void> {
|
||||
export async function processMessage(config: MessageConfig): Promise<void> {
|
||||
const {
|
||||
response,
|
||||
provider,
|
||||
@ -338,7 +308,9 @@ export async function streamToResponse(config: StreamConfig): Promise<void> {
|
||||
plugins = []
|
||||
} = config
|
||||
|
||||
logger.info('Starting proxy stream', {
|
||||
const isStreaming = 'stream' in params && params.stream === true
|
||||
|
||||
logger.info(`Starting ${isStreaming ? 'streaming' : 'non-streaming'} message`, {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId,
|
||||
@ -348,112 +320,21 @@ export async function streamToResponse(config: StreamConfig): Promise<void> {
|
||||
pluginCount: plugins.length
|
||||
})
|
||||
|
||||
// Create abort controller for client disconnect handling
|
||||
const streamController = createStreamAbortController({
|
||||
timeoutMs: LONG_POLL_TIMEOUT_MS
|
||||
})
|
||||
// Create abort controller with timeout
|
||||
const streamController = createStreamAbortController({ timeoutMs: LONG_POLL_TIMEOUT_MS })
|
||||
const { abortController, dispose } = streamController
|
||||
|
||||
// Handle client disconnect
|
||||
const handleDisconnect = () => {
|
||||
if (abortController.signal.aborted) return
|
||||
logger.info('Client disconnected, aborting stream', { providerId: provider.id, modelId })
|
||||
logger.info('Client disconnected, aborting', { providerId: provider.id, modelId })
|
||||
abortController.abort('Client disconnected')
|
||||
}
|
||||
|
||||
response.on('close', handleDisconnect)
|
||||
|
||||
try {
|
||||
// Set SSE headers
|
||||
response.setHeader('Content-Type', 'text/event-stream')
|
||||
response.setHeader('Cache-Control', 'no-cache')
|
||||
response.setHeader('Connection', 'keep-alive')
|
||||
response.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
const { outputStream } = await executeStream({
|
||||
provider,
|
||||
modelId,
|
||||
params,
|
||||
inputFormat,
|
||||
outputFormat,
|
||||
middlewares,
|
||||
plugins,
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
|
||||
// Get formatter for the output format
|
||||
const formatter = StreamAdapterFactory.getFormatter(outputFormat)
|
||||
|
||||
// Stream events to response
|
||||
const reader = outputStream.getReader()
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
if (response.writableEnded) break
|
||||
response.write(formatter.formatEvent(value))
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
|
||||
// Send done marker and end response
|
||||
if (!response.writableEnded) {
|
||||
response.write(formatter.formatDone())
|
||||
response.end()
|
||||
}
|
||||
|
||||
logger.info('Proxy stream completed', { providerId: provider.id, modelId })
|
||||
onComplete?.()
|
||||
} catch (error) {
|
||||
logger.error('Error in proxy stream', error as Error, { providerId: provider.id, modelId })
|
||||
onError?.(error)
|
||||
throw error
|
||||
} finally {
|
||||
response.off('close', handleDisconnect)
|
||||
dispose()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a non-streaming message response
|
||||
*
|
||||
* Uses simulateStreamingMiddleware to reuse the same streaming logic.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const message = await generateMessage({
|
||||
* provider,
|
||||
* modelId: 'claude-3-opus',
|
||||
* params: messageCreateParams,
|
||||
* outputFormat: 'anthropic'
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export async function generateMessage(config: GenerateConfig): Promise<unknown> {
|
||||
const {
|
||||
provider,
|
||||
modelId,
|
||||
params,
|
||||
inputFormat = 'anthropic',
|
||||
outputFormat = 'anthropic',
|
||||
middlewares = [],
|
||||
plugins = []
|
||||
} = config
|
||||
|
||||
logger.info('Starting message generation', {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId,
|
||||
inputFormat,
|
||||
outputFormat,
|
||||
middlewareCount: middlewares.length,
|
||||
pluginCount: plugins.length
|
||||
})
|
||||
|
||||
try {
|
||||
// Add simulateStreamingMiddleware to reuse streaming logic
|
||||
const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares]
|
||||
// For non-streaming, add simulateStreamingMiddleware
|
||||
const allMiddlewares = isStreaming ? middlewares : [simulateStreamingMiddleware(), ...middlewares]
|
||||
|
||||
const { adapter, outputStream } = await executeStream({
|
||||
provider,
|
||||
@ -462,30 +343,60 @@ export async function generateMessage(config: GenerateConfig): Promise<unknown>
|
||||
inputFormat,
|
||||
outputFormat,
|
||||
middlewares: allMiddlewares,
|
||||
plugins
|
||||
plugins,
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
|
||||
// Consume the stream to populate adapter state
|
||||
const reader = outputStream.getReader()
|
||||
while (true) {
|
||||
const { done } = await reader.read()
|
||||
if (done) break
|
||||
if (isStreaming) {
|
||||
// Streaming: Set SSE headers and stream events
|
||||
response.setHeader('Content-Type', 'text/event-stream')
|
||||
response.setHeader('Cache-Control', 'no-cache')
|
||||
response.setHeader('Connection', 'keep-alive')
|
||||
response.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
const formatter = StreamAdapterFactory.getFormatter(outputFormat)
|
||||
const reader = outputStream.getReader()
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
if (response.writableEnded) break
|
||||
response.write(formatter.formatEvent(value))
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
|
||||
if (!response.writableEnded) {
|
||||
response.write(formatter.formatDone())
|
||||
response.end()
|
||||
}
|
||||
} else {
|
||||
// Non-streaming: Consume stream and return JSON
|
||||
const reader = outputStream.getReader()
|
||||
while (true) {
|
||||
const { done } = await reader.read()
|
||||
if (done) break
|
||||
}
|
||||
reader.releaseLock()
|
||||
|
||||
const finalResponse = adapter.buildNonStreamingResponse()
|
||||
response.json(finalResponse)
|
||||
}
|
||||
reader.releaseLock()
|
||||
|
||||
// Build final response from adapter
|
||||
const finalResponse = adapter.buildNonStreamingResponse()
|
||||
|
||||
logger.info('Message generation completed', { providerId: provider.id, modelId })
|
||||
|
||||
return finalResponse
|
||||
logger.info('Message completed', { providerId: provider.id, modelId, streaming: isStreaming })
|
||||
onComplete?.()
|
||||
} catch (error) {
|
||||
logger.error('Error in message generation', error as Error, { providerId: provider.id, modelId })
|
||||
logger.error('Error in message processing', error as Error, { providerId: provider.id, modelId })
|
||||
onError?.(error)
|
||||
throw error
|
||||
} finally {
|
||||
response.off('close', handleDisconnect)
|
||||
dispose()
|
||||
}
|
||||
}
|
||||
|
||||
export default {
|
||||
streamToResponse,
|
||||
generateMessage
|
||||
processMessage
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user