mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-21 07:40:11 +08:00
♻️ refactor: replace ClaudeCodeService child process with SDK query
- Replace process spawning with @anthropic-ai/claude-code SDK query function - Remove complex process management, stdout/stderr parsing, and JSON buffering - Directly iterate over typed SDKMessages from AsyncGenerator - Simplify error handling and completion logic - Maintain full compatibility with existing SessionMessageService interface - Eliminate ~130 lines of process management code - Improve reliability by removing JSON parsing edge cases
This commit is contained in:
parent
be7399b3c4
commit
7abd5da57d
@ -5,7 +5,7 @@ import type {
|
||||
AgentSessionMessageEntity,
|
||||
CreateSessionMessageRequest,
|
||||
GetAgentSessionResponse,
|
||||
ListOptions,
|
||||
ListOptions
|
||||
} from '@types'
|
||||
import { ModelMessage, UIMessage, UIMessageChunk } from 'ai'
|
||||
import { convertToModelMessages, readUIMessageStream } from 'ai'
|
||||
@ -17,7 +17,6 @@ import ClaudeCodeService from './claudecode'
|
||||
|
||||
const logger = loggerService.withContext('SessionMessageService')
|
||||
|
||||
|
||||
// Collapse a UIMessageChunk stream into a final UIMessage, then convert to ModelMessage[]
|
||||
export async function chunksToModelMessages(
|
||||
chunkStream: ReadableStream<UIMessageChunk>,
|
||||
@ -68,7 +67,6 @@ interface PersistContext {
|
||||
session: GetAgentSessionResponse
|
||||
accumulator: ChunkAccumulator
|
||||
userMessageId: number
|
||||
sessionStream: EventEmitter
|
||||
}
|
||||
|
||||
// Chunk accumulator class to collect and reconstruct streaming data
|
||||
@ -254,10 +252,7 @@ export class SessionMessageService extends BaseService {
|
||||
updated_at: now
|
||||
}
|
||||
|
||||
const [saved] = await this.database
|
||||
.insert(sessionMessagesTable)
|
||||
.values(insertData)
|
||||
.returning()
|
||||
const [saved] = await this.database.insert(sessionMessagesTable).values(insertData).returning()
|
||||
|
||||
return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity
|
||||
}
|
||||
@ -299,8 +294,8 @@ export class SessionMessageService extends BaseService {
|
||||
|
||||
// Create the streaming agent invocation (using invokeStream for streaming)
|
||||
const claudeStream = this.cc.invoke(req.content, session.accessible_paths[0], session_id, {
|
||||
permissionMode: session.configuration?.permissionMode || 'default',
|
||||
maxTurns: session.configuration?.maxTurns || 10
|
||||
permissionMode: session.configuration?.permission_mode,
|
||||
maxTurns: session.configuration?.max_turns
|
||||
})
|
||||
|
||||
// Use chunk accumulator to manage streaming data
|
||||
@ -345,8 +340,7 @@ export class SessionMessageService extends BaseService {
|
||||
void this.persistSessionMessageAsync({
|
||||
session,
|
||||
accumulator,
|
||||
userMessageId,
|
||||
sessionStream
|
||||
userMessageId
|
||||
})
|
||||
}
|
||||
|
||||
@ -355,6 +349,10 @@ export class SessionMessageService extends BaseService {
|
||||
error: serializeError(underlyingError),
|
||||
persistScheduled
|
||||
})
|
||||
// Always emit a finish chunk at the end
|
||||
sessionStream.emit('data', {
|
||||
type: 'finish'
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
@ -367,18 +365,15 @@ export class SessionMessageService extends BaseService {
|
||||
// Set the agent result in the accumulator
|
||||
accumulator.setAgentResult(event.agentResult)
|
||||
|
||||
// // Emit SSE completion FIRST before persistence
|
||||
// sessionStream.emit('data', {
|
||||
// type: 'complete',
|
||||
// result: accumulator.buildStructuredContent()
|
||||
// })
|
||||
|
||||
// Then handle async persistence
|
||||
void this.persistSessionMessageAsync({
|
||||
session,
|
||||
accumulator,
|
||||
userMessageId,
|
||||
sessionStream
|
||||
userMessageId
|
||||
})
|
||||
// Always emit a finish chunk at the end
|
||||
sessionStream.emit('data', {
|
||||
type: 'finish'
|
||||
})
|
||||
break
|
||||
}
|
||||
@ -399,11 +394,10 @@ export class SessionMessageService extends BaseService {
|
||||
})
|
||||
}
|
||||
|
||||
private async persistSessionMessageAsync({ session, accumulator, userMessageId, sessionStream }: PersistContext) {
|
||||
private async persistSessionMessageAsync({ session, accumulator, userMessageId }: PersistContext) {
|
||||
if (!session?.id) {
|
||||
const missingSessionError = new Error('Missing session_id for persisted message')
|
||||
logger.error(missingSessionError.message, { error: missingSessionError })
|
||||
sessionStream.emit('data', { type: 'persist-error', error: serializeError(missingSessionError) })
|
||||
logger.error('error persisting session message', { error: missingSessionError })
|
||||
return
|
||||
}
|
||||
|
||||
@ -435,13 +429,10 @@ export class SessionMessageService extends BaseService {
|
||||
updated_at: now
|
||||
}
|
||||
|
||||
const [row] = await this.database.insert(sessionMessagesTable).values(insertData).returning()
|
||||
|
||||
const entity = this.deserializeSessionMessage(row) as AgentSessionMessageEntity
|
||||
sessionStream.emit('data', { type: 'persisted', message: entity })
|
||||
await this.database.insert(sessionMessagesTable).values(insertData).returning()
|
||||
logger.debug('Success Persisted session message')
|
||||
} catch (error) {
|
||||
logger.error('Failed to persist session message', { error })
|
||||
sessionStream.emit('data', { type: 'persist-error', error: serializeError(error) })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
// src/main/services/agents/services/claudecode/index.ts
|
||||
import { ChildProcess, spawn } from 'node:child_process'
|
||||
import { EventEmitter } from 'node:events'
|
||||
import { createRequire } from 'node:module'
|
||||
|
||||
import { Options, SDKMessage } from '@anthropic-ai/claude-code'
|
||||
import { Options, query, SDKMessage } from '@anthropic-ai/claude-code'
|
||||
import { loggerService } from '@logger'
|
||||
|
||||
import { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface'
|
||||
@ -38,210 +37,129 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
invoke(prompt: string, cwd: string, session_id?: string, base?: Options): AgentStream {
|
||||
const aiStream = new ClaudeCodeStream()
|
||||
|
||||
// Spawn process with same parameters as invoke
|
||||
const args: string[] = [this.claudeExecutablePath, '--output-format', 'stream-json', '--verbose']
|
||||
// Build SDK options from parameters
|
||||
const options: Options = {
|
||||
cwd,
|
||||
pathToClaudeCodeExecutable: this.claudeExecutablePath,
|
||||
stderr: (chunk: string) => {
|
||||
logger.info('claude stderr', { chunk })
|
||||
},
|
||||
...base
|
||||
}
|
||||
|
||||
if (session_id) {
|
||||
args.push('--resume', session_id)
|
||||
}
|
||||
if (base?.maxTurns) {
|
||||
args.push('--max-turns', base.maxTurns.toString())
|
||||
}
|
||||
if (base?.permissionMode) {
|
||||
args.push('--permission-mode', base.permissionMode)
|
||||
options.resume = session_id
|
||||
}
|
||||
|
||||
args.push('--print', prompt)
|
||||
|
||||
logger.info('Spawning Claude Code streaming process', { args, cwd })
|
||||
|
||||
const p = spawn(process.execPath, args, {
|
||||
env: { ...process.env, ELECTRON_RUN_AS_NODE: '1' },
|
||||
cwd,
|
||||
stdio: ['pipe', 'pipe', 'pipe'],
|
||||
shell: false,
|
||||
detached: false
|
||||
logger.info('Starting Claude Code SDK query', {
|
||||
prompt,
|
||||
options: { cwd, maxTurns: options.maxTurns, permissionMode: options.permissionMode }
|
||||
})
|
||||
|
||||
logger.info('Streaming process created', { pid: p.pid })
|
||||
|
||||
// Close stdin immediately
|
||||
if (p.stdin) {
|
||||
p.stdin.end()
|
||||
logger.debug('Closed stdin for streaming process')
|
||||
}
|
||||
|
||||
this.setupStreamingHandlers(p, aiStream)
|
||||
// Start async processing
|
||||
this.processSDKQuery(prompt, options, aiStream)
|
||||
|
||||
return aiStream
|
||||
}
|
||||
|
||||
/**
|
||||
* Set up process event handlers for streaming output
|
||||
* Process SDK query and emit stream events
|
||||
*/
|
||||
private setupStreamingHandlers(process: ChildProcess, stream: ClaudeCodeStream): void {
|
||||
let stdoutData = ''
|
||||
let stderrData = ''
|
||||
const jsonOutput: any[] = []
|
||||
private async processSDKQuery(prompt: string, options: Options, stream: ClaudeCodeStream): Promise<void> {
|
||||
const jsonOutput: SDKMessage[] = []
|
||||
let hasCompleted = false
|
||||
let stdoutBuffer = ''
|
||||
|
||||
const startTime = Date.now()
|
||||
|
||||
const emitChunks = (sdkMessage: SDKMessage) => {
|
||||
jsonOutput.push(sdkMessage)
|
||||
const chunks = transformSDKMessageToUIChunk(sdkMessage)
|
||||
for (const chunk of chunks) {
|
||||
stream.emit('data', {
|
||||
type: 'chunk',
|
||||
chunk,
|
||||
rawAgentMessage: sdkMessage // Store Claude Code specific SDKMessage as generic agent message
|
||||
})
|
||||
try {
|
||||
// Process streaming responses using SDK query
|
||||
for await (const message of query({
|
||||
prompt,
|
||||
options
|
||||
})) {
|
||||
if (hasCompleted) break
|
||||
|
||||
jsonOutput.push(message)
|
||||
logger.silly('claude response', { message })
|
||||
if (message.type === 'assistant' || message.type === 'user') {
|
||||
logger.silly('message content', {
|
||||
message: JSON.stringify({ role: message.message.role, content: message.message.content })
|
||||
})
|
||||
}
|
||||
|
||||
// Transform SDKMessage to UIMessageChunks
|
||||
const chunks = transformSDKMessageToUIChunk(message)
|
||||
for (const chunk of chunks) {
|
||||
stream.emit('data', {
|
||||
type: 'chunk',
|
||||
chunk,
|
||||
rawAgentMessage: message
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle stdout with streaming events
|
||||
if (process.stdout) {
|
||||
process.stdout.setEncoding('utf8')
|
||||
process.stdout.on('data', (data: string) => {
|
||||
stdoutData += data
|
||||
stdoutBuffer += data
|
||||
logger.debug('Streaming stdout chunk:', { length: data.length })
|
||||
|
||||
let newlineIndex = stdoutBuffer.indexOf('\n')
|
||||
while (newlineIndex !== -1) {
|
||||
const line = stdoutBuffer.slice(0, newlineIndex)
|
||||
stdoutBuffer = stdoutBuffer.slice(newlineIndex + 1)
|
||||
const trimmed = line.trim()
|
||||
if (trimmed) {
|
||||
try {
|
||||
const parsed = JSON.parse(trimmed) as SDKMessage
|
||||
emitChunks(parsed)
|
||||
logger.debug('Parsed JSON line', { parsed })
|
||||
} catch (error) {
|
||||
logger.debug('Non-JSON line', { line: trimmed })
|
||||
}
|
||||
}
|
||||
newlineIndex = stdoutBuffer.indexOf('\n')
|
||||
}
|
||||
})
|
||||
|
||||
process.stdout.on('end', () => {
|
||||
const trimmed = stdoutBuffer.trim()
|
||||
if (trimmed) {
|
||||
try {
|
||||
const parsed = JSON.parse(trimmed) as SDKMessage
|
||||
emitChunks(parsed)
|
||||
logger.debug('Parsed JSON line on stream end', { parsed })
|
||||
} catch (error) {
|
||||
logger.debug('Non-JSON remainder on stdout end', { line: trimmed })
|
||||
}
|
||||
}
|
||||
logger.debug('Streaming stdout ended')
|
||||
})
|
||||
}
|
||||
|
||||
// Handle stderr
|
||||
if (process.stderr) {
|
||||
process.stderr.setEncoding('utf8')
|
||||
process.stderr.on('data', (data: string) => {
|
||||
stderrData += data
|
||||
const message = data.trim()
|
||||
if (!message) return
|
||||
logger.warn('Streaming stderr chunk:', { data: message })
|
||||
stream.emit('data', {
|
||||
type: 'error',
|
||||
error: new Error(message)
|
||||
})
|
||||
})
|
||||
|
||||
process.stderr.on('end', () => {
|
||||
logger.debug('Streaming stderr ended')
|
||||
})
|
||||
}
|
||||
|
||||
// Handle process completion
|
||||
const completeProcess = (code: number | null, signal: NodeJS.Signals | null, error?: Error) => {
|
||||
if (hasCompleted) return
|
||||
// Successfully completed
|
||||
hasCompleted = true
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
const success = !error && code === 0
|
||||
|
||||
logger.info('Streaming process completed', {
|
||||
code,
|
||||
signal,
|
||||
success,
|
||||
logger.debug('SDK query completed successfully', {
|
||||
duration,
|
||||
stdoutLength: stdoutData.length,
|
||||
stderrLength: stderrData.length,
|
||||
jsonItems: jsonOutput.length,
|
||||
error: error?.message
|
||||
messageCount: jsonOutput.length
|
||||
})
|
||||
|
||||
const result: ClaudeCodeResult = {
|
||||
success,
|
||||
stdout: stdoutData,
|
||||
stderr: stderrData,
|
||||
success: true,
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
jsonOutput,
|
||||
exitCode: code || undefined,
|
||||
error
|
||||
exitCode: 0
|
||||
}
|
||||
|
||||
// Emit completion event with agent-specific result
|
||||
// Emit completion event
|
||||
stream.emit('data', {
|
||||
type: 'complete',
|
||||
agentResult: {
|
||||
...result,
|
||||
rawSDKMessages: jsonOutput, // Claude Code specific: all collected SDK messages
|
||||
agentType: 'claude-code' // Identify the agent type
|
||||
rawSDKMessages: jsonOutput,
|
||||
agentType: 'claude-code'
|
||||
}
|
||||
})
|
||||
} catch (error) {
|
||||
if (hasCompleted) return
|
||||
hasCompleted = true
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.error('SDK query error:', {
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
duration,
|
||||
messageCount: jsonOutput.length
|
||||
})
|
||||
|
||||
const result: ClaudeCodeResult = {
|
||||
success: false,
|
||||
stdout: '',
|
||||
stderr: error instanceof Error ? error.message : String(error),
|
||||
jsonOutput,
|
||||
error: error instanceof Error ? error : new Error(String(error)),
|
||||
exitCode: 1
|
||||
}
|
||||
|
||||
// Emit error event
|
||||
stream.emit('data', {
|
||||
type: 'error',
|
||||
error: error instanceof Error ? error : new Error(String(error))
|
||||
})
|
||||
|
||||
// Emit completion with error result
|
||||
stream.emit('data', {
|
||||
type: 'complete',
|
||||
agentResult: {
|
||||
...result,
|
||||
rawSDKMessages: jsonOutput,
|
||||
agentType: 'claude-code'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Handle process exit
|
||||
process.on('exit', (code, signal) => {
|
||||
completeProcess(code, signal)
|
||||
})
|
||||
|
||||
// Handle process errors
|
||||
process.on('error', (error) => {
|
||||
const duration = Date.now() - startTime
|
||||
logger.error('Streaming process error:', {
|
||||
error: error.message,
|
||||
duration,
|
||||
stdoutLength: stdoutData.length,
|
||||
stderrLength: stderrData.length
|
||||
})
|
||||
|
||||
completeProcess(null, null, error)
|
||||
})
|
||||
|
||||
// Handle close event as a fallback
|
||||
process.on('close', (code, signal) => {
|
||||
logger.debug('Streaming process closed', { code, signal })
|
||||
completeProcess(code, signal)
|
||||
})
|
||||
|
||||
// Set timeout to prevent hanging
|
||||
const timeout = setTimeout(() => {
|
||||
if (!hasCompleted) {
|
||||
logger.error('Streaming process timeout after 600 seconds', {
|
||||
pid: process.pid,
|
||||
stdoutLength: stdoutData.length,
|
||||
stderrLength: stderrData.length,
|
||||
jsonItems: jsonOutput.length
|
||||
})
|
||||
process.kill('SIGTERM')
|
||||
completeProcess(null, null, new Error('Process timeout after 600 seconds'))
|
||||
}
|
||||
}, 600 * 1000)
|
||||
|
||||
// Clear timeout when process ends
|
||||
process.on('exit', () => clearTimeout(timeout))
|
||||
process.on('error', () => clearTimeout(timeout))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
export default ClaudeCodeService
|
||||
|
||||
@ -327,11 +327,6 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Always emit a finish chunk at the end
|
||||
chunks.push({
|
||||
type: 'finish'
|
||||
})
|
||||
return chunks
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user