mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-10 07:19:02 +08:00
Merge branch 'feat/agents-new' of github.com:CherryHQ/cherry-studio into feat/agents-new
This commit is contained in:
commit
fcacc50fdc
@ -1,4 +1,5 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
|
import { AgentStreamEvent } from '@main/services/agents/interfaces/AgentStreamInterface'
|
||||||
import { Request, Response } from 'express'
|
import { Request, Response } from 'express'
|
||||||
|
|
||||||
import { agentService, sessionMessageService, sessionService } from '../../../../services/agents'
|
import { agentService, sessionMessageService, sessionService } from '../../../../services/agents'
|
||||||
@ -42,13 +43,12 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
res.setHeader('Access-Control-Allow-Origin', '*')
|
res.setHeader('Access-Control-Allow-Origin', '*')
|
||||||
res.setHeader('Access-Control-Allow-Headers', 'Cache-Control')
|
res.setHeader('Access-Control-Allow-Headers', 'Cache-Control')
|
||||||
|
|
||||||
const messageStream = sessionMessageService.createSessionMessage(session, messageData)
|
const abortController = new AbortController()
|
||||||
|
const messageStream = sessionMessageService.createSessionMessage(session, messageData, abortController)
|
||||||
|
|
||||||
// Track stream lifecycle so we keep the SSE connection open until persistence finishes
|
// Track stream lifecycle so we keep the SSE connection open until persistence finishes
|
||||||
let responseEnded = false
|
let responseEnded = false
|
||||||
let streamFinished = false
|
let streamFinished = false
|
||||||
let awaitingPersistence = false
|
|
||||||
let persistenceResolved = false
|
|
||||||
|
|
||||||
const finalizeResponse = () => {
|
const finalizeResponse = () => {
|
||||||
if (responseEnded) {
|
if (responseEnded) {
|
||||||
@ -59,10 +59,6 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (awaitingPersistence && !persistenceResolved) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
responseEnded = true
|
responseEnded = true
|
||||||
try {
|
try {
|
||||||
res.write('data: {"type":"finish"}\n\n')
|
res.write('data: {"type":"finish"}\n\n')
|
||||||
@ -73,15 +69,39 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
res.end()
|
res.end()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle client disconnect
|
/**
|
||||||
req.on('close', () => {
|
* Client Disconnect Detection for Server-Sent Events (SSE)
|
||||||
|
*
|
||||||
|
* We monitor multiple HTTP events to reliably detect when a client disconnects
|
||||||
|
* from the streaming response. This is crucial for:
|
||||||
|
* - Aborting long-running Claude Code processes
|
||||||
|
* - Cleaning up resources and preventing memory leaks
|
||||||
|
* - Avoiding orphaned processes
|
||||||
|
*
|
||||||
|
* Event Priority & Behavior:
|
||||||
|
* 1. res.on('close') - Most common for SSE client disconnects (browser tab close, curl Ctrl+C)
|
||||||
|
* 2. req.on('aborted') - Explicit request abortion
|
||||||
|
* 3. req.on('close') - Request object closure (less common with SSE)
|
||||||
|
*
|
||||||
|
* When any disconnect event fires, we:
|
||||||
|
* - Abort the Claude Code SDK process via abortController
|
||||||
|
* - Clean up event listeners to prevent memory leaks
|
||||||
|
* - Mark the response as ended to prevent further writes
|
||||||
|
*/
|
||||||
|
const handleDisconnect = () => {
|
||||||
|
if (responseEnded) return
|
||||||
logger.info(`Client disconnected from streaming message for session: ${sessionId}`)
|
logger.info(`Client disconnected from streaming message for session: ${sessionId}`)
|
||||||
responseEnded = true
|
responseEnded = true
|
||||||
messageStream.removeAllListeners()
|
messageStream.removeAllListeners()
|
||||||
})
|
abortController.abort('Client disconnected')
|
||||||
|
}
|
||||||
|
|
||||||
|
req.on('close', handleDisconnect)
|
||||||
|
req.on('aborted', handleDisconnect)
|
||||||
|
res.on('close', handleDisconnect)
|
||||||
|
|
||||||
// Handle stream events
|
// Handle stream events
|
||||||
messageStream.on('data', (event: any) => {
|
messageStream.on('data', (event: AgentStreamEvent) => {
|
||||||
if (responseEnded) return
|
if (responseEnded) return
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -101,12 +121,6 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
logger.error(`Streaming message error for session: ${sessionId}:`, event.error)
|
logger.error(`Streaming message error for session: ${sessionId}:`, event.error)
|
||||||
|
|
||||||
streamFinished = true
|
streamFinished = true
|
||||||
awaitingPersistence = Boolean(event.persistScheduled)
|
|
||||||
|
|
||||||
if (!awaitingPersistence) {
|
|
||||||
persistenceResolved = true
|
|
||||||
}
|
|
||||||
|
|
||||||
finalizeResponse()
|
finalizeResponse()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -116,32 +130,22 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
// res.write(`data: ${JSON.stringify({ type: 'complete', result: event.result })}\n\n`)
|
// res.write(`data: ${JSON.stringify({ type: 'complete', result: event.result })}\n\n`)
|
||||||
|
|
||||||
streamFinished = true
|
streamFinished = true
|
||||||
awaitingPersistence = true
|
|
||||||
finalizeResponse()
|
finalizeResponse()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
case 'persisted':
|
case 'cancelled': {
|
||||||
// Send persistence success event
|
logger.info(`Streaming message cancelled for session: ${sessionId}`)
|
||||||
// res.write(`data: ${JSON.stringify(event)}\n\n`)
|
// res.write(`data: ${JSON.stringify({ type: 'cancelled' })}\n\n`)
|
||||||
logger.debug(`Session message persisted for session: ${sessionId}`, { messageId: event.message?.id })
|
streamFinished = true
|
||||||
|
|
||||||
persistenceResolved = true
|
|
||||||
finalizeResponse()
|
|
||||||
break
|
|
||||||
|
|
||||||
case 'persist-error':
|
|
||||||
// Send persistence error event
|
|
||||||
// res.write(`data: ${JSON.stringify(event)}\n\n`)
|
|
||||||
logger.error(`Failed to persist session message for session: ${sessionId}:`, event.error)
|
|
||||||
|
|
||||||
persistenceResolved = true
|
|
||||||
finalizeResponse()
|
finalizeResponse()
|
||||||
break
|
break
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// Handle other event types as generic data
|
// Handle other event types as generic data
|
||||||
res.write(`data: ${JSON.stringify(event)}\n\n`)
|
logger.info(`Streaming message event for session: ${sessionId}:`, { event })
|
||||||
|
// res.write(`data: ${JSON.stringify(event)}\n\n`)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
} catch (writeError) {
|
} catch (writeError) {
|
||||||
@ -199,8 +203,8 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
|||||||
res.end()
|
res.end()
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
5 * 60 * 1000
|
10 * 60 * 1000
|
||||||
) // 5 minutes timeout
|
) // 10 minutes timeout
|
||||||
|
|
||||||
// Clear timeout when response ends
|
// Clear timeout when response ends
|
||||||
res.on('close', () => clearTimeout(timeout))
|
res.on('close', () => clearTimeout(timeout))
|
||||||
|
|||||||
@ -13,8 +13,7 @@ import { Request, Response } from 'express'
|
|||||||
import { IncomingMessage, ServerResponse } from 'http'
|
import { IncomingMessage, ServerResponse } from 'http'
|
||||||
|
|
||||||
import { loggerService } from '../../services/LoggerService'
|
import { loggerService } from '../../services/LoggerService'
|
||||||
import { reduxService } from '../../services/ReduxService'
|
import { getMcpServerById, getMCPServersFromRedux } from '../utils/mcp'
|
||||||
import { getMcpServerById } from '../utils/mcp'
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('MCPApiService')
|
const logger = loggerService.withContext('MCPApiService')
|
||||||
const transports: Record<string, StreamableHTTPServerTransport> = {}
|
const transports: Record<string, StreamableHTTPServerTransport> = {}
|
||||||
@ -57,34 +56,10 @@ class MCPApiService extends EventEmitter {
|
|||||||
this.transport.onmessage = this.onMessage
|
this.transport.onmessage = this.onMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Get servers directly from Redux store
|
|
||||||
*/
|
|
||||||
private async getServersFromRedux(): Promise<MCPServer[]> {
|
|
||||||
try {
|
|
||||||
logger.silly('Getting servers from Redux store')
|
|
||||||
|
|
||||||
// Try to get from cache first (faster)
|
|
||||||
const cachedServers = reduxService.selectSync<MCPServer[]>('state.mcp.servers')
|
|
||||||
if (cachedServers && Array.isArray(cachedServers)) {
|
|
||||||
logger.silly(`Found ${cachedServers.length} servers in Redux cache`)
|
|
||||||
return cachedServers
|
|
||||||
}
|
|
||||||
|
|
||||||
// If cache is not available, get fresh data
|
|
||||||
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
|
|
||||||
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
|
|
||||||
return servers || []
|
|
||||||
} catch (error: any) {
|
|
||||||
logger.error('Failed to get servers from Redux:', error)
|
|
||||||
return []
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// get all activated servers
|
// get all activated servers
|
||||||
async getAllServers(req: Request): Promise<McpServersResp> {
|
async getAllServers(req: Request): Promise<McpServersResp> {
|
||||||
try {
|
try {
|
||||||
const servers = await this.getServersFromRedux()
|
const servers = await getMCPServersFromRedux()
|
||||||
logger.silly(`Returning ${servers.length} servers`)
|
logger.silly(`Returning ${servers.length} servers`)
|
||||||
const resp: McpServersResp = {
|
const resp: McpServersResp = {
|
||||||
servers: {}
|
servers: {}
|
||||||
@ -111,7 +86,7 @@ class MCPApiService extends EventEmitter {
|
|||||||
async getServerById(id: string): Promise<MCPServer | null> {
|
async getServerById(id: string): Promise<MCPServer | null> {
|
||||||
try {
|
try {
|
||||||
logger.silly(`getServerById called with id: ${id}`)
|
logger.silly(`getServerById called with id: ${id}`)
|
||||||
const servers = await this.getServersFromRedux()
|
const servers = await getMCPServersFromRedux()
|
||||||
const server = servers.find((s) => s.id === id)
|
const server = servers.find((s) => s.id === id)
|
||||||
if (!server) {
|
if (!server) {
|
||||||
logger.warn(`Server with id ${id} not found`)
|
logger.warn(`Server with id ${id} not found`)
|
||||||
|
|||||||
@ -1,12 +1,24 @@
|
|||||||
|
import { CacheService } from '@main/services/CacheService'
|
||||||
import { loggerService } from '@main/services/LoggerService'
|
import { loggerService } from '@main/services/LoggerService'
|
||||||
import { reduxService } from '@main/services/ReduxService'
|
import { reduxService } from '@main/services/ReduxService'
|
||||||
import { ApiModel, Model, Provider } from '@types'
|
import { ApiModel, Model, Provider } from '@types'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServerUtils')
|
const logger = loggerService.withContext('ApiServerUtils')
|
||||||
|
|
||||||
|
// Cache configuration
|
||||||
|
const PROVIDERS_CACHE_KEY = 'api-server:providers'
|
||||||
|
const PROVIDERS_CACHE_TTL = 5 * 60 * 1000 // 5 minutes
|
||||||
|
|
||||||
export async function getAvailableProviders(): Promise<Provider[]> {
|
export async function getAvailableProviders(): Promise<Provider[]> {
|
||||||
try {
|
try {
|
||||||
// Wait for store to be ready before accessing providers
|
// Try to get from cache first (faster)
|
||||||
|
const cachedSupportedProviders = CacheService.get<Provider[]>(PROVIDERS_CACHE_KEY)
|
||||||
|
if (cachedSupportedProviders) {
|
||||||
|
logger.debug(`Found ${cachedSupportedProviders.length} supported providers (from cache)`)
|
||||||
|
return cachedSupportedProviders
|
||||||
|
}
|
||||||
|
|
||||||
|
// If cache is not available, get fresh data from Redux
|
||||||
const providers = await reduxService.select('state.llm.providers')
|
const providers = await reduxService.select('state.llm.providers')
|
||||||
if (!providers || !Array.isArray(providers)) {
|
if (!providers || !Array.isArray(providers)) {
|
||||||
logger.warn('No providers found in Redux store, returning empty array')
|
logger.warn('No providers found in Redux store, returning empty array')
|
||||||
@ -18,6 +30,9 @@ export async function getAvailableProviders(): Promise<Provider[]> {
|
|||||||
(p: Provider) => p.enabled && (p.type === 'openai' || p.type === 'anthropic')
|
(p: Provider) => p.enabled && (p.type === 'openai' || p.type === 'anthropic')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Cache the filtered results
|
||||||
|
CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL)
|
||||||
|
|
||||||
logger.info(`Filtered to ${supportedProviders.length} supported providers from ${providers.length} total providers`)
|
logger.info(`Filtered to ${supportedProviders.length} supported providers from ${providers.length} total providers`)
|
||||||
|
|
||||||
return supportedProviders
|
return supportedProviders
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import { CacheService } from '@main/services/CacheService'
|
||||||
import mcpService from '@main/services/MCPService'
|
import mcpService from '@main/services/MCPService'
|
||||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||||
import { CallToolRequestSchema, ListToolsRequestSchema, ListToolsResult } from '@modelcontextprotocol/sdk/types.js'
|
import { CallToolRequestSchema, ListToolsRequestSchema, ListToolsResult } from '@modelcontextprotocol/sdk/types.js'
|
||||||
@ -8,6 +9,10 @@ import { reduxService } from '../../services/ReduxService'
|
|||||||
|
|
||||||
const logger = loggerService.withContext('MCPApiService')
|
const logger = loggerService.withContext('MCPApiService')
|
||||||
|
|
||||||
|
// Cache configuration
|
||||||
|
const MCP_SERVERS_CACHE_KEY = 'api-server:mcp-servers'
|
||||||
|
const MCP_SERVERS_CACHE_TTL = 5 * 60 * 1000 // 5 minutes
|
||||||
|
|
||||||
const cachedServers: Record<string, Server> = {}
|
const cachedServers: Record<string, Server> = {}
|
||||||
|
|
||||||
async function handleListToolsRequest(request: any, extra: any): Promise<ListToolsResult> {
|
async function handleListToolsRequest(request: any, extra: any): Promise<ListToolsResult> {
|
||||||
@ -33,18 +38,33 @@ async function handleCallToolRequest(request: any, extra: any): Promise<any> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function getMcpServerConfigById(id: string): Promise<MCPServer | undefined> {
|
async function getMcpServerConfigById(id: string): Promise<MCPServer | undefined> {
|
||||||
const servers = await getServersFromRedux()
|
const servers = await getMCPServersFromRedux()
|
||||||
return servers.find((s) => s.id === id || s.name === id)
|
return servers.find((s) => s.id === id || s.name === id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get servers directly from Redux store
|
* Get servers directly from Redux store
|
||||||
*/
|
*/
|
||||||
async function getServersFromRedux(): Promise<MCPServer[]> {
|
export async function getMCPServersFromRedux(): Promise<MCPServer[]> {
|
||||||
try {
|
try {
|
||||||
|
logger.silly('Getting servers from Redux store')
|
||||||
|
|
||||||
|
// Try to get from cache first (faster)
|
||||||
|
const cachedServers = CacheService.get<MCPServer[]>(MCP_SERVERS_CACHE_KEY)
|
||||||
|
if (cachedServers) {
|
||||||
|
logger.silly(`Found ${cachedServers.length} servers (from cache)`)
|
||||||
|
return cachedServers
|
||||||
|
}
|
||||||
|
|
||||||
|
// If cache is not available, get fresh data from Redux
|
||||||
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
|
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
|
||||||
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
|
const serverList = servers || []
|
||||||
return servers || []
|
|
||||||
|
// Cache the results
|
||||||
|
CacheService.set(MCP_SERVERS_CACHE_KEY, serverList, MCP_SERVERS_CACHE_TTL)
|
||||||
|
|
||||||
|
logger.silly(`Fetched ${serverList.length} servers from Redux store`)
|
||||||
|
return serverList
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Failed to get servers from Redux:', error)
|
logger.error('Failed to get servers from Redux:', error)
|
||||||
return []
|
return []
|
||||||
@ -54,7 +74,7 @@ async function getServersFromRedux(): Promise<MCPServer[]> {
|
|||||||
export async function getMcpServerById(id: string): Promise<Server> {
|
export async function getMcpServerById(id: string): Promise<Server> {
|
||||||
const server = cachedServers[id]
|
const server = cachedServers[id]
|
||||||
if (!server) {
|
if (!server) {
|
||||||
const servers = await getServersFromRedux()
|
const servers = await getMCPServersFromRedux()
|
||||||
const mcpServer = servers.find((s) => s.id === id || s.name === id)
|
const mcpServer = servers.find((s) => s.id === id || s.name === id)
|
||||||
if (!mcpServer) {
|
if (!mcpServer) {
|
||||||
throw new Error(`Server not found: ${id}`)
|
throw new Error(`Server not found: ${id}`)
|
||||||
|
|||||||
@ -8,11 +8,9 @@ import { UIMessageChunk } from 'ai'
|
|||||||
|
|
||||||
// Generic agent stream event that works with any agent type
|
// Generic agent stream event that works with any agent type
|
||||||
export interface AgentStreamEvent {
|
export interface AgentStreamEvent {
|
||||||
type: 'chunk' | 'error' | 'complete'
|
type: 'chunk' | 'error' | 'complete' | 'cancelled'
|
||||||
chunk?: UIMessageChunk // Standard AI SDK chunk for UI consumption
|
chunk?: UIMessageChunk // Standard AI SDK chunk for UI consumption
|
||||||
rawAgentMessage?: any // Agent-specific raw message (SDKMessage for Claude Code, different for other agents)
|
|
||||||
error?: Error
|
error?: Error
|
||||||
agentResult?: any // Agent-specific result data
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Agent stream interface that all agents should implement
|
// Agent stream interface that all agents should implement
|
||||||
@ -24,5 +22,10 @@ export interface AgentStream extends EventEmitter {
|
|||||||
|
|
||||||
// Base agent service interface
|
// Base agent service interface
|
||||||
export interface AgentServiceInterface {
|
export interface AgentServiceInterface {
|
||||||
invoke(prompt: string, session: GetAgentSessionResponse, lastAgentSessionId?: string): Promise<AgentStream>
|
invoke(
|
||||||
|
prompt: string,
|
||||||
|
session: GetAgentSessionResponse,
|
||||||
|
abortController: AbortController,
|
||||||
|
lastAgentSessionId?: string
|
||||||
|
): Promise<AgentStream>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -170,14 +170,18 @@ export class SessionMessageService extends BaseService {
|
|||||||
return { messages }
|
return { messages }
|
||||||
}
|
}
|
||||||
|
|
||||||
createSessionMessage(session: GetAgentSessionResponse, messageData: CreateSessionMessageRequest): EventEmitter {
|
createSessionMessage(
|
||||||
|
session: GetAgentSessionResponse,
|
||||||
|
messageData: CreateSessionMessageRequest,
|
||||||
|
abortController: AbortController
|
||||||
|
): EventEmitter {
|
||||||
this.ensureInitialized()
|
this.ensureInitialized()
|
||||||
|
|
||||||
// Create a new EventEmitter to manage the session message lifecycle
|
// Create a new EventEmitter to manage the session message lifecycle
|
||||||
const sessionStream = new EventEmitter()
|
const sessionStream = new EventEmitter()
|
||||||
|
|
||||||
// No parent validation needed, start immediately
|
// No parent validation needed, start immediately
|
||||||
this.startSessionMessageStream(session, messageData, sessionStream)
|
this.startSessionMessageStream(session, messageData, sessionStream, abortController)
|
||||||
|
|
||||||
return sessionStream
|
return sessionStream
|
||||||
}
|
}
|
||||||
@ -185,7 +189,8 @@ export class SessionMessageService extends BaseService {
|
|||||||
private async startSessionMessageStream(
|
private async startSessionMessageStream(
|
||||||
session: GetAgentSessionResponse,
|
session: GetAgentSessionResponse,
|
||||||
req: CreateSessionMessageRequest,
|
req: CreateSessionMessageRequest,
|
||||||
sessionStream: EventEmitter
|
sessionStream: EventEmitter,
|
||||||
|
abortController: AbortController
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
||||||
let newAgentSessionId = ''
|
let newAgentSessionId = ''
|
||||||
@ -198,7 +203,7 @@ export class SessionMessageService extends BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create the streaming agent invocation (using invokeStream for streaming)
|
// Create the streaming agent invocation (using invokeStream for streaming)
|
||||||
const claudeStream = await this.cc.invoke(req.content, session, agentSessionId)
|
const claudeStream = await this.cc.invoke(req.content, session, abortController, agentSessionId)
|
||||||
|
|
||||||
// Use chunk accumulator to manage streaming data
|
// Use chunk accumulator to manage streaming data
|
||||||
const accumulator = new ChunkAccumulator()
|
const accumulator = new ChunkAccumulator()
|
||||||
@ -233,22 +238,15 @@ export class SessionMessageService extends BaseService {
|
|||||||
error: serializeError(underlyingError),
|
error: serializeError(underlyingError),
|
||||||
persistScheduled: false
|
persistScheduled: false
|
||||||
})
|
})
|
||||||
// Always emit a finish chunk at the end
|
// Always emit a complete chunk at the end
|
||||||
sessionStream.emit('data', {
|
sessionStream.emit('data', {
|
||||||
type: 'finish',
|
type: 'complete',
|
||||||
persistScheduled: false
|
persistScheduled: false
|
||||||
})
|
})
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
case 'complete': {
|
case 'complete': {
|
||||||
const completionPayload = event.result ?? accumulator.toModelMessage('assistant')
|
|
||||||
|
|
||||||
sessionStream.emit('data', {
|
|
||||||
type: 'complete',
|
|
||||||
result: completionPayload
|
|
||||||
})
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const persisted = await this.database.transaction(async (tx) => {
|
const persisted = await this.database.transaction(async (tx) => {
|
||||||
const userMessage = await this.persistUserMessage(tx, session.id, req.content, newAgentSessionId)
|
const userMessage = await this.persistUserMessage(tx, session.id, req.content, newAgentSessionId)
|
||||||
@ -273,9 +271,9 @@ export class SessionMessageService extends BaseService {
|
|||||||
error: serializeError(persistError)
|
error: serializeError(persistError)
|
||||||
})
|
})
|
||||||
} finally {
|
} finally {
|
||||||
// Always emit a finish chunk at the end
|
// Always emit a complete chunk at the end
|
||||||
sessionStream.emit('data', {
|
sessionStream.emit('data', {
|
||||||
type: 'finish',
|
type: 'complete',
|
||||||
persistScheduled: true
|
persistScheduled: true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,15 +14,6 @@ import { transformSDKMessageToUIChunk } from './transform'
|
|||||||
const require_ = createRequire(import.meta.url)
|
const require_ = createRequire(import.meta.url)
|
||||||
const logger = loggerService.withContext('ClaudeCodeService')
|
const logger = loggerService.withContext('ClaudeCodeService')
|
||||||
|
|
||||||
interface ClaudeCodeResult {
|
|
||||||
success: boolean
|
|
||||||
stdout: string
|
|
||||||
stderr: string
|
|
||||||
jsonOutput: any[]
|
|
||||||
error?: Error
|
|
||||||
exitCode?: number
|
|
||||||
}
|
|
||||||
|
|
||||||
class ClaudeCodeStream extends EventEmitter implements AgentStream {
|
class ClaudeCodeStream extends EventEmitter implements AgentStream {
|
||||||
declare emit: (event: 'data', data: AgentStreamEvent) => boolean
|
declare emit: (event: 'data', data: AgentStreamEvent) => boolean
|
||||||
declare on: (event: 'data', listener: (data: AgentStreamEvent) => void) => this
|
declare on: (event: 'data', listener: (data: AgentStreamEvent) => void) => this
|
||||||
@ -37,7 +28,12 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
this.claudeExecutablePath = require_.resolve('@anthropic-ai/claude-code/cli.js')
|
this.claudeExecutablePath = require_.resolve('@anthropic-ai/claude-code/cli.js')
|
||||||
}
|
}
|
||||||
|
|
||||||
async invoke(prompt: string, session: GetAgentSessionResponse, lastAgentSessionId?: string): Promise<AgentStream> {
|
async invoke(
|
||||||
|
prompt: string,
|
||||||
|
session: GetAgentSessionResponse,
|
||||||
|
abortController: AbortController,
|
||||||
|
lastAgentSessionId?: string
|
||||||
|
): Promise<AgentStream> {
|
||||||
const aiStream = new ClaudeCodeStream()
|
const aiStream = new ClaudeCodeStream()
|
||||||
|
|
||||||
// Validate session accessible paths and make sure it exists as a directory
|
// Validate session accessible paths and make sure it exists as a directory
|
||||||
@ -76,6 +72,7 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
|
|
||||||
// Build SDK options from parameters
|
// Build SDK options from parameters
|
||||||
const options: Options = {
|
const options: Options = {
|
||||||
|
abortController,
|
||||||
cwd,
|
cwd,
|
||||||
pathToClaudeCodeExecutable: this.claudeExecutablePath,
|
pathToClaudeCodeExecutable: this.claudeExecutablePath,
|
||||||
stderr: (chunk: string) => {
|
stderr: (chunk: string) => {
|
||||||
@ -164,8 +161,7 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
for (const chunk of chunks) {
|
for (const chunk of chunks) {
|
||||||
stream.emit('data', {
|
stream.emit('data', {
|
||||||
type: 'chunk',
|
type: 'chunk',
|
||||||
chunk,
|
chunk
|
||||||
rawAgentMessage: message
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -179,57 +175,44 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
messageCount: jsonOutput.length
|
messageCount: jsonOutput.length
|
||||||
})
|
})
|
||||||
|
|
||||||
const result: ClaudeCodeResult = {
|
|
||||||
success: true,
|
|
||||||
stdout: '',
|
|
||||||
stderr: '',
|
|
||||||
jsonOutput,
|
|
||||||
exitCode: 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emit completion event
|
// Emit completion event
|
||||||
stream.emit('data', {
|
stream.emit('data', {
|
||||||
type: 'complete',
|
type: 'complete'
|
||||||
agentResult: {
|
|
||||||
...result,
|
|
||||||
rawSDKMessages: jsonOutput,
|
|
||||||
agentType: 'claude-code'
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (hasCompleted) return
|
if (hasCompleted) return
|
||||||
hasCompleted = true
|
hasCompleted = true
|
||||||
|
|
||||||
const duration = Date.now() - startTime
|
const duration = Date.now() - startTime
|
||||||
|
|
||||||
|
// Check if this is an abort error
|
||||||
|
const errorObj = error as any
|
||||||
|
const isAborted =
|
||||||
|
errorObj?.name === 'AbortError' ||
|
||||||
|
errorObj?.message?.includes('aborted') ||
|
||||||
|
options.abortController?.signal.aborted
|
||||||
|
|
||||||
|
if (isAborted) {
|
||||||
|
logger.info('SDK query aborted by client disconnect', { duration })
|
||||||
|
// Simply cleanup and return - don't emit error events
|
||||||
|
stream.emit('data', {
|
||||||
|
type: 'cancelled',
|
||||||
|
error: new Error('Request aborted by client')
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Original error handling for non-abort errors
|
||||||
logger.error('SDK query error:', {
|
logger.error('SDK query error:', {
|
||||||
error: error instanceof Error ? error.message : String(error),
|
error: errorObj instanceof Error ? errorObj.message : String(errorObj),
|
||||||
duration,
|
duration,
|
||||||
messageCount: jsonOutput.length
|
messageCount: jsonOutput.length
|
||||||
})
|
})
|
||||||
|
|
||||||
const result: ClaudeCodeResult = {
|
|
||||||
success: false,
|
|
||||||
stdout: '',
|
|
||||||
stderr: error instanceof Error ? error.message : String(error),
|
|
||||||
jsonOutput,
|
|
||||||
error: error instanceof Error ? error : new Error(String(error)),
|
|
||||||
exitCode: 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emit error event
|
// Emit error event
|
||||||
stream.emit('data', {
|
stream.emit('data', {
|
||||||
type: 'error',
|
type: 'error',
|
||||||
error: error instanceof Error ? error : new Error(String(error))
|
error: errorObj instanceof Error ? errorObj : new Error(String(errorObj))
|
||||||
})
|
|
||||||
|
|
||||||
// Emit completion with error result
|
|
||||||
stream.emit('data', {
|
|
||||||
type: 'complete',
|
|
||||||
agentResult: {
|
|
||||||
...result,
|
|
||||||
rawSDKMessages: jsonOutput,
|
|
||||||
agentType: 'claude-code'
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user