diff --git a/src/main/apiServer/services/ProxyStreamService.ts b/src/main/apiServer/services/ProxyStreamService.ts index 5519dc49ef..90149a617e 100644 --- a/src/main/apiServer/services/ProxyStreamService.ts +++ b/src/main/apiServer/services/ProxyStreamService.ts @@ -33,6 +33,8 @@ import type { Response } from 'express' import type { InputFormat, InputParamsMap, IStreamAdapter } from '../adapters' import { MessageConverterFactory, type OutputFormat, StreamAdapterFactory } from '../adapters' +import { LONG_POLL_TIMEOUT_MS } from '../config/timeouts' +import { createStreamAbortController } from '../utils/createStreamAbortController' import { googleReasoningCache, openRouterReasoningCache } from './reasoning-cache' const logger = loggerService.withContext('ProxyStreamService') @@ -96,6 +98,7 @@ interface ExecuteStreamConfig { outputFormat: OutputFormat middlewares?: LanguageModelMiddleware[] plugins?: AiPlugin[] + abortSignal?: AbortSignal } // ============================================================================ @@ -248,7 +251,7 @@ async function executeStream(config: ExecuteStreamConfig): Promise<{ adapter: IStreamAdapter outputStream: ReadableStream }> { - const { provider, modelId, params, inputFormat, outputFormat, middlewares = [], plugins = [] } = config + const { provider, modelId, params, inputFormat, outputFormat, middlewares = [], plugins = [], abortSignal } = config // Convert provider config to AI SDK config let sdkConfig = providerToAiSdkConfig(provider, modelId) @@ -291,7 +294,8 @@ async function executeStream(config: ExecuteStreamConfig): Promise<{ stopWhen: stepCountIs(100), headers: defaultAppHeaders(), tools, - providerOptions + providerOptions, + abortSignal }) // Transform stream using adapter @@ -344,6 +348,21 @@ export async function streamToResponse(config: StreamConfig): Promise { pluginCount: plugins.length }) + // Create abort controller for client disconnect handling + const streamController = createStreamAbortController({ + timeoutMs: LONG_POLL_TIMEOUT_MS + }) + const { abortController, dispose } = streamController + + // Handle client disconnect + const handleDisconnect = () => { + if (abortController.signal.aborted) return + logger.info('Client disconnected, aborting stream', { providerId: provider.id, modelId }) + abortController.abort('Client disconnected') + } + + response.on('close', handleDisconnect) + try { // Set SSE headers response.setHeader('Content-Type', 'text/event-stream') @@ -358,7 +377,8 @@ export async function streamToResponse(config: StreamConfig): Promise { inputFormat, outputFormat, middlewares, - plugins + plugins, + abortSignal: abortController.signal }) // Get formatter for the output format @@ -370,6 +390,7 @@ export async function streamToResponse(config: StreamConfig): Promise { while (true) { const { done, value } = await reader.read() if (done) break + if (response.writableEnded) break response.write(formatter.formatEvent(value)) } } finally { @@ -377,8 +398,10 @@ export async function streamToResponse(config: StreamConfig): Promise { } // Send done marker and end response - response.write(formatter.formatDone()) - response.end() + if (!response.writableEnded) { + response.write(formatter.formatDone()) + response.end() + } logger.info('Proxy stream completed', { providerId: provider.id, modelId }) onComplete?.() @@ -386,6 +409,9 @@ export async function streamToResponse(config: StreamConfig): Promise { logger.error('Error in proxy stream', error as Error, { providerId: provider.id, modelId }) onError?.(error) throw error + } finally { + response.off('close', handleDisconnect) + dispose() } }