diff --git a/src/main/apiServer/routes/chat.ts b/src/main/apiServer/routes/chat.ts index 999ad36312..4756075e9b 100644 --- a/src/main/apiServer/routes/chat.ts +++ b/src/main/apiServer/routes/chat.ts @@ -3,7 +3,7 @@ import express from 'express' import { loggerService } from '../../services/LoggerService' import type { ExtendedChatCompletionCreateParams } from '../adapters' -import { generateMessage, streamToResponse } from '../services/ProxyStreamService' +import { processMessage } from '../services/ProxyStreamService' import { validateModelId } from '../utils' const logger = loggerService.withContext('ApiServerChatRoutes') @@ -205,38 +205,15 @@ router.post('/completions', async (req: Request, res: Response) => { const provider = modelValidation.provider! const modelId = modelValidation.modelId! - const isStreaming = !!request.stream - if (isStreaming) { - try { - await streamToResponse({ - response: res, - provider, - modelId, - params: request, - inputFormat: 'openai', - outputFormat: 'openai' - }) - } catch (streamError) { - logger.error('Stream error', { error: streamError }) - // If headers weren't sent yet, return JSON error - if (!res.headersSent) { - const { status, body } = mapChatCompletionError(streamError) - return res.status(status).json(body) - } - // Otherwise the error is already handled by streamToResponse - } - return - } - - const response = await generateMessage({ + return processMessage({ + response: res, provider, modelId, params: request, inputFormat: 'openai', outputFormat: 'openai' }) - return res.json(response) } catch (error: unknown) { const { status, body } = mapChatCompletionError(error) return res.status(status).json(body) diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index 652d4f46bb..94171c6b3d 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -8,7 +8,7 @@ import express from 'express' import { approximateTokenSize } from 'tokenx' import { messagesService } from '../services/messages' -import { generateMessage, streamToResponse } from '../services/ProxyStreamService' +import { processMessage } from '../services/ProxyStreamService' import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils' /** @@ -321,29 +321,19 @@ async function handleUnifiedProcessing({ providerId: provider.id }) - if (request.stream) { - await streamToResponse({ - response: res, - provider, - modelId: actualModelId, - params: request, - middlewares, - onError: (error) => { - logger.error('Stream error', error as Error) - }, - onComplete: () => { - logger.debug('Stream completed') - } - }) - } else { - const response = await generateMessage({ - provider, - modelId: actualModelId, - params: request, - middlewares - }) - res.json(response) - } + await processMessage({ + response: res, + provider, + modelId: actualModelId, + params: request, + middlewares, + onError: (error) => { + logger.error('Message error', error as Error) + }, + onComplete: () => { + logger.debug('Message completed') + } + }) } catch (error: any) { const { statusCode, errorResponse } = messagesService.transformError(error) res.status(statusCode).json(errorResponse) diff --git a/src/main/apiServer/services/ProxyStreamService.ts b/src/main/apiServer/services/ProxyStreamService.ts index 90149a617e..d5387c9f23 100644 --- a/src/main/apiServer/services/ProxyStreamService.ts +++ b/src/main/apiServer/services/ProxyStreamService.ts @@ -44,10 +44,6 @@ initializeSharedProviders({ error: (message, error) => logger.error(message, error) }) -// ============================================================================ -// Configuration Interfaces -// ============================================================================ - /** * Middleware type alias */ @@ -59,9 +55,9 @@ type LanguageModelMiddleware = LanguageModelV2Middleware type InputParams = InputParamsMap[InputFormat] /** - * Configuration for streaming message requests + * Configuration for message requests (both streaming and non-streaming) */ -export interface StreamConfig { +export interface MessageConfig { response: Response provider: Provider modelId: string @@ -74,19 +70,6 @@ export interface StreamConfig { plugins?: AiPlugin[] } -/** - * Configuration for non-streaming message generation - */ -export interface GenerateConfig { - provider: Provider - modelId: string - params: InputParams - inputFormat?: InputFormat - outputFormat?: OutputFormat - middlewares?: LanguageModelMiddleware[] - plugins?: AiPlugin[] -} - /** * Internal configuration for stream execution */ @@ -304,27 +287,14 @@ async function executeStream(config: ExecuteStreamConfig): Promise<{ return { adapter, outputStream } } -// ============================================================================ -// Public API -// ============================================================================ - /** - * Stream a message request and write to HTTP response + * Process a message request - handles both streaming and non-streaming * - * Uses TransformStream-based adapters for efficient streaming. - * - * @example - * ```typescript - * await streamToResponse({ - * response: res, - * provider, - * modelId: 'claude-3-opus', - * params: messageCreateParams, - * outputFormat: 'anthropic' - * }) - * ``` + * Automatically detects streaming mode from params.stream: + * - stream=true: SSE streaming response + * - stream=false: JSON response */ -export async function streamToResponse(config: StreamConfig): Promise { +export async function processMessage(config: MessageConfig): Promise { const { response, provider, @@ -338,7 +308,9 @@ export async function streamToResponse(config: StreamConfig): Promise { plugins = [] } = config - logger.info('Starting proxy stream', { + const isStreaming = 'stream' in params && params.stream === true + + logger.info(`Starting ${isStreaming ? 'streaming' : 'non-streaming'} message`, { providerId: provider.id, providerType: provider.type, modelId, @@ -348,112 +320,21 @@ export async function streamToResponse(config: StreamConfig): Promise { pluginCount: plugins.length }) - // Create abort controller for client disconnect handling - const streamController = createStreamAbortController({ - timeoutMs: LONG_POLL_TIMEOUT_MS - }) + // Create abort controller with timeout + const streamController = createStreamAbortController({ timeoutMs: LONG_POLL_TIMEOUT_MS }) const { abortController, dispose } = streamController - // Handle client disconnect const handleDisconnect = () => { if (abortController.signal.aborted) return - logger.info('Client disconnected, aborting stream', { providerId: provider.id, modelId }) + logger.info('Client disconnected, aborting', { providerId: provider.id, modelId }) abortController.abort('Client disconnected') } response.on('close', handleDisconnect) try { - // Set SSE headers - response.setHeader('Content-Type', 'text/event-stream') - response.setHeader('Cache-Control', 'no-cache') - response.setHeader('Connection', 'keep-alive') - response.setHeader('X-Accel-Buffering', 'no') - - const { outputStream } = await executeStream({ - provider, - modelId, - params, - inputFormat, - outputFormat, - middlewares, - plugins, - abortSignal: abortController.signal - }) - - // Get formatter for the output format - const formatter = StreamAdapterFactory.getFormatter(outputFormat) - - // Stream events to response - const reader = outputStream.getReader() - try { - while (true) { - const { done, value } = await reader.read() - if (done) break - if (response.writableEnded) break - response.write(formatter.formatEvent(value)) - } - } finally { - reader.releaseLock() - } - - // Send done marker and end response - if (!response.writableEnded) { - response.write(formatter.formatDone()) - response.end() - } - - logger.info('Proxy stream completed', { providerId: provider.id, modelId }) - onComplete?.() - } catch (error) { - logger.error('Error in proxy stream', error as Error, { providerId: provider.id, modelId }) - onError?.(error) - throw error - } finally { - response.off('close', handleDisconnect) - dispose() - } -} - -/** - * Generate a non-streaming message response - * - * Uses simulateStreamingMiddleware to reuse the same streaming logic. - * - * @example - * ```typescript - * const message = await generateMessage({ - * provider, - * modelId: 'claude-3-opus', - * params: messageCreateParams, - * outputFormat: 'anthropic' - * }) - * ``` - */ -export async function generateMessage(config: GenerateConfig): Promise { - const { - provider, - modelId, - params, - inputFormat = 'anthropic', - outputFormat = 'anthropic', - middlewares = [], - plugins = [] - } = config - - logger.info('Starting message generation', { - providerId: provider.id, - providerType: provider.type, - modelId, - inputFormat, - outputFormat, - middlewareCount: middlewares.length, - pluginCount: plugins.length - }) - - try { - // Add simulateStreamingMiddleware to reuse streaming logic - const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares] + // For non-streaming, add simulateStreamingMiddleware + const allMiddlewares = isStreaming ? middlewares : [simulateStreamingMiddleware(), ...middlewares] const { adapter, outputStream } = await executeStream({ provider, @@ -462,30 +343,60 @@ export async function generateMessage(config: GenerateConfig): Promise inputFormat, outputFormat, middlewares: allMiddlewares, - plugins + plugins, + abortSignal: abortController.signal }) - // Consume the stream to populate adapter state - const reader = outputStream.getReader() - while (true) { - const { done } = await reader.read() - if (done) break + if (isStreaming) { + // Streaming: Set SSE headers and stream events + response.setHeader('Content-Type', 'text/event-stream') + response.setHeader('Cache-Control', 'no-cache') + response.setHeader('Connection', 'keep-alive') + response.setHeader('X-Accel-Buffering', 'no') + + const formatter = StreamAdapterFactory.getFormatter(outputFormat) + const reader = outputStream.getReader() + + try { + while (true) { + const { done, value } = await reader.read() + if (done) break + if (response.writableEnded) break + response.write(formatter.formatEvent(value)) + } + } finally { + reader.releaseLock() + } + + if (!response.writableEnded) { + response.write(formatter.formatDone()) + response.end() + } + } else { + // Non-streaming: Consume stream and return JSON + const reader = outputStream.getReader() + while (true) { + const { done } = await reader.read() + if (done) break + } + reader.releaseLock() + + const finalResponse = adapter.buildNonStreamingResponse() + response.json(finalResponse) } - reader.releaseLock() - // Build final response from adapter - const finalResponse = adapter.buildNonStreamingResponse() - - logger.info('Message generation completed', { providerId: provider.id, modelId }) - - return finalResponse + logger.info('Message completed', { providerId: provider.id, modelId, streaming: isStreaming }) + onComplete?.() } catch (error) { - logger.error('Error in message generation', error as Error, { providerId: provider.id, modelId }) + logger.error('Error in message processing', error as Error, { providerId: provider.id, modelId }) + onError?.(error) throw error + } finally { + response.off('close', handleDisconnect) + dispose() } } export default { - streamToResponse, - generateMessage + processMessage }