mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-21 07:40:11 +08:00
♻️ refactor: replace ClaudeCodeService child process with SDK query
- Replace process spawning with @anthropic-ai/claude-code SDK query function - Remove complex process management, stdout/stderr parsing, and JSON buffering - Directly iterate over typed SDKMessages from AsyncGenerator - Simplify error handling and completion logic - Maintain full compatibility with existing SessionMessageService interface - Eliminate ~130 lines of process management code - Improve reliability by removing JSON parsing edge cases
This commit is contained in:
parent
be7399b3c4
commit
7abd5da57d
@ -5,7 +5,7 @@ import type {
|
|||||||
AgentSessionMessageEntity,
|
AgentSessionMessageEntity,
|
||||||
CreateSessionMessageRequest,
|
CreateSessionMessageRequest,
|
||||||
GetAgentSessionResponse,
|
GetAgentSessionResponse,
|
||||||
ListOptions,
|
ListOptions
|
||||||
} from '@types'
|
} from '@types'
|
||||||
import { ModelMessage, UIMessage, UIMessageChunk } from 'ai'
|
import { ModelMessage, UIMessage, UIMessageChunk } from 'ai'
|
||||||
import { convertToModelMessages, readUIMessageStream } from 'ai'
|
import { convertToModelMessages, readUIMessageStream } from 'ai'
|
||||||
@ -17,7 +17,6 @@ import ClaudeCodeService from './claudecode'
|
|||||||
|
|
||||||
const logger = loggerService.withContext('SessionMessageService')
|
const logger = loggerService.withContext('SessionMessageService')
|
||||||
|
|
||||||
|
|
||||||
// Collapse a UIMessageChunk stream into a final UIMessage, then convert to ModelMessage[]
|
// Collapse a UIMessageChunk stream into a final UIMessage, then convert to ModelMessage[]
|
||||||
export async function chunksToModelMessages(
|
export async function chunksToModelMessages(
|
||||||
chunkStream: ReadableStream<UIMessageChunk>,
|
chunkStream: ReadableStream<UIMessageChunk>,
|
||||||
@ -68,7 +67,6 @@ interface PersistContext {
|
|||||||
session: GetAgentSessionResponse
|
session: GetAgentSessionResponse
|
||||||
accumulator: ChunkAccumulator
|
accumulator: ChunkAccumulator
|
||||||
userMessageId: number
|
userMessageId: number
|
||||||
sessionStream: EventEmitter
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Chunk accumulator class to collect and reconstruct streaming data
|
// Chunk accumulator class to collect and reconstruct streaming data
|
||||||
@ -254,10 +252,7 @@ export class SessionMessageService extends BaseService {
|
|||||||
updated_at: now
|
updated_at: now
|
||||||
}
|
}
|
||||||
|
|
||||||
const [saved] = await this.database
|
const [saved] = await this.database.insert(sessionMessagesTable).values(insertData).returning()
|
||||||
.insert(sessionMessagesTable)
|
|
||||||
.values(insertData)
|
|
||||||
.returning()
|
|
||||||
|
|
||||||
return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity
|
return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity
|
||||||
}
|
}
|
||||||
@ -299,8 +294,8 @@ export class SessionMessageService extends BaseService {
|
|||||||
|
|
||||||
// Create the streaming agent invocation (using invokeStream for streaming)
|
// Create the streaming agent invocation (using invokeStream for streaming)
|
||||||
const claudeStream = this.cc.invoke(req.content, session.accessible_paths[0], session_id, {
|
const claudeStream = this.cc.invoke(req.content, session.accessible_paths[0], session_id, {
|
||||||
permissionMode: session.configuration?.permissionMode || 'default',
|
permissionMode: session.configuration?.permission_mode,
|
||||||
maxTurns: session.configuration?.maxTurns || 10
|
maxTurns: session.configuration?.max_turns
|
||||||
})
|
})
|
||||||
|
|
||||||
// Use chunk accumulator to manage streaming data
|
// Use chunk accumulator to manage streaming data
|
||||||
@ -345,8 +340,7 @@ export class SessionMessageService extends BaseService {
|
|||||||
void this.persistSessionMessageAsync({
|
void this.persistSessionMessageAsync({
|
||||||
session,
|
session,
|
||||||
accumulator,
|
accumulator,
|
||||||
userMessageId,
|
userMessageId
|
||||||
sessionStream
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -355,6 +349,10 @@ export class SessionMessageService extends BaseService {
|
|||||||
error: serializeError(underlyingError),
|
error: serializeError(underlyingError),
|
||||||
persistScheduled
|
persistScheduled
|
||||||
})
|
})
|
||||||
|
// Always emit a finish chunk at the end
|
||||||
|
sessionStream.emit('data', {
|
||||||
|
type: 'finish'
|
||||||
|
})
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -367,18 +365,15 @@ export class SessionMessageService extends BaseService {
|
|||||||
// Set the agent result in the accumulator
|
// Set the agent result in the accumulator
|
||||||
accumulator.setAgentResult(event.agentResult)
|
accumulator.setAgentResult(event.agentResult)
|
||||||
|
|
||||||
// // Emit SSE completion FIRST before persistence
|
|
||||||
// sessionStream.emit('data', {
|
|
||||||
// type: 'complete',
|
|
||||||
// result: accumulator.buildStructuredContent()
|
|
||||||
// })
|
|
||||||
|
|
||||||
// Then handle async persistence
|
// Then handle async persistence
|
||||||
void this.persistSessionMessageAsync({
|
void this.persistSessionMessageAsync({
|
||||||
session,
|
session,
|
||||||
accumulator,
|
accumulator,
|
||||||
userMessageId,
|
userMessageId
|
||||||
sessionStream
|
})
|
||||||
|
// Always emit a finish chunk at the end
|
||||||
|
sessionStream.emit('data', {
|
||||||
|
type: 'finish'
|
||||||
})
|
})
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -399,11 +394,10 @@ export class SessionMessageService extends BaseService {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
private async persistSessionMessageAsync({ session, accumulator, userMessageId, sessionStream }: PersistContext) {
|
private async persistSessionMessageAsync({ session, accumulator, userMessageId }: PersistContext) {
|
||||||
if (!session?.id) {
|
if (!session?.id) {
|
||||||
const missingSessionError = new Error('Missing session_id for persisted message')
|
const missingSessionError = new Error('Missing session_id for persisted message')
|
||||||
logger.error(missingSessionError.message, { error: missingSessionError })
|
logger.error('error persisting session message', { error: missingSessionError })
|
||||||
sessionStream.emit('data', { type: 'persist-error', error: serializeError(missingSessionError) })
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -435,13 +429,10 @@ export class SessionMessageService extends BaseService {
|
|||||||
updated_at: now
|
updated_at: now
|
||||||
}
|
}
|
||||||
|
|
||||||
const [row] = await this.database.insert(sessionMessagesTable).values(insertData).returning()
|
await this.database.insert(sessionMessagesTable).values(insertData).returning()
|
||||||
|
logger.debug('Success Persisted session message')
|
||||||
const entity = this.deserializeSessionMessage(row) as AgentSessionMessageEntity
|
|
||||||
sessionStream.emit('data', { type: 'persisted', message: entity })
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Failed to persist session message', { error })
|
logger.error('Failed to persist session message', { error })
|
||||||
sessionStream.emit('data', { type: 'persist-error', error: serializeError(error) })
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
// src/main/services/agents/services/claudecode/index.ts
|
// src/main/services/agents/services/claudecode/index.ts
|
||||||
import { ChildProcess, spawn } from 'node:child_process'
|
|
||||||
import { EventEmitter } from 'node:events'
|
import { EventEmitter } from 'node:events'
|
||||||
import { createRequire } from 'node:module'
|
import { createRequire } from 'node:module'
|
||||||
|
|
||||||
import { Options, SDKMessage } from '@anthropic-ai/claude-code'
|
import { Options, query, SDKMessage } from '@anthropic-ai/claude-code'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
|
|
||||||
import { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface'
|
import { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface'
|
||||||
@ -38,210 +37,129 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
invoke(prompt: string, cwd: string, session_id?: string, base?: Options): AgentStream {
|
invoke(prompt: string, cwd: string, session_id?: string, base?: Options): AgentStream {
|
||||||
const aiStream = new ClaudeCodeStream()
|
const aiStream = new ClaudeCodeStream()
|
||||||
|
|
||||||
// Spawn process with same parameters as invoke
|
// Build SDK options from parameters
|
||||||
const args: string[] = [this.claudeExecutablePath, '--output-format', 'stream-json', '--verbose']
|
const options: Options = {
|
||||||
|
cwd,
|
||||||
|
pathToClaudeCodeExecutable: this.claudeExecutablePath,
|
||||||
|
stderr: (chunk: string) => {
|
||||||
|
logger.info('claude stderr', { chunk })
|
||||||
|
},
|
||||||
|
...base
|
||||||
|
}
|
||||||
|
|
||||||
if (session_id) {
|
if (session_id) {
|
||||||
args.push('--resume', session_id)
|
options.resume = session_id
|
||||||
}
|
|
||||||
if (base?.maxTurns) {
|
|
||||||
args.push('--max-turns', base.maxTurns.toString())
|
|
||||||
}
|
|
||||||
if (base?.permissionMode) {
|
|
||||||
args.push('--permission-mode', base.permissionMode)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
args.push('--print', prompt)
|
logger.info('Starting Claude Code SDK query', {
|
||||||
|
prompt,
|
||||||
logger.info('Spawning Claude Code streaming process', { args, cwd })
|
options: { cwd, maxTurns: options.maxTurns, permissionMode: options.permissionMode }
|
||||||
|
|
||||||
const p = spawn(process.execPath, args, {
|
|
||||||
env: { ...process.env, ELECTRON_RUN_AS_NODE: '1' },
|
|
||||||
cwd,
|
|
||||||
stdio: ['pipe', 'pipe', 'pipe'],
|
|
||||||
shell: false,
|
|
||||||
detached: false
|
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.info('Streaming process created', { pid: p.pid })
|
// Start async processing
|
||||||
|
this.processSDKQuery(prompt, options, aiStream)
|
||||||
// Close stdin immediately
|
|
||||||
if (p.stdin) {
|
|
||||||
p.stdin.end()
|
|
||||||
logger.debug('Closed stdin for streaming process')
|
|
||||||
}
|
|
||||||
|
|
||||||
this.setupStreamingHandlers(p, aiStream)
|
|
||||||
|
|
||||||
return aiStream
|
return aiStream
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set up process event handlers for streaming output
|
* Process SDK query and emit stream events
|
||||||
*/
|
*/
|
||||||
private setupStreamingHandlers(process: ChildProcess, stream: ClaudeCodeStream): void {
|
private async processSDKQuery(prompt: string, options: Options, stream: ClaudeCodeStream): Promise<void> {
|
||||||
let stdoutData = ''
|
const jsonOutput: SDKMessage[] = []
|
||||||
let stderrData = ''
|
|
||||||
const jsonOutput: any[] = []
|
|
||||||
let hasCompleted = false
|
let hasCompleted = false
|
||||||
let stdoutBuffer = ''
|
|
||||||
|
|
||||||
const startTime = Date.now()
|
const startTime = Date.now()
|
||||||
|
|
||||||
const emitChunks = (sdkMessage: SDKMessage) => {
|
try {
|
||||||
jsonOutput.push(sdkMessage)
|
// Process streaming responses using SDK query
|
||||||
const chunks = transformSDKMessageToUIChunk(sdkMessage)
|
for await (const message of query({
|
||||||
for (const chunk of chunks) {
|
prompt,
|
||||||
stream.emit('data', {
|
options
|
||||||
type: 'chunk',
|
})) {
|
||||||
chunk,
|
if (hasCompleted) break
|
||||||
rawAgentMessage: sdkMessage // Store Claude Code specific SDKMessage as generic agent message
|
|
||||||
})
|
jsonOutput.push(message)
|
||||||
|
logger.silly('claude response', { message })
|
||||||
|
if (message.type === 'assistant' || message.type === 'user') {
|
||||||
|
logger.silly('message content', {
|
||||||
|
message: JSON.stringify({ role: message.message.role, content: message.message.content })
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transform SDKMessage to UIMessageChunks
|
||||||
|
const chunks = transformSDKMessageToUIChunk(message)
|
||||||
|
for (const chunk of chunks) {
|
||||||
|
stream.emit('data', {
|
||||||
|
type: 'chunk',
|
||||||
|
chunk,
|
||||||
|
rawAgentMessage: message
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Handle stdout with streaming events
|
// Successfully completed
|
||||||
if (process.stdout) {
|
|
||||||
process.stdout.setEncoding('utf8')
|
|
||||||
process.stdout.on('data', (data: string) => {
|
|
||||||
stdoutData += data
|
|
||||||
stdoutBuffer += data
|
|
||||||
logger.debug('Streaming stdout chunk:', { length: data.length })
|
|
||||||
|
|
||||||
let newlineIndex = stdoutBuffer.indexOf('\n')
|
|
||||||
while (newlineIndex !== -1) {
|
|
||||||
const line = stdoutBuffer.slice(0, newlineIndex)
|
|
||||||
stdoutBuffer = stdoutBuffer.slice(newlineIndex + 1)
|
|
||||||
const trimmed = line.trim()
|
|
||||||
if (trimmed) {
|
|
||||||
try {
|
|
||||||
const parsed = JSON.parse(trimmed) as SDKMessage
|
|
||||||
emitChunks(parsed)
|
|
||||||
logger.debug('Parsed JSON line', { parsed })
|
|
||||||
} catch (error) {
|
|
||||||
logger.debug('Non-JSON line', { line: trimmed })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
newlineIndex = stdoutBuffer.indexOf('\n')
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
process.stdout.on('end', () => {
|
|
||||||
const trimmed = stdoutBuffer.trim()
|
|
||||||
if (trimmed) {
|
|
||||||
try {
|
|
||||||
const parsed = JSON.parse(trimmed) as SDKMessage
|
|
||||||
emitChunks(parsed)
|
|
||||||
logger.debug('Parsed JSON line on stream end', { parsed })
|
|
||||||
} catch (error) {
|
|
||||||
logger.debug('Non-JSON remainder on stdout end', { line: trimmed })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.debug('Streaming stdout ended')
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle stderr
|
|
||||||
if (process.stderr) {
|
|
||||||
process.stderr.setEncoding('utf8')
|
|
||||||
process.stderr.on('data', (data: string) => {
|
|
||||||
stderrData += data
|
|
||||||
const message = data.trim()
|
|
||||||
if (!message) return
|
|
||||||
logger.warn('Streaming stderr chunk:', { data: message })
|
|
||||||
stream.emit('data', {
|
|
||||||
type: 'error',
|
|
||||||
error: new Error(message)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
process.stderr.on('end', () => {
|
|
||||||
logger.debug('Streaming stderr ended')
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle process completion
|
|
||||||
const completeProcess = (code: number | null, signal: NodeJS.Signals | null, error?: Error) => {
|
|
||||||
if (hasCompleted) return
|
|
||||||
hasCompleted = true
|
hasCompleted = true
|
||||||
|
|
||||||
const duration = Date.now() - startTime
|
const duration = Date.now() - startTime
|
||||||
const success = !error && code === 0
|
|
||||||
|
|
||||||
logger.info('Streaming process completed', {
|
logger.debug('SDK query completed successfully', {
|
||||||
code,
|
|
||||||
signal,
|
|
||||||
success,
|
|
||||||
duration,
|
duration,
|
||||||
stdoutLength: stdoutData.length,
|
messageCount: jsonOutput.length
|
||||||
stderrLength: stderrData.length,
|
|
||||||
jsonItems: jsonOutput.length,
|
|
||||||
error: error?.message
|
|
||||||
})
|
})
|
||||||
|
|
||||||
const result: ClaudeCodeResult = {
|
const result: ClaudeCodeResult = {
|
||||||
success,
|
success: true,
|
||||||
stdout: stdoutData,
|
stdout: '',
|
||||||
stderr: stderrData,
|
stderr: '',
|
||||||
jsonOutput,
|
jsonOutput,
|
||||||
exitCode: code || undefined,
|
exitCode: 0
|
||||||
error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Emit completion event with agent-specific result
|
// Emit completion event
|
||||||
stream.emit('data', {
|
stream.emit('data', {
|
||||||
type: 'complete',
|
type: 'complete',
|
||||||
agentResult: {
|
agentResult: {
|
||||||
...result,
|
...result,
|
||||||
rawSDKMessages: jsonOutput, // Claude Code specific: all collected SDK messages
|
rawSDKMessages: jsonOutput,
|
||||||
agentType: 'claude-code' // Identify the agent type
|
agentType: 'claude-code'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
if (hasCompleted) return
|
||||||
|
hasCompleted = true
|
||||||
|
|
||||||
|
const duration = Date.now() - startTime
|
||||||
|
logger.error('SDK query error:', {
|
||||||
|
error: error instanceof Error ? error.message : String(error),
|
||||||
|
duration,
|
||||||
|
messageCount: jsonOutput.length
|
||||||
|
})
|
||||||
|
|
||||||
|
const result: ClaudeCodeResult = {
|
||||||
|
success: false,
|
||||||
|
stdout: '',
|
||||||
|
stderr: error instanceof Error ? error.message : String(error),
|
||||||
|
jsonOutput,
|
||||||
|
error: error instanceof Error ? error : new Error(String(error)),
|
||||||
|
exitCode: 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit error event
|
||||||
|
stream.emit('data', {
|
||||||
|
type: 'error',
|
||||||
|
error: error instanceof Error ? error : new Error(String(error))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Emit completion with error result
|
||||||
|
stream.emit('data', {
|
||||||
|
type: 'complete',
|
||||||
|
agentResult: {
|
||||||
|
...result,
|
||||||
|
rawSDKMessages: jsonOutput,
|
||||||
|
agentType: 'claude-code'
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle process exit
|
|
||||||
process.on('exit', (code, signal) => {
|
|
||||||
completeProcess(code, signal)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Handle process errors
|
|
||||||
process.on('error', (error) => {
|
|
||||||
const duration = Date.now() - startTime
|
|
||||||
logger.error('Streaming process error:', {
|
|
||||||
error: error.message,
|
|
||||||
duration,
|
|
||||||
stdoutLength: stdoutData.length,
|
|
||||||
stderrLength: stderrData.length
|
|
||||||
})
|
|
||||||
|
|
||||||
completeProcess(null, null, error)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Handle close event as a fallback
|
|
||||||
process.on('close', (code, signal) => {
|
|
||||||
logger.debug('Streaming process closed', { code, signal })
|
|
||||||
completeProcess(code, signal)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Set timeout to prevent hanging
|
|
||||||
const timeout = setTimeout(() => {
|
|
||||||
if (!hasCompleted) {
|
|
||||||
logger.error('Streaming process timeout after 600 seconds', {
|
|
||||||
pid: process.pid,
|
|
||||||
stdoutLength: stdoutData.length,
|
|
||||||
stderrLength: stderrData.length,
|
|
||||||
jsonItems: jsonOutput.length
|
|
||||||
})
|
|
||||||
process.kill('SIGTERM')
|
|
||||||
completeProcess(null, null, new Error('Process timeout after 600 seconds'))
|
|
||||||
}
|
|
||||||
}, 600 * 1000)
|
|
||||||
|
|
||||||
// Clear timeout when process ends
|
|
||||||
process.on('exit', () => clearTimeout(timeout))
|
|
||||||
process.on('error', () => clearTimeout(timeout))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export default ClaudeCodeService
|
export default ClaudeCodeService
|
||||||
|
|||||||
@ -327,11 +327,6 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Always emit a finish chunk at the end
|
|
||||||
chunks.push({
|
|
||||||
type: 'finish'
|
|
||||||
})
|
|
||||||
return chunks
|
return chunks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user