mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-26 11:44:28 +08:00
Merge branch 'feat/agents-new' of github.com:CherryHQ/cherry-studio into feat/agents-new
This commit is contained in:
commit
fcacc50fdc
@ -1,4 +1,5 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { AgentStreamEvent } from '@main/services/agents/interfaces/AgentStreamInterface'
|
||||
import { Request, Response } from 'express'
|
||||
|
||||
import { agentService, sessionMessageService, sessionService } from '../../../../services/agents'
|
||||
@ -42,13 +43,12 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
||||
res.setHeader('Access-Control-Allow-Origin', '*')
|
||||
res.setHeader('Access-Control-Allow-Headers', 'Cache-Control')
|
||||
|
||||
const messageStream = sessionMessageService.createSessionMessage(session, messageData)
|
||||
const abortController = new AbortController()
|
||||
const messageStream = sessionMessageService.createSessionMessage(session, messageData, abortController)
|
||||
|
||||
// Track stream lifecycle so we keep the SSE connection open until persistence finishes
|
||||
let responseEnded = false
|
||||
let streamFinished = false
|
||||
let awaitingPersistence = false
|
||||
let persistenceResolved = false
|
||||
|
||||
const finalizeResponse = () => {
|
||||
if (responseEnded) {
|
||||
@ -59,10 +59,6 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
||||
return
|
||||
}
|
||||
|
||||
if (awaitingPersistence && !persistenceResolved) {
|
||||
return
|
||||
}
|
||||
|
||||
responseEnded = true
|
||||
try {
|
||||
res.write('data: {"type":"finish"}\n\n')
|
||||
@ -73,15 +69,39 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
||||
res.end()
|
||||
}
|
||||
|
||||
// Handle client disconnect
|
||||
req.on('close', () => {
|
||||
/**
|
||||
* Client Disconnect Detection for Server-Sent Events (SSE)
|
||||
*
|
||||
* We monitor multiple HTTP events to reliably detect when a client disconnects
|
||||
* from the streaming response. This is crucial for:
|
||||
* - Aborting long-running Claude Code processes
|
||||
* - Cleaning up resources and preventing memory leaks
|
||||
* - Avoiding orphaned processes
|
||||
*
|
||||
* Event Priority & Behavior:
|
||||
* 1. res.on('close') - Most common for SSE client disconnects (browser tab close, curl Ctrl+C)
|
||||
* 2. req.on('aborted') - Explicit request abortion
|
||||
* 3. req.on('close') - Request object closure (less common with SSE)
|
||||
*
|
||||
* When any disconnect event fires, we:
|
||||
* - Abort the Claude Code SDK process via abortController
|
||||
* - Clean up event listeners to prevent memory leaks
|
||||
* - Mark the response as ended to prevent further writes
|
||||
*/
|
||||
const handleDisconnect = () => {
|
||||
if (responseEnded) return
|
||||
logger.info(`Client disconnected from streaming message for session: ${sessionId}`)
|
||||
responseEnded = true
|
||||
messageStream.removeAllListeners()
|
||||
})
|
||||
abortController.abort('Client disconnected')
|
||||
}
|
||||
|
||||
req.on('close', handleDisconnect)
|
||||
req.on('aborted', handleDisconnect)
|
||||
res.on('close', handleDisconnect)
|
||||
|
||||
// Handle stream events
|
||||
messageStream.on('data', (event: any) => {
|
||||
messageStream.on('data', (event: AgentStreamEvent) => {
|
||||
if (responseEnded) return
|
||||
|
||||
try {
|
||||
@ -101,12 +121,6 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
||||
logger.error(`Streaming message error for session: ${sessionId}:`, event.error)
|
||||
|
||||
streamFinished = true
|
||||
awaitingPersistence = Boolean(event.persistScheduled)
|
||||
|
||||
if (!awaitingPersistence) {
|
||||
persistenceResolved = true
|
||||
}
|
||||
|
||||
finalizeResponse()
|
||||
break
|
||||
}
|
||||
@ -116,32 +130,22 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
||||
// res.write(`data: ${JSON.stringify({ type: 'complete', result: event.result })}\n\n`)
|
||||
|
||||
streamFinished = true
|
||||
awaitingPersistence = true
|
||||
finalizeResponse()
|
||||
break
|
||||
}
|
||||
|
||||
case 'persisted':
|
||||
// Send persistence success event
|
||||
// res.write(`data: ${JSON.stringify(event)}\n\n`)
|
||||
logger.debug(`Session message persisted for session: ${sessionId}`, { messageId: event.message?.id })
|
||||
|
||||
persistenceResolved = true
|
||||
finalizeResponse()
|
||||
break
|
||||
|
||||
case 'persist-error':
|
||||
// Send persistence error event
|
||||
// res.write(`data: ${JSON.stringify(event)}\n\n`)
|
||||
logger.error(`Failed to persist session message for session: ${sessionId}:`, event.error)
|
||||
|
||||
persistenceResolved = true
|
||||
case 'cancelled': {
|
||||
logger.info(`Streaming message cancelled for session: ${sessionId}`)
|
||||
// res.write(`data: ${JSON.stringify({ type: 'cancelled' })}\n\n`)
|
||||
streamFinished = true
|
||||
finalizeResponse()
|
||||
break
|
||||
}
|
||||
|
||||
default:
|
||||
// Handle other event types as generic data
|
||||
res.write(`data: ${JSON.stringify(event)}\n\n`)
|
||||
logger.info(`Streaming message event for session: ${sessionId}:`, { event })
|
||||
// res.write(`data: ${JSON.stringify(event)}\n\n`)
|
||||
break
|
||||
}
|
||||
} catch (writeError) {
|
||||
@ -199,8 +203,8 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
||||
res.end()
|
||||
}
|
||||
},
|
||||
5 * 60 * 1000
|
||||
) // 5 minutes timeout
|
||||
10 * 60 * 1000
|
||||
) // 10 minutes timeout
|
||||
|
||||
// Clear timeout when response ends
|
||||
res.on('close', () => clearTimeout(timeout))
|
||||
|
||||
@ -13,8 +13,7 @@ import { Request, Response } from 'express'
|
||||
import { IncomingMessage, ServerResponse } from 'http'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { reduxService } from '../../services/ReduxService'
|
||||
import { getMcpServerById } from '../utils/mcp'
|
||||
import { getMcpServerById, getMCPServersFromRedux } from '../utils/mcp'
|
||||
|
||||
const logger = loggerService.withContext('MCPApiService')
|
||||
const transports: Record<string, StreamableHTTPServerTransport> = {}
|
||||
@ -57,34 +56,10 @@ class MCPApiService extends EventEmitter {
|
||||
this.transport.onmessage = this.onMessage
|
||||
}
|
||||
|
||||
/**
|
||||
* Get servers directly from Redux store
|
||||
*/
|
||||
private async getServersFromRedux(): Promise<MCPServer[]> {
|
||||
try {
|
||||
logger.silly('Getting servers from Redux store')
|
||||
|
||||
// Try to get from cache first (faster)
|
||||
const cachedServers = reduxService.selectSync<MCPServer[]>('state.mcp.servers')
|
||||
if (cachedServers && Array.isArray(cachedServers)) {
|
||||
logger.silly(`Found ${cachedServers.length} servers in Redux cache`)
|
||||
return cachedServers
|
||||
}
|
||||
|
||||
// If cache is not available, get fresh data
|
||||
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
|
||||
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
|
||||
return servers || []
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get servers from Redux:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
// get all activated servers
|
||||
async getAllServers(req: Request): Promise<McpServersResp> {
|
||||
try {
|
||||
const servers = await this.getServersFromRedux()
|
||||
const servers = await getMCPServersFromRedux()
|
||||
logger.silly(`Returning ${servers.length} servers`)
|
||||
const resp: McpServersResp = {
|
||||
servers: {}
|
||||
@ -111,7 +86,7 @@ class MCPApiService extends EventEmitter {
|
||||
async getServerById(id: string): Promise<MCPServer | null> {
|
||||
try {
|
||||
logger.silly(`getServerById called with id: ${id}`)
|
||||
const servers = await this.getServersFromRedux()
|
||||
const servers = await getMCPServersFromRedux()
|
||||
const server = servers.find((s) => s.id === id)
|
||||
if (!server) {
|
||||
logger.warn(`Server with id ${id} not found`)
|
||||
|
||||
@ -1,12 +1,24 @@
|
||||
import { CacheService } from '@main/services/CacheService'
|
||||
import { loggerService } from '@main/services/LoggerService'
|
||||
import { reduxService } from '@main/services/ReduxService'
|
||||
import { ApiModel, Model, Provider } from '@types'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerUtils')
|
||||
|
||||
// Cache configuration
|
||||
const PROVIDERS_CACHE_KEY = 'api-server:providers'
|
||||
const PROVIDERS_CACHE_TTL = 5 * 60 * 1000 // 5 minutes
|
||||
|
||||
export async function getAvailableProviders(): Promise<Provider[]> {
|
||||
try {
|
||||
// Wait for store to be ready before accessing providers
|
||||
// Try to get from cache first (faster)
|
||||
const cachedSupportedProviders = CacheService.get<Provider[]>(PROVIDERS_CACHE_KEY)
|
||||
if (cachedSupportedProviders) {
|
||||
logger.debug(`Found ${cachedSupportedProviders.length} supported providers (from cache)`)
|
||||
return cachedSupportedProviders
|
||||
}
|
||||
|
||||
// If cache is not available, get fresh data from Redux
|
||||
const providers = await reduxService.select('state.llm.providers')
|
||||
if (!providers || !Array.isArray(providers)) {
|
||||
logger.warn('No providers found in Redux store, returning empty array')
|
||||
@ -18,6 +30,9 @@ export async function getAvailableProviders(): Promise<Provider[]> {
|
||||
(p: Provider) => p.enabled && (p.type === 'openai' || p.type === 'anthropic')
|
||||
)
|
||||
|
||||
// Cache the filtered results
|
||||
CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL)
|
||||
|
||||
logger.info(`Filtered to ${supportedProviders.length} supported providers from ${providers.length} total providers`)
|
||||
|
||||
return supportedProviders
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { CacheService } from '@main/services/CacheService'
|
||||
import mcpService from '@main/services/MCPService'
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import { CallToolRequestSchema, ListToolsRequestSchema, ListToolsResult } from '@modelcontextprotocol/sdk/types.js'
|
||||
@ -8,6 +9,10 @@ import { reduxService } from '../../services/ReduxService'
|
||||
|
||||
const logger = loggerService.withContext('MCPApiService')
|
||||
|
||||
// Cache configuration
|
||||
const MCP_SERVERS_CACHE_KEY = 'api-server:mcp-servers'
|
||||
const MCP_SERVERS_CACHE_TTL = 5 * 60 * 1000 // 5 minutes
|
||||
|
||||
const cachedServers: Record<string, Server> = {}
|
||||
|
||||
async function handleListToolsRequest(request: any, extra: any): Promise<ListToolsResult> {
|
||||
@ -33,18 +38,33 @@ async function handleCallToolRequest(request: any, extra: any): Promise<any> {
|
||||
}
|
||||
|
||||
async function getMcpServerConfigById(id: string): Promise<MCPServer | undefined> {
|
||||
const servers = await getServersFromRedux()
|
||||
const servers = await getMCPServersFromRedux()
|
||||
return servers.find((s) => s.id === id || s.name === id)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get servers directly from Redux store
|
||||
*/
|
||||
async function getServersFromRedux(): Promise<MCPServer[]> {
|
||||
export async function getMCPServersFromRedux(): Promise<MCPServer[]> {
|
||||
try {
|
||||
logger.silly('Getting servers from Redux store')
|
||||
|
||||
// Try to get from cache first (faster)
|
||||
const cachedServers = CacheService.get<MCPServer[]>(MCP_SERVERS_CACHE_KEY)
|
||||
if (cachedServers) {
|
||||
logger.silly(`Found ${cachedServers.length} servers (from cache)`)
|
||||
return cachedServers
|
||||
}
|
||||
|
||||
// If cache is not available, get fresh data from Redux
|
||||
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
|
||||
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
|
||||
return servers || []
|
||||
const serverList = servers || []
|
||||
|
||||
// Cache the results
|
||||
CacheService.set(MCP_SERVERS_CACHE_KEY, serverList, MCP_SERVERS_CACHE_TTL)
|
||||
|
||||
logger.silly(`Fetched ${serverList.length} servers from Redux store`)
|
||||
return serverList
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get servers from Redux:', error)
|
||||
return []
|
||||
@ -54,7 +74,7 @@ async function getServersFromRedux(): Promise<MCPServer[]> {
|
||||
export async function getMcpServerById(id: string): Promise<Server> {
|
||||
const server = cachedServers[id]
|
||||
if (!server) {
|
||||
const servers = await getServersFromRedux()
|
||||
const servers = await getMCPServersFromRedux()
|
||||
const mcpServer = servers.find((s) => s.id === id || s.name === id)
|
||||
if (!mcpServer) {
|
||||
throw new Error(`Server not found: ${id}`)
|
||||
|
||||
@ -8,11 +8,9 @@ import { UIMessageChunk } from 'ai'
|
||||
|
||||
// Generic agent stream event that works with any agent type
|
||||
export interface AgentStreamEvent {
|
||||
type: 'chunk' | 'error' | 'complete'
|
||||
type: 'chunk' | 'error' | 'complete' | 'cancelled'
|
||||
chunk?: UIMessageChunk // Standard AI SDK chunk for UI consumption
|
||||
rawAgentMessage?: any // Agent-specific raw message (SDKMessage for Claude Code, different for other agents)
|
||||
error?: Error
|
||||
agentResult?: any // Agent-specific result data
|
||||
}
|
||||
|
||||
// Agent stream interface that all agents should implement
|
||||
@ -24,5 +22,10 @@ export interface AgentStream extends EventEmitter {
|
||||
|
||||
// Base agent service interface
|
||||
export interface AgentServiceInterface {
|
||||
invoke(prompt: string, session: GetAgentSessionResponse, lastAgentSessionId?: string): Promise<AgentStream>
|
||||
invoke(
|
||||
prompt: string,
|
||||
session: GetAgentSessionResponse,
|
||||
abortController: AbortController,
|
||||
lastAgentSessionId?: string
|
||||
): Promise<AgentStream>
|
||||
}
|
||||
|
||||
@ -170,14 +170,18 @@ export class SessionMessageService extends BaseService {
|
||||
return { messages }
|
||||
}
|
||||
|
||||
createSessionMessage(session: GetAgentSessionResponse, messageData: CreateSessionMessageRequest): EventEmitter {
|
||||
createSessionMessage(
|
||||
session: GetAgentSessionResponse,
|
||||
messageData: CreateSessionMessageRequest,
|
||||
abortController: AbortController
|
||||
): EventEmitter {
|
||||
this.ensureInitialized()
|
||||
|
||||
// Create a new EventEmitter to manage the session message lifecycle
|
||||
const sessionStream = new EventEmitter()
|
||||
|
||||
// No parent validation needed, start immediately
|
||||
this.startSessionMessageStream(session, messageData, sessionStream)
|
||||
this.startSessionMessageStream(session, messageData, sessionStream, abortController)
|
||||
|
||||
return sessionStream
|
||||
}
|
||||
@ -185,7 +189,8 @@ export class SessionMessageService extends BaseService {
|
||||
private async startSessionMessageStream(
|
||||
session: GetAgentSessionResponse,
|
||||
req: CreateSessionMessageRequest,
|
||||
sessionStream: EventEmitter
|
||||
sessionStream: EventEmitter,
|
||||
abortController: AbortController
|
||||
): Promise<void> {
|
||||
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
||||
let newAgentSessionId = ''
|
||||
@ -198,7 +203,7 @@ export class SessionMessageService extends BaseService {
|
||||
}
|
||||
|
||||
// Create the streaming agent invocation (using invokeStream for streaming)
|
||||
const claudeStream = await this.cc.invoke(req.content, session, agentSessionId)
|
||||
const claudeStream = await this.cc.invoke(req.content, session, abortController, agentSessionId)
|
||||
|
||||
// Use chunk accumulator to manage streaming data
|
||||
const accumulator = new ChunkAccumulator()
|
||||
@ -233,22 +238,15 @@ export class SessionMessageService extends BaseService {
|
||||
error: serializeError(underlyingError),
|
||||
persistScheduled: false
|
||||
})
|
||||
// Always emit a finish chunk at the end
|
||||
// Always emit a complete chunk at the end
|
||||
sessionStream.emit('data', {
|
||||
type: 'finish',
|
||||
type: 'complete',
|
||||
persistScheduled: false
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
case 'complete': {
|
||||
const completionPayload = event.result ?? accumulator.toModelMessage('assistant')
|
||||
|
||||
sessionStream.emit('data', {
|
||||
type: 'complete',
|
||||
result: completionPayload
|
||||
})
|
||||
|
||||
try {
|
||||
const persisted = await this.database.transaction(async (tx) => {
|
||||
const userMessage = await this.persistUserMessage(tx, session.id, req.content, newAgentSessionId)
|
||||
@ -273,9 +271,9 @@ export class SessionMessageService extends BaseService {
|
||||
error: serializeError(persistError)
|
||||
})
|
||||
} finally {
|
||||
// Always emit a finish chunk at the end
|
||||
// Always emit a complete chunk at the end
|
||||
sessionStream.emit('data', {
|
||||
type: 'finish',
|
||||
type: 'complete',
|
||||
persistScheduled: true
|
||||
})
|
||||
}
|
||||
|
||||
@ -14,15 +14,6 @@ import { transformSDKMessageToUIChunk } from './transform'
|
||||
const require_ = createRequire(import.meta.url)
|
||||
const logger = loggerService.withContext('ClaudeCodeService')
|
||||
|
||||
interface ClaudeCodeResult {
|
||||
success: boolean
|
||||
stdout: string
|
||||
stderr: string
|
||||
jsonOutput: any[]
|
||||
error?: Error
|
||||
exitCode?: number
|
||||
}
|
||||
|
||||
class ClaudeCodeStream extends EventEmitter implements AgentStream {
|
||||
declare emit: (event: 'data', data: AgentStreamEvent) => boolean
|
||||
declare on: (event: 'data', listener: (data: AgentStreamEvent) => void) => this
|
||||
@ -37,7 +28,12 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
this.claudeExecutablePath = require_.resolve('@anthropic-ai/claude-code/cli.js')
|
||||
}
|
||||
|
||||
async invoke(prompt: string, session: GetAgentSessionResponse, lastAgentSessionId?: string): Promise<AgentStream> {
|
||||
async invoke(
|
||||
prompt: string,
|
||||
session: GetAgentSessionResponse,
|
||||
abortController: AbortController,
|
||||
lastAgentSessionId?: string
|
||||
): Promise<AgentStream> {
|
||||
const aiStream = new ClaudeCodeStream()
|
||||
|
||||
// Validate session accessible paths and make sure it exists as a directory
|
||||
@ -76,6 +72,7 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
|
||||
// Build SDK options from parameters
|
||||
const options: Options = {
|
||||
abortController,
|
||||
cwd,
|
||||
pathToClaudeCodeExecutable: this.claudeExecutablePath,
|
||||
stderr: (chunk: string) => {
|
||||
@ -164,8 +161,7 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
for (const chunk of chunks) {
|
||||
stream.emit('data', {
|
||||
type: 'chunk',
|
||||
chunk,
|
||||
rawAgentMessage: message
|
||||
chunk
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -179,57 +175,44 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
messageCount: jsonOutput.length
|
||||
})
|
||||
|
||||
const result: ClaudeCodeResult = {
|
||||
success: true,
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
jsonOutput,
|
||||
exitCode: 0
|
||||
}
|
||||
|
||||
// Emit completion event
|
||||
stream.emit('data', {
|
||||
type: 'complete',
|
||||
agentResult: {
|
||||
...result,
|
||||
rawSDKMessages: jsonOutput,
|
||||
agentType: 'claude-code'
|
||||
}
|
||||
type: 'complete'
|
||||
})
|
||||
} catch (error) {
|
||||
if (hasCompleted) return
|
||||
hasCompleted = true
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
|
||||
// Check if this is an abort error
|
||||
const errorObj = error as any
|
||||
const isAborted =
|
||||
errorObj?.name === 'AbortError' ||
|
||||
errorObj?.message?.includes('aborted') ||
|
||||
options.abortController?.signal.aborted
|
||||
|
||||
if (isAborted) {
|
||||
logger.info('SDK query aborted by client disconnect', { duration })
|
||||
// Simply cleanup and return - don't emit error events
|
||||
stream.emit('data', {
|
||||
type: 'cancelled',
|
||||
error: new Error('Request aborted by client')
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Original error handling for non-abort errors
|
||||
logger.error('SDK query error:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
error: errorObj instanceof Error ? errorObj.message : String(errorObj),
|
||||
duration,
|
||||
messageCount: jsonOutput.length
|
||||
})
|
||||
|
||||
const result: ClaudeCodeResult = {
|
||||
success: false,
|
||||
stdout: '',
|
||||
stderr: error instanceof Error ? error.message : String(error),
|
||||
jsonOutput,
|
||||
error: error instanceof Error ? error : new Error(String(error)),
|
||||
exitCode: 1
|
||||
}
|
||||
|
||||
// Emit error event
|
||||
stream.emit('data', {
|
||||
type: 'error',
|
||||
error: error instanceof Error ? error : new Error(String(error))
|
||||
})
|
||||
|
||||
// Emit completion with error result
|
||||
stream.emit('data', {
|
||||
type: 'complete',
|
||||
agentResult: {
|
||||
...result,
|
||||
rawSDKMessages: jsonOutput,
|
||||
agentType: 'claude-code'
|
||||
}
|
||||
error: errorObj instanceof Error ? errorObj : new Error(String(errorObj))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user