mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-29 23:12:38 +08:00
Refactor agent streaming from EventEmitter to ReadableStream
Replaced EventEmitter-based agent streaming with ReadableStream for better compatibility with AI SDK patterns. Modified SessionMessageService to return stream/completion pair instead of event emitter, updated HTTP handlers to use stream pumping, and added IPC contract for renderer-side message persistence.
This commit is contained in:
parent
fcacc50fdc
commit
1fd44a68b0
@ -89,6 +89,9 @@ export enum IpcChannel {
|
||||
// Python
|
||||
Python_Execute = 'python:execute',
|
||||
|
||||
// agent messages
|
||||
AgentMessage_PersistExchange = 'agent-message:persist-exchange',
|
||||
|
||||
//copilot
|
||||
Copilot_GetAuthMessage = 'copilot:get-auth-message',
|
||||
Copilot_GetCopilotToken = 'copilot:get-copilot-token',
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { AgentStreamEvent } from '@main/services/agents/interfaces/AgentStreamInterface'
|
||||
import { Request, Response } from 'express'
|
||||
|
||||
import { agentService, sessionMessageService, sessionService } from '../../../../services/agents'
|
||||
@ -44,7 +43,12 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
||||
res.setHeader('Access-Control-Allow-Headers', 'Cache-Control')
|
||||
|
||||
const abortController = new AbortController()
|
||||
const messageStream = sessionMessageService.createSessionMessage(session, messageData, abortController)
|
||||
const { stream, completion } = await sessionMessageService.createSessionMessage(
|
||||
session,
|
||||
messageData,
|
||||
abortController
|
||||
)
|
||||
const reader = stream.getReader()
|
||||
|
||||
// Track stream lifecycle so we keep the SSE connection open until persistence finishes
|
||||
let responseEnded = false
|
||||
@ -61,7 +65,7 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
||||
|
||||
responseEnded = true
|
||||
try {
|
||||
res.write('data: {"type":"finish"}\n\n')
|
||||
// res.write('data: {"type":"finish"}\n\n')
|
||||
res.write('data: [DONE]\n\n')
|
||||
} catch (writeError) {
|
||||
logger.error('Error writing final sentinel to SSE stream:', { error: writeError as Error })
|
||||
@ -92,93 +96,78 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
||||
if (responseEnded) return
|
||||
logger.info(`Client disconnected from streaming message for session: ${sessionId}`)
|
||||
responseEnded = true
|
||||
messageStream.removeAllListeners()
|
||||
abortController.abort('Client disconnected')
|
||||
reader.cancel('Client disconnected').catch(() => {})
|
||||
}
|
||||
|
||||
req.on('close', handleDisconnect)
|
||||
req.on('aborted', handleDisconnect)
|
||||
res.on('close', handleDisconnect)
|
||||
|
||||
// Handle stream events
|
||||
messageStream.on('data', (event: AgentStreamEvent) => {
|
||||
if (responseEnded) return
|
||||
|
||||
const pumpStream = async () => {
|
||||
try {
|
||||
switch (event.type) {
|
||||
case 'chunk':
|
||||
// Format UIMessageChunk as SSE event following AI SDK protocol
|
||||
res.write(`data: ${JSON.stringify(event.chunk)}\n\n`)
|
||||
while (!responseEnded) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
break
|
||||
}
|
||||
|
||||
case 'error': {
|
||||
// Send error as AI SDK error chunk
|
||||
const errorChunk = {
|
||||
res.write(`data: ${JSON.stringify(value)}\n\n`)
|
||||
}
|
||||
|
||||
streamFinished = true
|
||||
finalizeResponse()
|
||||
} catch (error) {
|
||||
if (responseEnded) return
|
||||
logger.error('Error reading agent stream:', { error })
|
||||
try {
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
type: 'error',
|
||||
errorText: event.error?.message || 'Stream processing error'
|
||||
}
|
||||
res.write(`data: ${JSON.stringify(errorChunk)}\n\n`)
|
||||
logger.error(`Streaming message error for session: ${sessionId}:`, event.error)
|
||||
|
||||
streamFinished = true
|
||||
finalizeResponse()
|
||||
break
|
||||
}
|
||||
|
||||
case 'complete': {
|
||||
logger.info(`Streaming message completed for session: ${sessionId}`)
|
||||
// res.write(`data: ${JSON.stringify({ type: 'complete', result: event.result })}\n\n`)
|
||||
|
||||
streamFinished = true
|
||||
finalizeResponse()
|
||||
break
|
||||
}
|
||||
|
||||
case 'cancelled': {
|
||||
logger.info(`Streaming message cancelled for session: ${sessionId}`)
|
||||
// res.write(`data: ${JSON.stringify({ type: 'cancelled' })}\n\n`)
|
||||
streamFinished = true
|
||||
finalizeResponse()
|
||||
break
|
||||
}
|
||||
|
||||
default:
|
||||
// Handle other event types as generic data
|
||||
logger.info(`Streaming message event for session: ${sessionId}:`, { event })
|
||||
// res.write(`data: ${JSON.stringify(event)}\n\n`)
|
||||
break
|
||||
}
|
||||
} catch (writeError) {
|
||||
logger.error('Error writing to SSE stream:', { error: writeError })
|
||||
if (!responseEnded) {
|
||||
responseEnded = true
|
||||
res.end()
|
||||
error: {
|
||||
message: (error as Error).message || 'Stream processing error',
|
||||
type: 'stream_error',
|
||||
code: 'stream_processing_failed'
|
||||
}
|
||||
})}\n\n`
|
||||
)
|
||||
} catch (writeError) {
|
||||
logger.error('Error writing stream error to SSE:', { error: writeError })
|
||||
}
|
||||
responseEnded = true
|
||||
res.end()
|
||||
}
|
||||
}
|
||||
|
||||
pumpStream().catch((error) => {
|
||||
logger.error('Pump stream failure:', { error })
|
||||
})
|
||||
|
||||
// Handle stream errors
|
||||
messageStream.on('error', (error: Error) => {
|
||||
if (responseEnded) return
|
||||
|
||||
logger.error(`Stream error for session: ${sessionId}:`, { error })
|
||||
try {
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: error.message || 'Stream processing error',
|
||||
type: 'stream_error',
|
||||
code: 'stream_processing_failed'
|
||||
}
|
||||
})}\n\n`
|
||||
)
|
||||
} catch (writeError) {
|
||||
logger.error('Error writing error to SSE stream:', { error: writeError })
|
||||
}
|
||||
responseEnded = true
|
||||
res.end()
|
||||
})
|
||||
completion
|
||||
.then(() => {
|
||||
streamFinished = true
|
||||
finalizeResponse()
|
||||
})
|
||||
.catch((error) => {
|
||||
if (responseEnded) return
|
||||
logger.error(`Streaming message error for session: ${sessionId}:`, error)
|
||||
try {
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: (error as { message?: string })?.message || 'Stream processing error',
|
||||
type: 'stream_error',
|
||||
code: 'stream_processing_failed'
|
||||
}
|
||||
})}\n\n`
|
||||
)
|
||||
} catch (writeError) {
|
||||
logger.error('Error writing completion error to SSE stream:', { error: writeError })
|
||||
}
|
||||
responseEnded = true
|
||||
res.end()
|
||||
})
|
||||
|
||||
// Set a timeout to prevent hanging indefinitely
|
||||
const timeout = setTimeout(
|
||||
@ -199,6 +188,8 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
||||
} catch (writeError) {
|
||||
logger.error('Error writing timeout to SSE stream:', { error: writeError })
|
||||
}
|
||||
abortController.abort('stream timeout')
|
||||
reader.cancel('stream timeout').catch(() => {})
|
||||
responseEnded = true
|
||||
res.end()
|
||||
}
|
||||
|
||||
@ -16,6 +16,7 @@ import checkDiskSpace from 'check-disk-space'
|
||||
import { BrowserWindow, dialog, ipcMain, ProxyConfig, session, shell, systemPreferences, webContents } from 'electron'
|
||||
import fontList from 'font-list'
|
||||
|
||||
import { agentMessageRepository } from './services/agents/database'
|
||||
import { apiServerService } from './services/ApiServerService'
|
||||
import appService from './services/AppService'
|
||||
import AppUpdater from './services/AppUpdater'
|
||||
@ -199,6 +200,15 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.AgentMessage_PersistExchange, async (_event, payload) => {
|
||||
try {
|
||||
return await agentMessageRepository.persistExchange(payload)
|
||||
} catch (error) {
|
||||
logger.error('Failed to persist agent session messages', error as Error)
|
||||
throw error
|
||||
}
|
||||
})
|
||||
|
||||
//only for mac
|
||||
if (isMac) {
|
||||
ipcMain.handle(IpcChannel.App_MacIsProcessTrusted, (): boolean => {
|
||||
|
||||
35
src/main/services/agents/TODO.md
Normal file
35
src/main/services/agents/TODO.md
Normal file
@ -0,0 +1,35 @@
|
||||
# Agents Service Refactor TODO (interface-level)
|
||||
|
||||
- [x] **SessionMessageService.createSessionMessage**
|
||||
- Replace the current `EventEmitter` that emits `UIMessageChunk` with a readable stream of `TextStreamPart` objects (same shape produced by `/api/messages` in `messageThunk`).
|
||||
- Update `startSessionMessageStream` to call a new adapter (`claudeToTextStreamPart(chunk)`) that maps Claude Code chunk payloads to `{ type: 'text-delta' | 'tool-call' | ... }` parts used by `AiSdkToChunkAdapter`.
|
||||
- Add a secondary return value (promise) resolving to the persisted `ModelMessage[]` once streaming completes, so the renderer thunk can await save confirmation.
|
||||
|
||||
- [x] **main -> renderer transport**
|
||||
- Update the existing SSE handler in `src/main/apiServer/routes/agents/handlers/messages.ts` (e.g., `createMessage`) to forward the new `TextStreamPart` stream over HTTP, preserving the current agent endpoint contract.
|
||||
- Keep abort handling compatible with the current HTTP server (honor `AbortController` on the request to terminate the stream).
|
||||
|
||||
- [x] **renderer thunk integration**
|
||||
- Introduce a thin IPC contract (e.g., `AgentMessagePersistence`) surfaced by `src/main/services/agents/database/index.ts` so the renderer thunk can request session-message writes without going through `SessionMessageService`.
|
||||
- Define explicit entry points on the main side:
|
||||
- `persistUserMessage({ sessionId, agentSessionId, payload, createdAt?, metadata? })`
|
||||
- `persistAssistantMessage({ sessionId, agentSessionId, payload, createdAt?, metadata? })`
|
||||
- `persistExchange({ sessionId, agentSessionId, user, assistant })` which runs the above in a single transaction and returns both records.
|
||||
- Export these helpers via an `agentMessageRepository` object so both IPC handlers and legacy services share the same persistence path.
|
||||
- Normalize persisted payloads to `{ message, blocks }` matching the renderer schema instead of AI-SDK `ModelMessage` chunks.
|
||||
- Extend `messageThunk.sendMessage` to call the agent transport when the topic corresponds to a session, pipe chunks through `createStreamProcessor` + `AiSdkToChunkAdapter`, and invoke the new persistence interface once streaming resolves.
|
||||
- Replace `useSession().createSessionMessage` optimistic insert with dispatching the thunk so Redux/Dexie persistence happens via the shared save helpers.
|
||||
|
||||
- [x] **persistence alignment**
|
||||
- Remove `persistUserMessage` / `persistAssistantMessage` calls from `SessionMessageService`; instead expose a `SessionMessageRepository` in `main` that the thunk invokes via existing Dexie helpers.
|
||||
- On renderer side, persist agent exchanges via IPC after streaming completes, storing `{ message, blocks }` payloads while skipping Dexie writes for agent sessions so the single source of truth remains `session_messages`.
|
||||
|
||||
- [x] **Blocks renderer**
|
||||
- Replace `AgentSessionMessages` simple `<div>` render with the shared `Blocks` component (`src/renderer/src/pages/home/Messages/Blocks`) wired to the Redux store.
|
||||
- Adjust `useSession` to only fetch metadata (e.g., session info) and rely on store selectors for message list.
|
||||
|
||||
- [x] **API client clean-up**
|
||||
- Remove `AgentApiClient.createMessage` direct POST once thunk is in place; calls should go through renderer thunk -> stream -> final persistence.
|
||||
|
||||
- [ ] **Regression tests**
|
||||
- Add integration test to assert agent sessions render incremental text the same way as standard assistant messages.
|
||||
@ -9,3 +9,6 @@
|
||||
|
||||
// Drizzle ORM schemas
|
||||
export * from './schema'
|
||||
|
||||
// Repository helpers
|
||||
export * from './sessionMessageRepository'
|
||||
|
||||
181
src/main/services/agents/database/sessionMessageRepository.ts
Normal file
181
src/main/services/agents/database/sessionMessageRepository.ts
Normal file
@ -0,0 +1,181 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type {
|
||||
AgentMessageAssistantPersistPayload,
|
||||
AgentMessagePersistExchangePayload,
|
||||
AgentMessagePersistExchangeResult,
|
||||
AgentMessageUserPersistPayload,
|
||||
AgentPersistedMessage,
|
||||
AgentSessionMessageEntity
|
||||
} from '@types'
|
||||
|
||||
import { BaseService } from '../BaseService'
|
||||
import type { InsertSessionMessageRow } from './schema'
|
||||
import { sessionMessagesTable } from './schema'
|
||||
|
||||
const logger = loggerService.withContext('AgentMessageRepository')
|
||||
|
||||
type TxClient = any
|
||||
|
||||
export type PersistUserMessageParams = AgentMessageUserPersistPayload & {
|
||||
sessionId: string
|
||||
agentSessionId?: string
|
||||
tx?: TxClient
|
||||
}
|
||||
|
||||
export type PersistAssistantMessageParams = AgentMessageAssistantPersistPayload & {
|
||||
sessionId: string
|
||||
agentSessionId: string
|
||||
tx?: TxClient
|
||||
}
|
||||
|
||||
type PersistExchangeParams = AgentMessagePersistExchangePayload & {
|
||||
tx?: TxClient
|
||||
}
|
||||
|
||||
type PersistExchangeResult = AgentMessagePersistExchangeResult
|
||||
|
||||
class AgentMessageRepository extends BaseService {
|
||||
private static instance: AgentMessageRepository | null = null
|
||||
|
||||
static getInstance(): AgentMessageRepository {
|
||||
if (!AgentMessageRepository.instance) {
|
||||
AgentMessageRepository.instance = new AgentMessageRepository()
|
||||
}
|
||||
|
||||
return AgentMessageRepository.instance
|
||||
}
|
||||
|
||||
private serializeMessage(payload: AgentPersistedMessage): string {
|
||||
return JSON.stringify(payload)
|
||||
}
|
||||
|
||||
private serializeMetadata(metadata?: Record<string, unknown>): string | undefined {
|
||||
if (!metadata) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.stringify(metadata)
|
||||
} catch (error) {
|
||||
logger.warn('Failed to serialize session message metadata', error as Error)
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
private deserialize(row: any): AgentSessionMessageEntity {
|
||||
if (!row) return row
|
||||
|
||||
const deserialized = { ...row }
|
||||
|
||||
if (typeof deserialized.content === 'string') {
|
||||
try {
|
||||
deserialized.content = JSON.parse(deserialized.content)
|
||||
} catch (error) {
|
||||
logger.warn('Failed to parse session message content JSON', error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
if (typeof deserialized.metadata === 'string') {
|
||||
try {
|
||||
deserialized.metadata = JSON.parse(deserialized.metadata)
|
||||
} catch (error) {
|
||||
logger.warn('Failed to parse session message metadata JSON', error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
return deserialized
|
||||
}
|
||||
|
||||
private getWriter(tx?: TxClient): TxClient {
|
||||
return tx ?? this.database
|
||||
}
|
||||
|
||||
async persistUserMessage(params: PersistUserMessageParams): Promise<AgentSessionMessageEntity> {
|
||||
await AgentMessageRepository.initialize()
|
||||
this.ensureInitialized()
|
||||
|
||||
const writer = this.getWriter(params.tx)
|
||||
const now = params.createdAt ?? params.payload.message.createdAt ?? new Date().toISOString()
|
||||
|
||||
const insertData: InsertSessionMessageRow = {
|
||||
session_id: params.sessionId,
|
||||
role: params.payload.message.role,
|
||||
content: this.serializeMessage(params.payload),
|
||||
agent_session_id: params.agentSessionId ?? '',
|
||||
metadata: this.serializeMetadata(params.metadata),
|
||||
created_at: now,
|
||||
updated_at: now
|
||||
}
|
||||
|
||||
const [saved] = await writer.insert(sessionMessagesTable).values(insertData).returning()
|
||||
|
||||
return this.deserialize(saved)
|
||||
}
|
||||
|
||||
async persistAssistantMessage(params: PersistAssistantMessageParams): Promise<AgentSessionMessageEntity> {
|
||||
await AgentMessageRepository.initialize()
|
||||
this.ensureInitialized()
|
||||
|
||||
const writer = this.getWriter(params.tx)
|
||||
const now = params.createdAt ?? params.payload.message.createdAt ?? new Date().toISOString()
|
||||
|
||||
const insertData: InsertSessionMessageRow = {
|
||||
session_id: params.sessionId,
|
||||
role: params.payload.message.role,
|
||||
content: this.serializeMessage(params.payload),
|
||||
agent_session_id: params.agentSessionId,
|
||||
metadata: this.serializeMetadata(params.metadata),
|
||||
created_at: now,
|
||||
updated_at: now
|
||||
}
|
||||
|
||||
const [saved] = await writer.insert(sessionMessagesTable).values(insertData).returning()
|
||||
|
||||
return this.deserialize(saved)
|
||||
}
|
||||
|
||||
async persistExchange(params: PersistExchangeParams): Promise<PersistExchangeResult> {
|
||||
await AgentMessageRepository.initialize()
|
||||
this.ensureInitialized()
|
||||
|
||||
const { sessionId, agentSessionId, user, assistant } = params
|
||||
|
||||
const result = await this.database.transaction(async (tx) => {
|
||||
const exchangeResult: PersistExchangeResult = {}
|
||||
|
||||
if (user?.payload) {
|
||||
if (!user.payload.message?.role) {
|
||||
throw new Error('User message payload missing role')
|
||||
}
|
||||
exchangeResult.userMessage = await this.persistUserMessage({
|
||||
sessionId,
|
||||
agentSessionId,
|
||||
payload: user.payload,
|
||||
metadata: user.metadata,
|
||||
createdAt: user.createdAt,
|
||||
tx
|
||||
})
|
||||
}
|
||||
|
||||
if (assistant?.payload) {
|
||||
if (!assistant.payload.message?.role) {
|
||||
throw new Error('Assistant message payload missing role')
|
||||
}
|
||||
exchangeResult.assistantMessage = await this.persistAssistantMessage({
|
||||
sessionId,
|
||||
agentSessionId,
|
||||
payload: assistant.payload,
|
||||
metadata: assistant.metadata,
|
||||
createdAt: assistant.createdAt,
|
||||
tx
|
||||
})
|
||||
}
|
||||
|
||||
return exchangeResult
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
export const agentMessageRepository = AgentMessageRepository.getInstance()
|
||||
@ -4,12 +4,12 @@
|
||||
import { EventEmitter } from 'node:events'
|
||||
|
||||
import { GetAgentSessionResponse } from '@types'
|
||||
import { UIMessageChunk } from 'ai'
|
||||
import type { TextStreamPart } from 'ai'
|
||||
|
||||
// Generic agent stream event that works with any agent type
|
||||
export interface AgentStreamEvent {
|
||||
type: 'chunk' | 'error' | 'complete' | 'cancelled'
|
||||
chunk?: UIMessageChunk // Standard AI SDK chunk for UI consumption
|
||||
chunk?: TextStreamPart<any> // Standard AI SDK chunk for UI consumption
|
||||
error?: Error
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
import { EventEmitter } from 'node:events'
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import type {
|
||||
AgentSessionMessageEntity,
|
||||
@ -7,29 +5,22 @@ import type {
|
||||
GetAgentSessionResponse,
|
||||
ListOptions
|
||||
} from '@types'
|
||||
import { ModelMessage, UIMessage, UIMessageChunk } from 'ai'
|
||||
import { convertToModelMessages, readUIMessageStream } from 'ai'
|
||||
import { ModelMessage, TextStreamPart } from 'ai'
|
||||
import { desc, eq } from 'drizzle-orm'
|
||||
|
||||
import { BaseService } from '../BaseService'
|
||||
import { InsertSessionMessageRow, sessionMessagesTable } from '../database/schema'
|
||||
import { sessionMessagesTable } from '../database/schema'
|
||||
import { AgentStreamEvent } from '../interfaces/AgentStreamInterface'
|
||||
import ClaudeCodeService from './claudecode'
|
||||
|
||||
const logger = loggerService.withContext('SessionMessageService')
|
||||
|
||||
// Collapse a UIMessageChunk stream into a final UIMessage, then convert to ModelMessage[]
|
||||
export async function chunksToModelMessages(
|
||||
chunkStream: ReadableStream<UIMessageChunk>,
|
||||
priorUiHistory: UIMessage[] = []
|
||||
): Promise<ModelMessage[]> {
|
||||
let latest: UIMessage | undefined
|
||||
|
||||
for await (const uiMsg of readUIMessageStream({ stream: chunkStream })) {
|
||||
latest = uiMsg // each yield is a newer state; keep the last one
|
||||
}
|
||||
|
||||
const uiMessages = latest ? [...priorUiHistory, latest] : priorUiHistory
|
||||
return convertToModelMessages(uiMessages) // -> ModelMessage[]
|
||||
type SessionStreamResult = {
|
||||
stream: ReadableStream<TextStreamPart<Record<string, any>>>
|
||||
completion: Promise<{
|
||||
userMessage?: AgentSessionMessageEntity
|
||||
assistantMessage?: AgentSessionMessageEntity
|
||||
}>
|
||||
}
|
||||
|
||||
// Ensure errors emitted through SSE are serializable
|
||||
@ -51,71 +42,69 @@ function serializeError(error: unknown): { message: string; name?: string; stack
|
||||
}
|
||||
}
|
||||
|
||||
// Chunk accumulator class to collect and reconstruct streaming data
|
||||
class ChunkAccumulator {
|
||||
private streamedChunks: UIMessageChunk[] = []
|
||||
private agentType: string = 'unknown'
|
||||
class TextStreamAccumulator {
|
||||
private textBuffer = ''
|
||||
private totalText = ''
|
||||
private readonly toolCalls = new Map<string, { toolName?: string; input?: unknown }>()
|
||||
private readonly toolResults = new Map<string, unknown>()
|
||||
|
||||
addChunk(chunk: UIMessageChunk): void {
|
||||
this.streamedChunks.push(chunk)
|
||||
}
|
||||
|
||||
// Create a ReadableStream from accumulated chunks
|
||||
createChunkStream(): ReadableStream<UIMessageChunk> {
|
||||
const chunks = [...this.streamedChunks]
|
||||
|
||||
return new ReadableStream<UIMessageChunk>({
|
||||
start(controller) {
|
||||
// Enqueue all chunks
|
||||
for (const chunk of chunks) {
|
||||
controller.enqueue(chunk)
|
||||
add(part: TextStreamPart<Record<string, any>>): void {
|
||||
switch (part.type) {
|
||||
case 'text-start':
|
||||
this.textBuffer = ''
|
||||
break
|
||||
case 'text-delta':
|
||||
if (part.text) {
|
||||
this.textBuffer += part.text
|
||||
}
|
||||
controller.close()
|
||||
break
|
||||
case 'text-end': {
|
||||
const blockText = (part.providerMetadata?.text?.value as string | undefined) ?? this.textBuffer
|
||||
if (blockText) {
|
||||
this.totalText += blockText
|
||||
}
|
||||
this.textBuffer = ''
|
||||
break
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Convert accumulated chunks to ModelMessages using chunksToModelMessages
|
||||
async toModelMessages(priorUiHistory: UIMessage[] = []): Promise<ModelMessage[]> {
|
||||
const chunkStream = this.createChunkStream()
|
||||
return await chunksToModelMessages(chunkStream, priorUiHistory)
|
||||
case 'tool-call':
|
||||
if (part.toolCallId) {
|
||||
this.toolCalls.set(part.toolCallId, {
|
||||
toolName: part.toolName,
|
||||
input: part.input ?? part.args ?? part.providerMetadata?.raw?.input
|
||||
})
|
||||
}
|
||||
break
|
||||
case 'tool-result':
|
||||
if (part.toolCallId) {
|
||||
this.toolResults.set(part.toolCallId, part.output ?? part.result ?? part.providerMetadata?.raw)
|
||||
}
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
toModelMessage(role: ModelMessage['role'] = 'assistant'): ModelMessage {
|
||||
// Reconstruct the content from chunks
|
||||
let textContent = ''
|
||||
const toolCalls: any[] = []
|
||||
const content = this.totalText || this.textBuffer || ''
|
||||
|
||||
for (const chunk of this.streamedChunks) {
|
||||
if (chunk.type === 'text-delta' && 'delta' in chunk) {
|
||||
textContent += chunk.delta
|
||||
} else if (chunk.type === 'tool-input-available' && 'toolCallId' in chunk && 'toolName' in chunk) {
|
||||
// Handle tool calls - use tool-input-available chunks
|
||||
const toolCall = {
|
||||
toolCallId: chunk.toolCallId,
|
||||
toolName: chunk.toolName,
|
||||
args: chunk.input || {}
|
||||
}
|
||||
toolCalls.push(toolCall)
|
||||
}
|
||||
}
|
||||
const toolInvocations = Array.from(this.toolCalls.entries()).map(([toolCallId, info]) => ({
|
||||
toolCallId,
|
||||
toolName: info.toolName,
|
||||
args: info.input,
|
||||
result: this.toolResults.get(toolCallId)
|
||||
}))
|
||||
|
||||
const message: any = {
|
||||
const message: Record<string, unknown> = {
|
||||
role,
|
||||
content: textContent
|
||||
content
|
||||
}
|
||||
|
||||
// Add tool invocations if any
|
||||
if (toolCalls.length > 0) {
|
||||
message.toolInvocations = toolCalls
|
||||
if (toolInvocations.length > 0) {
|
||||
message.toolInvocations = toolInvocations
|
||||
}
|
||||
|
||||
return message as ModelMessage
|
||||
}
|
||||
|
||||
getAgentType(): string {
|
||||
return this.agentType
|
||||
}
|
||||
}
|
||||
|
||||
export class SessionMessageService extends BaseService {
|
||||
@ -170,28 +159,21 @@ export class SessionMessageService extends BaseService {
|
||||
return { messages }
|
||||
}
|
||||
|
||||
createSessionMessage(
|
||||
async createSessionMessage(
|
||||
session: GetAgentSessionResponse,
|
||||
messageData: CreateSessionMessageRequest,
|
||||
abortController: AbortController
|
||||
): EventEmitter {
|
||||
): Promise<SessionStreamResult> {
|
||||
this.ensureInitialized()
|
||||
|
||||
// Create a new EventEmitter to manage the session message lifecycle
|
||||
const sessionStream = new EventEmitter()
|
||||
|
||||
// No parent validation needed, start immediately
|
||||
this.startSessionMessageStream(session, messageData, sessionStream, abortController)
|
||||
|
||||
return sessionStream
|
||||
return await this.startSessionMessageStream(session, messageData, abortController)
|
||||
}
|
||||
|
||||
private async startSessionMessageStream(
|
||||
session: GetAgentSessionResponse,
|
||||
req: CreateSessionMessageRequest,
|
||||
sessionStream: EventEmitter,
|
||||
abortController: AbortController
|
||||
): Promise<void> {
|
||||
): Promise<SessionStreamResult> {
|
||||
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
||||
let newAgentSessionId = ''
|
||||
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
|
||||
@ -202,98 +184,98 @@ export class SessionMessageService extends BaseService {
|
||||
throw new Error('Unsupported agent type for streaming')
|
||||
}
|
||||
|
||||
// Create the streaming agent invocation (using invokeStream for streaming)
|
||||
const claudeStream = await this.cc.invoke(req.content, session, abortController, agentSessionId)
|
||||
const accumulator = new TextStreamAccumulator()
|
||||
|
||||
// Use chunk accumulator to manage streaming data
|
||||
const accumulator = new ChunkAccumulator()
|
||||
let resolveCompletion!: (value: {
|
||||
userMessage?: AgentSessionMessageEntity
|
||||
assistantMessage?: AgentSessionMessageEntity
|
||||
}) => void
|
||||
let rejectCompletion!: (reason?: unknown) => void
|
||||
|
||||
// Handle agent stream events (agent-agnostic)
|
||||
claudeStream.on('data', async (event: any) => {
|
||||
try {
|
||||
switch (event.type) {
|
||||
case 'chunk':
|
||||
// Forward UIMessageChunk directly and collect raw agent messages
|
||||
if (event.chunk) {
|
||||
const chunk = event.chunk as UIMessageChunk
|
||||
if (chunk.type === 'start' && chunk.messageId) {
|
||||
newAgentSessionId = chunk.messageId
|
||||
const completion = new Promise<{
|
||||
userMessage?: AgentSessionMessageEntity
|
||||
assistantMessage?: AgentSessionMessageEntity
|
||||
}>((resolve, reject) => {
|
||||
resolveCompletion = resolve
|
||||
rejectCompletion = reject
|
||||
})
|
||||
|
||||
let finished = false
|
||||
|
||||
const cleanup = () => {
|
||||
if (finished) return
|
||||
finished = true
|
||||
claudeStream.removeAllListeners()
|
||||
}
|
||||
|
||||
const stream = new ReadableStream<TextStreamPart<Record<string, any>>>({
|
||||
start: (controller) => {
|
||||
claudeStream.on('data', async (event: AgentStreamEvent) => {
|
||||
if (finished) return
|
||||
try {
|
||||
switch (event.type) {
|
||||
case 'chunk': {
|
||||
const chunk = event.chunk as TextStreamPart<Record<string, any>> | undefined
|
||||
if (!chunk) {
|
||||
logger.warn('Received agent chunk event without chunk payload')
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'start' && chunk.messageId) {
|
||||
newAgentSessionId = chunk.messageId
|
||||
}
|
||||
|
||||
accumulator.add(chunk)
|
||||
controller.enqueue(chunk)
|
||||
break
|
||||
}
|
||||
accumulator.addChunk(chunk)
|
||||
|
||||
sessionStream.emit('data', {
|
||||
type: 'chunk',
|
||||
chunk
|
||||
})
|
||||
} else {
|
||||
logger.warn('Received agent chunk event without chunk payload')
|
||||
}
|
||||
break
|
||||
case 'error': {
|
||||
const stderrMessage = (event as any)?.data?.stderr as string | undefined
|
||||
const underlyingError = event.error ?? (stderrMessage ? new Error(stderrMessage) : undefined)
|
||||
cleanup()
|
||||
const streamError = underlyingError ?? new Error('Stream error')
|
||||
controller.error(streamError)
|
||||
rejectCompletion(serializeError(streamError))
|
||||
break
|
||||
}
|
||||
|
||||
case 'error': {
|
||||
const underlyingError = event.error || (event.data?.stderr ? new Error(event.data.stderr) : undefined)
|
||||
case 'complete': {
|
||||
cleanup()
|
||||
controller.close()
|
||||
resolveCompletion({})
|
||||
break
|
||||
}
|
||||
|
||||
sessionStream.emit('data', {
|
||||
type: 'error',
|
||||
error: serializeError(underlyingError),
|
||||
persistScheduled: false
|
||||
})
|
||||
// Always emit a complete chunk at the end
|
||||
sessionStream.emit('data', {
|
||||
type: 'complete',
|
||||
persistScheduled: false
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'cancelled': {
|
||||
cleanup()
|
||||
controller.close()
|
||||
resolveCompletion({})
|
||||
break
|
||||
}
|
||||
|
||||
case 'complete': {
|
||||
try {
|
||||
const persisted = await this.database.transaction(async (tx) => {
|
||||
const userMessage = await this.persistUserMessage(tx, session.id, req.content, newAgentSessionId)
|
||||
const assistantMessage = await this.persistAssistantMessage({
|
||||
tx,
|
||||
session,
|
||||
accumulator,
|
||||
agentSessionId: newAgentSessionId
|
||||
default:
|
||||
logger.warn('Unknown event type from Claude Code service:', {
|
||||
type: event.type
|
||||
})
|
||||
|
||||
return { userMessage, assistantMessage }
|
||||
})
|
||||
|
||||
sessionStream.emit('data', {
|
||||
type: 'persisted',
|
||||
message: persisted.assistantMessage,
|
||||
userMessage: persisted.userMessage
|
||||
})
|
||||
} catch (persistError) {
|
||||
sessionStream.emit('data', {
|
||||
type: 'persist-error',
|
||||
error: serializeError(persistError)
|
||||
})
|
||||
} finally {
|
||||
// Always emit a complete chunk at the end
|
||||
sessionStream.emit('data', {
|
||||
type: 'complete',
|
||||
persistScheduled: true
|
||||
})
|
||||
break
|
||||
}
|
||||
break
|
||||
} catch (error) {
|
||||
cleanup()
|
||||
controller.error(error)
|
||||
rejectCompletion(serializeError(error))
|
||||
}
|
||||
|
||||
default:
|
||||
logger.warn('Unknown event type from Claude Code service:', {
|
||||
type: event.type
|
||||
})
|
||||
break
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error handling Claude Code stream event:', { error })
|
||||
sessionStream.emit('data', {
|
||||
type: 'error',
|
||||
error: serializeError(error)
|
||||
})
|
||||
},
|
||||
cancel: (reason) => {
|
||||
cleanup()
|
||||
abortController.abort(typeof reason === 'string' ? reason : 'stream cancelled')
|
||||
resolveCompletion({})
|
||||
}
|
||||
})
|
||||
|
||||
return { stream, completion }
|
||||
}
|
||||
|
||||
private async getLastAgentSessionId(sessionId: string): Promise<string> {
|
||||
@ -317,75 +299,6 @@ export class SessionMessageService extends BaseService {
|
||||
}
|
||||
}
|
||||
|
||||
async persistUserMessage(
|
||||
tx: any,
|
||||
sessionId: string,
|
||||
prompt: string,
|
||||
agentSessionId: string
|
||||
): Promise<AgentSessionMessageEntity> {
|
||||
this.ensureInitialized()
|
||||
|
||||
const now = new Date().toISOString()
|
||||
const insertData: InsertSessionMessageRow = {
|
||||
session_id: sessionId,
|
||||
role: 'user',
|
||||
content: JSON.stringify({ role: 'user', content: prompt }),
|
||||
agent_session_id: agentSessionId,
|
||||
created_at: now,
|
||||
updated_at: now
|
||||
}
|
||||
|
||||
const [saved] = await tx.insert(sessionMessagesTable).values(insertData).returning()
|
||||
|
||||
return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity
|
||||
}
|
||||
|
||||
private async persistAssistantMessage({
|
||||
tx,
|
||||
session,
|
||||
accumulator,
|
||||
agentSessionId
|
||||
}: {
|
||||
tx: any
|
||||
session: GetAgentSessionResponse
|
||||
accumulator: ChunkAccumulator
|
||||
agentSessionId: string
|
||||
}): Promise<AgentSessionMessageEntity> {
|
||||
if (!session?.id) {
|
||||
const missingSessionError = new Error('Missing session_id for persisted message')
|
||||
logger.error('error persisting session message', { error: missingSessionError })
|
||||
throw missingSessionError
|
||||
}
|
||||
|
||||
const sessionId = session.id
|
||||
const now = new Date().toISOString()
|
||||
|
||||
try {
|
||||
// Use chunksToModelMessages to convert chunks to ModelMessages
|
||||
const modelMessages = await accumulator.toModelMessages()
|
||||
// Get the last message (should be the assistant's response)
|
||||
const modelMessage =
|
||||
modelMessages.length > 0 ? modelMessages[modelMessages.length - 1] : accumulator.toModelMessage('assistant')
|
||||
|
||||
const insertData: InsertSessionMessageRow = {
|
||||
session_id: sessionId,
|
||||
role: 'assistant',
|
||||
content: JSON.stringify(modelMessage),
|
||||
agent_session_id: agentSessionId,
|
||||
created_at: now,
|
||||
updated_at: now
|
||||
}
|
||||
|
||||
const [saved] = await tx.insert(sessionMessagesTable).values(insertData).returning()
|
||||
logger.debug('Success Persisted session message')
|
||||
|
||||
return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity
|
||||
} catch (error) {
|
||||
logger.error('Failed to persist session message', { error })
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
private deserializeSessionMessage(data: any): AgentSessionMessageEntity {
|
||||
if (!data) return data
|
||||
|
||||
|
||||
@ -1,384 +0,0 @@
|
||||
AI SDK UI functions such as `useChat` and `useCompletion` support both text streams and data streams. The stream protocol defines how the data is streamed to the frontend on top of the HTTP protocol.
|
||||
|
||||
This page describes both protocols and how to use them in the backend and frontend.
|
||||
|
||||
You can use this information to develop custom backends and frontends for your use case, e.g., to provide compatible API endpoints that are implemented in a different language such as Python.
|
||||
|
||||
For instance, here's an example using [FastAPI](https://github.com/vercel/ai/tree/main/examples/next-fastapi) as a backend.
|
||||
|
||||
## Text Stream Protocol
|
||||
|
||||
A text stream contains chunks in plain text, that are streamed to the frontend. Each chunk is then appended together to form a full text response.
|
||||
|
||||
Text streams are supported by `useChat`, `useCompletion`, and `useObject`. When you use `useChat` or `useCompletion`, you need to enable text streaming by setting the `streamProtocol` options to `text`.
|
||||
|
||||
You can generate text streams with `streamText` in the backend. When you call `toTextStreamResponse()` on the result object, a streaming HTTP response is returned.
|
||||
|
||||
Text streams only support basic text data. If you need to stream other types of data such as tool calls, use data streams.
|
||||
|
||||
### Text Stream Example
|
||||
|
||||
Here is a Next.js example that uses the text stream protocol:
|
||||
|
||||
app/page.tsx
|
||||
|
||||
```tsx
|
||||
'use client';
|
||||
|
||||
import { useChat } from '@ai-sdk/react';
|
||||
import { TextStreamChatTransport } from 'ai';
|
||||
import { useState } from 'react';
|
||||
|
||||
export default function Chat() {
|
||||
const [input, setInput] = useState('');
|
||||
const { messages, sendMessage } = useChat({
|
||||
transport: new TextStreamChatTransport({ api: '/api/chat' }),
|
||||
});
|
||||
|
||||
return (
|
||||
<div className="flex flex-col w-full max-w-md py-24 mx-auto stretch">
|
||||
{messages.map(message => (
|
||||
<div key={message.id} className="whitespace-pre-wrap">
|
||||
{message.role === 'user' ? 'User: ' : 'AI: '}
|
||||
{message.parts.map((part, i) => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return <div key={\`${message.id}-${i}\`}>{part.text}</div>;
|
||||
}
|
||||
})}
|
||||
</div>
|
||||
))}
|
||||
|
||||
<form
|
||||
onSubmit={e => {
|
||||
e.preventDefault();
|
||||
sendMessage({ text: input });
|
||||
setInput('');
|
||||
}}
|
||||
>
|
||||
<input
|
||||
className="fixed dark:bg-zinc-900 bottom-0 w-full max-w-md p-2 mb-8 border border-zinc-300 dark:border-zinc-800 rounded shadow-xl"
|
||||
value={input}
|
||||
placeholder="Say something..."
|
||||
onChange={e => setInput(e.currentTarget.value)}
|
||||
/>
|
||||
</form>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
## Data Stream Protocol
|
||||
|
||||
A data stream follows a special protocol that the AI SDK provides to send information to the frontend.
|
||||
|
||||
The data stream protocol uses Server-Sent Events (SSE) format for improved standardization, keep-alive through ping, reconnect capabilities, and better cache handling.
|
||||
|
||||
The following stream parts are currently supported:
|
||||
|
||||
### Message Start Part
|
||||
|
||||
Indicates the beginning of a new message with metadata.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"start","messageId":"..."}
|
||||
```
|
||||
|
||||
### Text Parts
|
||||
|
||||
Text content is streamed using a start/delta/end pattern with unique IDs for each text block.
|
||||
|
||||
#### Text Start Part
|
||||
|
||||
Indicates the beginning of a text block.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"text-start","id":"msg_68679a454370819ca74c8eb3d04379630dd1afb72306ca5d"}
|
||||
```
|
||||
|
||||
#### Text Delta Part
|
||||
|
||||
Contains incremental text content for the text block.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"text-delta","id":"msg_68679a454370819ca74c8eb3d04379630dd1afb72306ca5d","delta":"Hello"}
|
||||
```
|
||||
|
||||
#### Text End Part
|
||||
|
||||
Indicates the completion of a text block.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"text-end","id":"msg_68679a454370819ca74c8eb3d04379630dd1afb72306ca5d"}
|
||||
```
|
||||
|
||||
### Reasoning Parts
|
||||
|
||||
Reasoning content is streamed using a start/delta/end pattern with unique IDs for each reasoning block.
|
||||
|
||||
#### Reasoning Start Part
|
||||
|
||||
Indicates the beginning of a reasoning block.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"reasoning-start","id":"reasoning_123"}
|
||||
```
|
||||
|
||||
#### Reasoning Delta Part
|
||||
|
||||
Contains incremental reasoning content for the reasoning block.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"reasoning-delta","id":"reasoning_123","delta":"This is some reasoning"}
|
||||
```
|
||||
|
||||
#### Reasoning End Part
|
||||
|
||||
Indicates the completion of a reasoning block.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"reasoning-end","id":"reasoning_123"}
|
||||
```
|
||||
|
||||
### Source Parts
|
||||
|
||||
Source parts provide references to external content sources.
|
||||
|
||||
#### Source URL Part
|
||||
|
||||
References to external URLs.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"source-url","sourceId":"https://example.com","url":"https://example.com"}
|
||||
```
|
||||
|
||||
#### Source Document Part
|
||||
|
||||
References to documents or files.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"source-document","sourceId":"https://example.com","mediaType":"file","title":"Title"}
|
||||
```
|
||||
|
||||
### File Part
|
||||
|
||||
The file parts contain references to files with their media type.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"file","url":"https://example.com/file.png","mediaType":"image/png"}
|
||||
```
|
||||
|
||||
### Data Parts
|
||||
|
||||
Custom data parts allow streaming of arbitrary structured data with type-specific handling.
|
||||
|
||||
Format: Server-Sent Event with JSON object where the type includes a custom suffix
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"data-weather","data":{"location":"SF","temperature":100}}
|
||||
```
|
||||
|
||||
The `data-*` type pattern allows you to define custom data types that your frontend can handle specifically.
|
||||
|
||||
The error parts are appended to the message as they are received.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"error","errorText":"error message"}
|
||||
```
|
||||
|
||||
### Tool Input Start Part
|
||||
|
||||
Indicates the beginning of tool input streaming.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"tool-input-start","toolCallId":"call_fJdQDqnXeGxTmr4E3YPSR7Ar","toolName":"getWeatherInformation"}
|
||||
```
|
||||
|
||||
### Tool Input Delta Part
|
||||
|
||||
Incremental chunks of tool input as it's being generated.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"tool-input-delta","toolCallId":"call_fJdQDqnXeGxTmr4E3YPSR7Ar","inputTextDelta":"San Francisco"}
|
||||
```
|
||||
|
||||
### Tool Input Available Part
|
||||
|
||||
Indicates that tool input is complete and ready for execution.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"tool-input-available","toolCallId":"call_fJdQDqnXeGxTmr4E3YPSR7Ar","toolName":"getWeatherInformation","input":{"city":"San Francisco"}}
|
||||
```
|
||||
|
||||
### Tool Output Available Part
|
||||
|
||||
Contains the result of tool execution.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"tool-output-available","toolCallId":"call_fJdQDqnXeGxTmr4E3YPSR7Ar","output":{"city":"San Francisco","weather":"sunny"}}
|
||||
```
|
||||
|
||||
### Start Step Part
|
||||
|
||||
A part indicating the start of a step.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"start-step"}
|
||||
```
|
||||
|
||||
### Finish Step Part
|
||||
|
||||
A part indicating that a step (i.e., one LLM API call in the backend) has been completed.
|
||||
|
||||
This part is necessary to correctly process multiple stitched assistant calls, e.g. when calling tools in the backend, and using steps in `useChat` at the same time.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"finish-step"}
|
||||
```
|
||||
|
||||
### Finish Message Part
|
||||
|
||||
A part indicating the completion of a message.
|
||||
|
||||
Format: Server-Sent Event with JSON object
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: {"type":"finish"}
|
||||
```
|
||||
|
||||
### Stream Termination
|
||||
|
||||
The stream ends with a special `[DONE]` marker.
|
||||
|
||||
Format: Server-Sent Event with literal `[DONE]`
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
The data stream protocol is supported by `useChat` and `useCompletion` on the frontend and used by default.`useCompletion` only supports the `text` and `data` stream parts.
|
||||
|
||||
On the backend, you can use `toUIMessageStreamResponse()` from the `streamText` result object to return a streaming HTTP response.
|
||||
|
||||
### UI Message Stream Example
|
||||
|
||||
Here is a Next.js example that uses the UI message stream protocol:
|
||||
|
||||
app/page.tsx
|
||||
|
||||
```tsx
|
||||
'use client';
|
||||
|
||||
import { useChat } from '@ai-sdk/react';
|
||||
import { useState } from 'react';
|
||||
|
||||
export default function Chat() {
|
||||
const [input, setInput] = useState('');
|
||||
const { messages, sendMessage } = useChat();
|
||||
|
||||
return (
|
||||
<div className="flex flex-col w-full max-w-md py-24 mx-auto stretch">
|
||||
{messages.map(message => (
|
||||
<div key={message.id} className="whitespace-pre-wrap">
|
||||
{message.role === 'user' ? 'User: ' : 'AI: '}
|
||||
{message.parts.map((part, i) => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return <div key={\`${message.id}-${i}\`}>{part.text}</div>;
|
||||
}
|
||||
})}
|
||||
</div>
|
||||
))}
|
||||
|
||||
<form
|
||||
onSubmit={e => {
|
||||
e.preventDefault();
|
||||
sendMessage({ text: input });
|
||||
setInput('');
|
||||
}}
|
||||
>
|
||||
<input
|
||||
className="fixed dark:bg-zinc-900 bottom-0 w-full max-w-md p-2 mb-8 border border-zinc-300 dark:border-zinc-800 rounded shadow-xl"
|
||||
value={input}
|
||||
placeholder="Say something..."
|
||||
onChange={e => setInput(e.currentTarget.value)}
|
||||
/>
|
||||
</form>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
@ -9,7 +9,7 @@ import { validateModelId } from '@main/apiServer/utils'
|
||||
|
||||
import { GetAgentSessionResponse } from '../..'
|
||||
import { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface'
|
||||
import { transformSDKMessageToUIChunk } from './transform'
|
||||
import { transformSDKMessageToStreamParts } from './transform'
|
||||
|
||||
const require_ = createRequire(import.meta.url)
|
||||
const logger = loggerService.withContext('ClaudeCodeService')
|
||||
@ -157,7 +157,7 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
}
|
||||
|
||||
// Transform SDKMessage to UIMessageChunks
|
||||
const chunks = transformSDKMessageToUIChunk(message)
|
||||
const chunks = transformSDKMessageToStreamParts(message)
|
||||
for (const chunk of chunks) {
|
||||
stream.emit('data', {
|
||||
type: 'chunk',
|
||||
|
||||
@ -0,0 +1,34 @@
|
||||
// ported from https://github.com/ben-vargas/ai-sdk-provider-claude-code/blob/main/src/map-claude-code-finish-reason.ts#L22
|
||||
import type { LanguageModelV2FinishReason } from '@ai-sdk/provider'
|
||||
|
||||
/**
|
||||
* Maps Claude Code SDK result subtypes to AI SDK finish reasons.
|
||||
*
|
||||
* @param subtype - The result subtype from Claude Code SDK
|
||||
* @returns The corresponding AI SDK finish reason
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const finishReason = mapClaudeCodeFinishReason('error_max_turns');
|
||||
* // Returns: 'length'
|
||||
* ```
|
||||
*
|
||||
* @remarks
|
||||
* Mappings:
|
||||
* - 'success' -> 'stop' (normal completion)
|
||||
* - 'error_max_turns' -> 'length' (hit turn limit)
|
||||
* - 'error_during_execution' -> 'error' (execution error)
|
||||
* - default -> 'stop' (unknown subtypes treated as normal completion)
|
||||
*/
|
||||
export function mapClaudeCodeFinishReason(subtype?: string): LanguageModelV2FinishReason {
|
||||
switch (subtype) {
|
||||
case 'success':
|
||||
return 'stop'
|
||||
case 'error_max_turns':
|
||||
return 'length'
|
||||
case 'error_during_execution':
|
||||
return 'error'
|
||||
default:
|
||||
return 'stop'
|
||||
}
|
||||
}
|
||||
@ -1,21 +1,34 @@
|
||||
// This file is used to transform claude code json response to aisdk streaming format
|
||||
|
||||
import type { LanguageModelV2Usage } from '@ai-sdk/provider'
|
||||
import { SDKMessage } from '@anthropic-ai/claude-code'
|
||||
import { loggerService } from '@logger'
|
||||
import { ProviderMetadata, UIMessageChunk } from 'ai'
|
||||
import type { ProviderMetadata, TextStreamPart } from 'ai'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
import { mapClaudeCodeFinishReason } from './map-claude-code-finish-reason'
|
||||
|
||||
const logger = loggerService.withContext('ClaudeCodeTransform')
|
||||
|
||||
type AgentStreamPart = TextStreamPart<Record<string, any>>
|
||||
|
||||
const contentBlockState = new Map<
|
||||
string,
|
||||
{
|
||||
type: 'text' | 'tool-call'
|
||||
toolCallId?: string
|
||||
toolName?: string
|
||||
input?: string
|
||||
}
|
||||
>()
|
||||
|
||||
// Helper function to generate unique IDs for text blocks
|
||||
const generateMessageId = (): string => {
|
||||
return `msg_${uuidv4().replace(/-/g, '')}`
|
||||
}
|
||||
const generateMessageId = (): string => `msg_${uuidv4().replace(/-/g, '')}`
|
||||
|
||||
// Main transform function
|
||||
export function transformSDKMessageToUIChunk(sdkMessage: SDKMessage): UIMessageChunk[] {
|
||||
const chunks: UIMessageChunk[] = []
|
||||
|
||||
export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage): AgentStreamPart[] {
|
||||
const chunks: AgentStreamPart[] = []
|
||||
logger.debug('Transforming SDKMessage to stream parts', sdkMessage)
|
||||
switch (sdkMessage.type) {
|
||||
case 'assistant':
|
||||
case 'user':
|
||||
@ -35,7 +48,6 @@ export function transformSDKMessageToUIChunk(sdkMessage: SDKMessage): UIMessageC
|
||||
break
|
||||
|
||||
default:
|
||||
// Handle unknown message types gracefully
|
||||
logger.warn('Unknown SDKMessage type:', { type: (sdkMessage as any).type })
|
||||
break
|
||||
}
|
||||
@ -43,36 +55,45 @@ export function transformSDKMessageToUIChunk(sdkMessage: SDKMessage): UIMessageC
|
||||
return chunks
|
||||
}
|
||||
|
||||
function sdkMessageToProviderMetadata(message: SDKMessage): ProviderMetadata {
|
||||
const meta: ProviderMetadata = {
|
||||
message: message as Record<string, any>
|
||||
const sdkMessageToProviderMetadata = (message: SDKMessage): ProviderMetadata => {
|
||||
return {
|
||||
anthropic: {
|
||||
uuid: message.uuid || generateMessageId(),
|
||||
session_id: message.session_id
|
||||
},
|
||||
raw: message as Record<string, any>
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
function generateTextChunks(id: string, text: string, message: SDKMessage): UIMessageChunk[] {
|
||||
function generateTextChunks(id: string, text: string, message: SDKMessage): AgentStreamPart[] {
|
||||
const providerMetadata = sdkMessageToProviderMetadata(message)
|
||||
return [
|
||||
{
|
||||
type: 'text-start',
|
||||
id
|
||||
id,
|
||||
providerMetadata
|
||||
},
|
||||
{
|
||||
type: 'text-delta',
|
||||
id,
|
||||
delta: text
|
||||
text,
|
||||
providerMetadata
|
||||
},
|
||||
{
|
||||
type: 'text-end',
|
||||
id,
|
||||
providerMetadata: {
|
||||
rawMessage: sdkMessageToProviderMetadata(message)
|
||||
...providerMetadata,
|
||||
text: {
|
||||
value: text
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assistant' | 'user' }>): UIMessageChunk[] {
|
||||
const chunks: UIMessageChunk[] = []
|
||||
function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assistant' | 'user' }>): AgentStreamPart[] {
|
||||
const chunks: AgentStreamPart[] = []
|
||||
const messageId = message.uuid?.toString() || generateMessageId()
|
||||
|
||||
// handle normal text content
|
||||
@ -89,29 +110,25 @@ function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assi
|
||||
break
|
||||
case 'tool_use':
|
||||
chunks.push({
|
||||
type: 'tool-input-available',
|
||||
type: 'tool-call',
|
||||
toolCallId: block.id,
|
||||
toolName: block.name,
|
||||
input: block.input,
|
||||
providerExecuted: true,
|
||||
providerMetadata: {
|
||||
rawMessage: sdkMessageToProviderMetadata(message)
|
||||
}
|
||||
providerMetadata: sdkMessageToProviderMetadata(message)
|
||||
})
|
||||
break
|
||||
case 'tool_result':
|
||||
chunks.push({
|
||||
type: 'tool-output-available',
|
||||
toolCallId: block.tool_use_id,
|
||||
output: block.content,
|
||||
providerExecuted: true,
|
||||
dynamic: false,
|
||||
preliminary: false
|
||||
})
|
||||
// chunks.push({
|
||||
// type: 'tool-result',
|
||||
// toolCallId: block.tool_use_id,
|
||||
// output: block.content,
|
||||
// providerMetadata: sdkMessageToProviderMetadata(message)
|
||||
// })
|
||||
break
|
||||
default:
|
||||
logger.warn('Unknown content block type in user/assistant message:', {
|
||||
type: (block as any).type
|
||||
type: block.type
|
||||
})
|
||||
break
|
||||
}
|
||||
@ -122,9 +139,10 @@ function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assi
|
||||
}
|
||||
|
||||
// Handle stream events (real-time streaming)
|
||||
function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }>): UIMessageChunk[] {
|
||||
const chunks: UIMessageChunk[] = []
|
||||
function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }>): AgentStreamPart[] {
|
||||
const chunks: AgentStreamPart[] = []
|
||||
const event = message.event
|
||||
const blockKey = `${message.uuid ?? message.session_id ?? 'session'}:${event.index}`
|
||||
|
||||
switch (event.type) {
|
||||
case 'message_start':
|
||||
@ -132,69 +150,110 @@ function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }
|
||||
break
|
||||
|
||||
case 'content_block_start':
|
||||
if (event.content_block?.type === 'text') {
|
||||
chunks.push({
|
||||
type: 'text-start',
|
||||
id: event.index?.toString() || generateMessageId(),
|
||||
providerMetadata: {
|
||||
anthropic: {
|
||||
uuid: message.uuid,
|
||||
session_id: message.session_id,
|
||||
content_block_index: event.index
|
||||
},
|
||||
raw: sdkMessageToProviderMetadata(message)
|
||||
}
|
||||
})
|
||||
} else if (event.content_block?.type === 'tool_use') {
|
||||
chunks.push({
|
||||
type: 'tool-input-start',
|
||||
toolCallId: event.content_block.id,
|
||||
toolName: event.content_block.name,
|
||||
providerExecuted: true
|
||||
})
|
||||
const contentBlockType = event.content_block.type
|
||||
switch (contentBlockType) {
|
||||
case 'text': {
|
||||
contentBlockState.set(blockKey, { type: 'text' })
|
||||
chunks.push({
|
||||
type: 'text-start',
|
||||
id: String(event.index),
|
||||
providerMetadata: {
|
||||
...sdkMessageToProviderMetadata(message),
|
||||
anthropic: {
|
||||
uuid: message.uuid,
|
||||
session_id: message.session_id,
|
||||
content_block_index: event.index
|
||||
}
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'tool_use': {
|
||||
contentBlockState.set(blockKey, {
|
||||
type: 'tool-call',
|
||||
toolCallId: event.content_block.id,
|
||||
toolName: event.content_block.name,
|
||||
input: ''
|
||||
})
|
||||
chunks.push({
|
||||
type: 'tool-call',
|
||||
toolCallId: event.content_block.id,
|
||||
toolName: event.content_block.name,
|
||||
input: event.content_block.input,
|
||||
providerExecuted: true,
|
||||
providerMetadata: sdkMessageToProviderMetadata(message)
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
|
||||
case 'content_block_delta':
|
||||
if (event.delta?.type === 'text_delta') {
|
||||
chunks.push({
|
||||
type: 'text-delta',
|
||||
id: event.index?.toString() || generateMessageId(),
|
||||
delta: event.delta.text,
|
||||
providerMetadata: {
|
||||
anthropic: {
|
||||
uuid: message.uuid,
|
||||
session_id: message.session_id,
|
||||
content_block_index: event.index
|
||||
},
|
||||
raw: sdkMessageToProviderMetadata(message)
|
||||
switch (event.delta.type) {
|
||||
case 'text_delta': {
|
||||
chunks.push({
|
||||
type: 'text-delta',
|
||||
id: String(event.index),
|
||||
text: event.delta.text,
|
||||
providerMetadata: {
|
||||
...sdkMessageToProviderMetadata(message),
|
||||
anthropic: {
|
||||
uuid: message.uuid,
|
||||
session_id: message.session_id,
|
||||
content_block_index: event.index
|
||||
}
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
// case 'thinking_delta': {
|
||||
// chunks.push({
|
||||
// type: 'reasoning-delta',
|
||||
// id: String(event.index),
|
||||
// text: event.delta.thinking,
|
||||
// });
|
||||
// break
|
||||
// }
|
||||
// case 'signature_delta': {
|
||||
// if (blockType === 'thinking') {
|
||||
// chunks.push({
|
||||
// type: 'reasoning-delta',
|
||||
// id: String(event.index),
|
||||
// text: '',
|
||||
// providerMetadata: {
|
||||
// ...sdkMessageToProviderMetadata(message),
|
||||
// anthropic: {
|
||||
// uuid: message.uuid,
|
||||
// session_id: message.session_id,
|
||||
// content_block_index: event.index,
|
||||
// signature: event.delta.signature
|
||||
// }
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// break
|
||||
// }
|
||||
case 'input_json_delta': {
|
||||
const contentBlock = contentBlockState.get(blockKey)
|
||||
if (contentBlock && contentBlock.type === 'tool-call') {
|
||||
contentBlockState.set(blockKey, {
|
||||
...contentBlock,
|
||||
input: `${contentBlock.input ?? ''}${event.delta.partial_json ?? ''}`
|
||||
})
|
||||
}
|
||||
})
|
||||
} else if (event.delta?.type === 'input_json_delta') {
|
||||
chunks.push({
|
||||
type: 'tool-input-delta',
|
||||
toolCallId: (event as any).content_block?.id || '',
|
||||
inputTextDelta: event.delta.partial_json
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
|
||||
case 'content_block_stop': {
|
||||
// Determine if this was a text block or tool use block
|
||||
const blockId = event.index?.toString() || generateMessageId()
|
||||
chunks.push({
|
||||
type: 'text-end',
|
||||
id: blockId,
|
||||
providerMetadata: {
|
||||
anthropic: {
|
||||
uuid: message.uuid,
|
||||
session_id: message.session_id,
|
||||
content_block_index: event.index
|
||||
},
|
||||
raw: sdkMessageToProviderMetadata(message)
|
||||
}
|
||||
})
|
||||
break
|
||||
const contentBlock = contentBlockState.get(blockKey)
|
||||
if (contentBlock?.type === 'text') {
|
||||
chunks.push({
|
||||
type: 'text-end',
|
||||
id: String(event.index)
|
||||
})
|
||||
}
|
||||
contentBlockState.delete(blockKey)
|
||||
}
|
||||
|
||||
case 'message_delta':
|
||||
@ -214,80 +273,68 @@ function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }
|
||||
}
|
||||
|
||||
// Handle system messages
|
||||
function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>): UIMessageChunk[] {
|
||||
const chunks: UIMessageChunk[] = []
|
||||
|
||||
if (message.subtype === 'init') {
|
||||
chunks.push({
|
||||
type: 'start',
|
||||
messageId: message.session_id
|
||||
})
|
||||
|
||||
// System initialization - could emit as a data chunk or skip
|
||||
chunks.push({
|
||||
type: 'data-system' as any,
|
||||
data: {
|
||||
type: 'init',
|
||||
session_id: message.session_id,
|
||||
raw: message
|
||||
}
|
||||
})
|
||||
} else if (message.subtype === 'compact_boundary') {
|
||||
chunks.push({
|
||||
type: 'data-system' as any,
|
||||
data: {
|
||||
type: 'compact_boundary',
|
||||
metadata: message.compact_metadata,
|
||||
raw: message
|
||||
}
|
||||
})
|
||||
function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>): AgentStreamPart[] {
|
||||
const chunks: AgentStreamPart[] = []
|
||||
logger.debug('Received system message', {
|
||||
subtype: message.subtype
|
||||
})
|
||||
switch (message.subtype) {
|
||||
case 'init': {
|
||||
chunks.push({
|
||||
type: 'start'
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return chunks
|
||||
return []
|
||||
}
|
||||
|
||||
// Handle result messages (completion with usage stats)
|
||||
function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>): UIMessageChunk[] {
|
||||
const chunks: UIMessageChunk[] = []
|
||||
function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>): AgentStreamPart[] {
|
||||
const chunks: AgentStreamPart[] = []
|
||||
|
||||
const messageId = message.uuid
|
||||
let usage: LanguageModelV2Usage | undefined
|
||||
if ('usage' in message) {
|
||||
usage = {
|
||||
inputTokens:
|
||||
(message.usage.cache_creation_input_tokens ?? 0) +
|
||||
(message.usage.cache_read_input_tokens ?? 0) +
|
||||
(message.usage.input_tokens ?? 0),
|
||||
outputTokens: message.usage.output_tokens ?? 0,
|
||||
totalTokens:
|
||||
(message.usage.cache_creation_input_tokens ?? 0) +
|
||||
(message.usage.cache_read_input_tokens ?? 0) +
|
||||
(message.usage.input_tokens ?? 0) +
|
||||
(message.usage.output_tokens ?? 0)
|
||||
}
|
||||
}
|
||||
if (message.subtype === 'success') {
|
||||
// Emit final result data
|
||||
chunks.push({
|
||||
type: 'data-result' as any,
|
||||
id: messageId,
|
||||
data: message,
|
||||
transient: true
|
||||
})
|
||||
type: 'finish',
|
||||
totalUsage: usage,
|
||||
finishReason: mapClaudeCodeFinishReason(message.subtype),
|
||||
providerMetadata: {
|
||||
...sdkMessageToProviderMetadata(message),
|
||||
usage: message.usage,
|
||||
durationMs: message.duration_ms,
|
||||
costUsd: message.total_cost_usd,
|
||||
raw: message
|
||||
}
|
||||
} as AgentStreamPart)
|
||||
} else {
|
||||
// Handle error cases
|
||||
chunks.push({
|
||||
type: 'error',
|
||||
errorText: `${message.subtype}: Process failed after ${message.num_turns} turns`
|
||||
})
|
||||
}
|
||||
|
||||
// Emit usage and cost data
|
||||
chunks.push({
|
||||
type: 'data-usage' as any,
|
||||
data: {
|
||||
cost: message.total_cost_usd,
|
||||
usage: {
|
||||
input_tokens: message.usage.input_tokens,
|
||||
cache_creation_input_tokens: message.usage.cache_creation_input_tokens,
|
||||
cache_read_input_tokens: message.usage.cache_read_input_tokens,
|
||||
output_tokens: message.usage.output_tokens,
|
||||
service_tier: 'standard'
|
||||
error: {
|
||||
message: `${message.subtype}: Process failed after ${message.num_turns} turns`
|
||||
}
|
||||
}
|
||||
})
|
||||
} as AgentStreamPart)
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
// Convenience function to transform a stream of SDKMessages
|
||||
export function* transformSDKMessageStream(sdkMessages: SDKMessage[]): Generator<UIMessageChunk> {
|
||||
export function* transformSDKMessageStream(sdkMessages: SDKMessage[]): Generator<AgentStreamPart> {
|
||||
for (const sdkMessage of sdkMessages) {
|
||||
const chunks = transformSDKMessageToUIChunk(sdkMessage)
|
||||
const chunks = transformSDKMessageToStreamParts(sdkMessage)
|
||||
for (const chunk of chunks) {
|
||||
yield chunk
|
||||
}
|
||||
@ -297,9 +344,9 @@ export function* transformSDKMessageStream(sdkMessages: SDKMessage[]): Generator
|
||||
// Async version for async iterables
|
||||
export async function* transformSDKMessageStreamAsync(
|
||||
sdkMessages: AsyncIterable<SDKMessage>
|
||||
): AsyncGenerator<UIMessageChunk> {
|
||||
): AsyncGenerator<AgentStreamPart> {
|
||||
for await (const sdkMessage of sdkMessages) {
|
||||
const chunks = transformSDKMessageToUIChunk(sdkMessage)
|
||||
const chunks = transformSDKMessageToStreamParts(sdkMessage)
|
||||
for (const chunk of chunks) {
|
||||
yield chunk
|
||||
}
|
||||
|
||||
@ -32,16 +32,19 @@ export class AiSdkToChunkAdapter {
|
||||
private accumulate: boolean | undefined
|
||||
private isFirstChunk = true
|
||||
private enableWebSearch: boolean = false
|
||||
private onSessionUpdate?: (sessionId: string) => void
|
||||
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
mcpTools: MCPTool[] = [],
|
||||
accumulate?: boolean,
|
||||
enableWebSearch?: boolean
|
||||
enableWebSearch?: boolean,
|
||||
onSessionUpdate?: (sessionId: string) => void
|
||||
) {
|
||||
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
|
||||
this.accumulate = accumulate
|
||||
this.enableWebSearch = enableWebSearch || false
|
||||
this.onSessionUpdate = onSessionUpdate
|
||||
}
|
||||
|
||||
/**
|
||||
@ -108,6 +111,15 @@ export class AiSdkToChunkAdapter {
|
||||
chunk: TextStreamPart<any>,
|
||||
final: { text: string; reasoningContent: string; webSearchResults: AISDKWebSearchResult[]; reasoningId: string }
|
||||
) {
|
||||
const sessionId =
|
||||
(chunk.providerMetadata as any)?.anthropic?.session_id ??
|
||||
(chunk.providerMetadata as any)?.anthropic?.sessionId ??
|
||||
(chunk.providerMetadata as any)?.raw?.session_id
|
||||
|
||||
if (typeof sessionId === 'string' && sessionId) {
|
||||
this.onSessionUpdate?.(sessionId)
|
||||
}
|
||||
|
||||
logger.silly(`AI SDK chunk type: ${chunk.type}`, chunk)
|
||||
switch (chunk.type) {
|
||||
// === 文本相关事件 ===
|
||||
|
||||
@ -10,7 +10,6 @@ import {
|
||||
CreateAgentResponse,
|
||||
CreateAgentResponseSchema,
|
||||
CreateSessionForm,
|
||||
CreateSessionMessageRequest,
|
||||
CreateSessionRequest,
|
||||
CreateSessionResponse,
|
||||
CreateSessionResponseSchema,
|
||||
@ -225,16 +224,6 @@ export class AgentApiClient {
|
||||
}
|
||||
}
|
||||
|
||||
public async createMessage(agentId: string, sessionId: string, content: string): Promise<void> {
|
||||
const url = this.getSessionMessagesPath(agentId, sessionId)
|
||||
try {
|
||||
const payload = { content } satisfies CreateSessionMessageRequest
|
||||
await this.axios.post(url, payload)
|
||||
} catch (error) {
|
||||
throw processError(error, 'Failed to post message.')
|
||||
}
|
||||
}
|
||||
|
||||
public async getModels(props?: ApiModelsFilter): Promise<ApiModelsResponse> {
|
||||
const url = this.getModelsPath(props)
|
||||
try {
|
||||
|
||||
@ -118,6 +118,35 @@ export const SessionModal: React.FC<Props> = ({ agentId, session, trigger, isOpe
|
||||
}))
|
||||
}, [])
|
||||
|
||||
const addAccessiblePath = useCallback(async () => {
|
||||
try {
|
||||
const selected = await window.api.file.selectFolder()
|
||||
if (!selected) {
|
||||
return
|
||||
}
|
||||
setForm((prev) => {
|
||||
if (prev.accessible_paths.includes(selected)) {
|
||||
window.toast.warning(t('agent.session.accessible_paths.duplicate'))
|
||||
return prev
|
||||
}
|
||||
return {
|
||||
...prev,
|
||||
accessible_paths: [...prev.accessible_paths, selected]
|
||||
}
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to select accessible path:', error as Error)
|
||||
window.toast.error(t('agent.session.accessible_paths.select_failed'))
|
||||
}
|
||||
}, [t])
|
||||
|
||||
const removeAccessiblePath = useCallback((path: string) => {
|
||||
setForm((prev) => ({
|
||||
...prev,
|
||||
accessible_paths: prev.accessible_paths.filter((item) => item !== path)
|
||||
}))
|
||||
}, [])
|
||||
|
||||
const modelOptions = useMemo(() => {
|
||||
// mocked data. not final version
|
||||
return (models ?? []).map((model) => ({
|
||||
@ -152,6 +181,12 @@ export const SessionModal: React.FC<Props> = ({ agentId, session, trigger, isOpe
|
||||
return
|
||||
}
|
||||
|
||||
if (form.accessible_paths.length === 0) {
|
||||
window.toast.error(t('agent.session.accessible_paths.required'))
|
||||
loadingRef.current = false
|
||||
return
|
||||
}
|
||||
|
||||
if (isEditing(session)) {
|
||||
if (!session) {
|
||||
throw new Error('Agent is required for editing mode')
|
||||
@ -162,7 +197,8 @@ export const SessionModal: React.FC<Props> = ({ agentId, session, trigger, isOpe
|
||||
name: form.name,
|
||||
description: form.description,
|
||||
instructions: form.instructions,
|
||||
model: form.model
|
||||
model: form.model,
|
||||
accessible_paths: [...form.accessible_paths]
|
||||
} satisfies UpdateSessionForm
|
||||
|
||||
updateSession(updatePayload)
|
||||
@ -248,6 +284,34 @@ export const SessionModal: React.FC<Props> = ({ agentId, session, trigger, isOpe
|
||||
value={form.description ?? ''}
|
||||
onValueChange={onDescChange}
|
||||
/>
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm font-medium text-foreground">
|
||||
{t('agent.session.accessible_paths.label')}
|
||||
</span>
|
||||
<Button size="sm" variant="flat" onPress={addAccessiblePath}>
|
||||
{t('agent.session.accessible_paths.add')}
|
||||
</Button>
|
||||
</div>
|
||||
{form.accessible_paths.length > 0 ? (
|
||||
<div className="space-y-2">
|
||||
{form.accessible_paths.map((path) => (
|
||||
<div
|
||||
key={path}
|
||||
className="flex items-center justify-between gap-2 rounded-medium border border-default-200 px-3 py-2">
|
||||
<span className="truncate text-sm" title={path}>
|
||||
{path}
|
||||
</span>
|
||||
<Button size="sm" variant="light" color="danger" onPress={() => removeAccessiblePath(path)}>
|
||||
{t('common.remove')}
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-sm text-foreground-400">{t('agent.session.accessible_paths.empty')}</p>
|
||||
)}
|
||||
</div>
|
||||
{/* TODO: accessible paths */}
|
||||
<Textarea label={t('common.prompt')} value={form.instructions ?? ''} onValueChange={onInstChange} />
|
||||
</ModalBody>
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
import { AgentSessionMessageEntity, UpdateSessionForm } from '@renderer/types'
|
||||
import { cloneDeep } from 'lodash'
|
||||
import { useCallback } from 'react'
|
||||
import { useAppDispatch } from '@renderer/store'
|
||||
import { removeManyBlocks,upsertManyBlocks } from '@renderer/store/messageBlock'
|
||||
import { newMessagesActions } from '@renderer/store/newMessage'
|
||||
import { AgentPersistedMessage, UpdateSessionForm } from '@renderer/types'
|
||||
import { buildAgentSessionTopicId } from '@renderer/utils/agentSession'
|
||||
import { useCallback, useEffect, useMemo, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import useSWR from 'swr'
|
||||
|
||||
@ -10,6 +13,9 @@ export const useSession = (agentId: string, sessionId: string) => {
|
||||
const { t } = useTranslation()
|
||||
const client = useAgentClient()
|
||||
const key = client.getSessionPaths(agentId).withId(sessionId)
|
||||
const dispatch = useAppDispatch()
|
||||
const sessionTopicId = useMemo(() => buildAgentSessionTopicId(sessionId), [sessionId])
|
||||
const blockIdsRef = useRef<string[]>([])
|
||||
|
||||
const fetcher = async () => {
|
||||
const data = await client.getSession(agentId, sessionId)
|
||||
@ -17,6 +23,38 @@ export const useSession = (agentId: string, sessionId: string) => {
|
||||
}
|
||||
const { data, error, isLoading, mutate } = useSWR(key, fetcher)
|
||||
|
||||
useEffect(() => {
|
||||
const messages = data?.messages ?? []
|
||||
if (!messages.length) {
|
||||
dispatch(newMessagesActions.messagesReceived({ topicId: sessionTopicId, messages: [] }))
|
||||
blockIdsRef.current = []
|
||||
return
|
||||
}
|
||||
|
||||
const persistedEntries = messages
|
||||
.map((entity) => entity.content as AgentPersistedMessage | undefined)
|
||||
.filter((entry): entry is AgentPersistedMessage => Boolean(entry))
|
||||
|
||||
const allBlocks = persistedEntries.flatMap((entry) => entry.blocks)
|
||||
if (allBlocks.length > 0) {
|
||||
dispatch(upsertManyBlocks(allBlocks))
|
||||
}
|
||||
|
||||
blockIdsRef.current = allBlocks.map((block) => block.id)
|
||||
|
||||
const messageRecords = persistedEntries.map((entry) => entry.message)
|
||||
dispatch(newMessagesActions.messagesReceived({ topicId: sessionTopicId, messages: messageRecords }))
|
||||
}, [data?.messages, dispatch, sessionTopicId])
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (blockIdsRef.current.length > 0) {
|
||||
dispatch(removeManyBlocks(blockIdsRef.current))
|
||||
}
|
||||
dispatch(newMessagesActions.clearTopicMessages(sessionTopicId))
|
||||
}
|
||||
}, [dispatch, sessionTopicId])
|
||||
|
||||
const updateSession = useCallback(
|
||||
async (form: UpdateSessionForm) => {
|
||||
if (!agentId) return
|
||||
@ -30,53 +68,11 @@ export const useSession = (agentId: string, sessionId: string) => {
|
||||
[agentId, client, mutate, t]
|
||||
)
|
||||
|
||||
const createSessionMessage = useCallback(
|
||||
async (content: string) => {
|
||||
if (!agentId || !sessionId || !data) return
|
||||
const origin = cloneDeep(data)
|
||||
const newMessageDraft = {
|
||||
id: 77777,
|
||||
session_id: '',
|
||||
role: 'user',
|
||||
content: {
|
||||
role: 'user',
|
||||
content: content,
|
||||
providerOptions: undefined
|
||||
},
|
||||
agent_session_id: '',
|
||||
created_at: '',
|
||||
updated_at: ''
|
||||
} satisfies AgentSessionMessageEntity
|
||||
try {
|
||||
mutate(
|
||||
(prev) => ({
|
||||
...prev,
|
||||
accessible_paths: prev?.accessible_paths ?? [],
|
||||
model: prev?.model ?? '',
|
||||
id: prev?.id ?? '',
|
||||
agent_id: prev?.id ?? '',
|
||||
agent_type: prev?.agent_type ?? 'claude-code',
|
||||
created_at: prev?.created_at ?? '',
|
||||
updated_at: prev?.updated_at ?? '',
|
||||
messages: [...(prev?.messages ?? []), newMessageDraft]
|
||||
}),
|
||||
false
|
||||
)
|
||||
await client.createMessage(agentId, sessionId, content)
|
||||
} catch (error) {
|
||||
mutate(origin)
|
||||
window.toast.error(t('common.errors.create_message'))
|
||||
}
|
||||
},
|
||||
[agentId, sessionId, data, mutate, client, t]
|
||||
)
|
||||
|
||||
return {
|
||||
session: data,
|
||||
messages: data?.messages ?? [],
|
||||
error,
|
||||
isLoading,
|
||||
updateSession,
|
||||
createSessionMessage
|
||||
mutate
|
||||
}
|
||||
}
|
||||
|
||||
@ -51,6 +51,14 @@
|
||||
"error": {
|
||||
"failed": "Failed to update the session"
|
||||
}
|
||||
},
|
||||
"accessible_paths": {
|
||||
"label": "Accessible directories",
|
||||
"add": "Add directory",
|
||||
"empty": "Select at least one directory that the agent can access.",
|
||||
"required": "Please select at least one accessible directory.",
|
||||
"duplicate": "This directory is already included.",
|
||||
"select_failed": "Failed to select directory."
|
||||
}
|
||||
},
|
||||
"update": {
|
||||
|
||||
@ -51,6 +51,14 @@
|
||||
"error": {
|
||||
"failed": "更新会话失败"
|
||||
}
|
||||
},
|
||||
"accessible_paths": {
|
||||
"label": "工作目录",
|
||||
"add": "添加目录",
|
||||
"empty": "请选择至少一个智能体可访问的目录。",
|
||||
"required": "请至少选择一个可访问的目录。",
|
||||
"duplicate": "该目录已添加。",
|
||||
"select_failed": "选择目录失败"
|
||||
}
|
||||
},
|
||||
"update": {
|
||||
|
||||
@ -51,6 +51,14 @@
|
||||
"error": {
|
||||
"failed": "[to be translated]:Failed to update the session"
|
||||
}
|
||||
},
|
||||
"accessible_paths": {
|
||||
"label": "[to be translated]:Accessible directories",
|
||||
"add": "[to be translated]:Add directory",
|
||||
"empty": "[to be translated]:Select at least one directory that the agent can access.",
|
||||
"required": "[to be translated]:Please select at least one accessible directory.",
|
||||
"duplicate": "[to be translated]:This directory is already included.",
|
||||
"select_failed": "[to be translated]:Failed to select directory."
|
||||
}
|
||||
},
|
||||
"update": {
|
||||
|
||||
@ -51,6 +51,14 @@
|
||||
"error": {
|
||||
"failed": "[to be translated]:Failed to update the session"
|
||||
}
|
||||
},
|
||||
"accessible_paths": {
|
||||
"label": "[to be translated]:Accessible directories",
|
||||
"add": "[to be translated]:Add directory",
|
||||
"empty": "[to be translated]:Select at least one directory that the agent can access.",
|
||||
"required": "[to be translated]:Please select at least one accessible directory.",
|
||||
"duplicate": "[to be translated]:This directory is already included.",
|
||||
"select_failed": "[to be translated]:Failed to select directory."
|
||||
}
|
||||
},
|
||||
"update": {
|
||||
|
||||
@ -51,6 +51,14 @@
|
||||
"error": {
|
||||
"failed": "[to be translated]:Failed to update the session"
|
||||
}
|
||||
},
|
||||
"accessible_paths": {
|
||||
"label": "[to be translated]:Accessible directories",
|
||||
"add": "[to be translated]:Add directory",
|
||||
"empty": "[to be translated]:Select at least one directory that the agent can access.",
|
||||
"required": "[to be translated]:Please select at least one accessible directory.",
|
||||
"duplicate": "[to be translated]:This directory is already included.",
|
||||
"select_failed": "[to be translated]:Failed to select directory."
|
||||
}
|
||||
},
|
||||
"update": {
|
||||
|
||||
@ -51,6 +51,14 @@
|
||||
"error": {
|
||||
"failed": "[to be translated]:Failed to update the session"
|
||||
}
|
||||
},
|
||||
"accessible_paths": {
|
||||
"label": "[to be translated]:Accessible directories",
|
||||
"add": "[to be translated]:Add directory",
|
||||
"empty": "[to be translated]:Select at least one directory that the agent can access.",
|
||||
"required": "[to be translated]:Please select at least one accessible directory.",
|
||||
"duplicate": "[to be translated]:This directory is already included.",
|
||||
"select_failed": "[to be translated]:Failed to select directory."
|
||||
}
|
||||
},
|
||||
"update": {
|
||||
|
||||
@ -51,6 +51,14 @@
|
||||
"error": {
|
||||
"failed": "[to be translated]:Failed to update the session"
|
||||
}
|
||||
},
|
||||
"accessible_paths": {
|
||||
"label": "[to be translated]:Accessible directories",
|
||||
"add": "[to be translated]:Add directory",
|
||||
"empty": "[to be translated]:Select at least one directory that the agent can access.",
|
||||
"required": "[to be translated]:Please select at least one accessible directory.",
|
||||
"duplicate": "[to be translated]:This directory is already included.",
|
||||
"select_failed": "[to be translated]:Failed to select directory."
|
||||
}
|
||||
},
|
||||
"update": {
|
||||
|
||||
@ -51,6 +51,14 @@
|
||||
"error": {
|
||||
"failed": "[to be translated]:Failed to update the session"
|
||||
}
|
||||
},
|
||||
"accessible_paths": {
|
||||
"label": "[to be translated]:Accessible directories",
|
||||
"add": "[to be translated]:Add directory",
|
||||
"empty": "[to be translated]:Select at least one directory that the agent can access.",
|
||||
"required": "[to be translated]:Please select at least one accessible directory.",
|
||||
"duplicate": "[to be translated]:This directory is already included.",
|
||||
"select_failed": "[to be translated]:Failed to select directory."
|
||||
}
|
||||
},
|
||||
"update": {
|
||||
|
||||
@ -4,13 +4,20 @@ import { useSession } from '@renderer/hooks/agents/useSession'
|
||||
import { useSettings } from '@renderer/hooks/useSettings'
|
||||
import { useTimer } from '@renderer/hooks/useTimer'
|
||||
import PasteService from '@renderer/services/PasteService'
|
||||
import { useAppDispatch } from '@renderer/store'
|
||||
import { sendMessage as dispatchSendMessage } from '@renderer/store/thunk/messageThunk'
|
||||
import type { Assistant, Message, MessageBlock, Model, Topic } from '@renderer/types'
|
||||
import { MessageBlockStatus } from '@renderer/types/newMessage'
|
||||
import { classNames } from '@renderer/utils'
|
||||
import { buildAgentSessionTopicId } from '@renderer/utils/agentSession'
|
||||
import { getSendMessageShortcutLabel, isSendMessageKeyPressed } from '@renderer/utils/input'
|
||||
import { createMainTextBlock, createMessage } from '@renderer/utils/messageUtils/create'
|
||||
import TextArea, { TextAreaRef } from 'antd/es/input/TextArea'
|
||||
import { isEmpty } from 'lodash'
|
||||
import React, { CSSProperties, FC, useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
import { v4 as uuid } from 'uuid'
|
||||
|
||||
import NarrowLayout from '../Messages/NarrowLayout'
|
||||
import SendMessageButton from './SendMessageButton'
|
||||
@ -27,7 +34,7 @@ const _text = ''
|
||||
const AgentSessionInputbar: FC<Props> = ({ agentId, sessionId }) => {
|
||||
const [text, setText] = useState(_text)
|
||||
const [inputFocus, setInputFocus] = useState(false)
|
||||
const { createSessionMessage } = useSession(agentId, sessionId)
|
||||
const { session } = useSession(agentId, sessionId)
|
||||
|
||||
const { sendMessageShortcut, fontSize, enableSpellCheck } = useSettings()
|
||||
const textareaRef = useRef<TextAreaRef>(null)
|
||||
@ -36,6 +43,8 @@ const AgentSessionInputbar: FC<Props> = ({ agentId, sessionId }) => {
|
||||
const containerRef = useRef(null)
|
||||
|
||||
const { setTimeoutTimer } = useTimer()
|
||||
const dispatch = useAppDispatch()
|
||||
const sessionTopicId = buildAgentSessionTopicId(sessionId)
|
||||
|
||||
const focusTextarea = useCallback(() => {
|
||||
textareaRef.current?.focus()
|
||||
@ -93,14 +102,65 @@ const AgentSessionInputbar: FC<Props> = ({ agentId, sessionId }) => {
|
||||
logger.info('Starting to send message')
|
||||
|
||||
try {
|
||||
createSessionMessage(text)
|
||||
// Clear input
|
||||
const userMessageId = uuid()
|
||||
const mainBlock = createMainTextBlock(userMessageId, text, {
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
})
|
||||
const userMessageBlocks: MessageBlock[] = [mainBlock]
|
||||
|
||||
const model: Model | undefined = session?.model
|
||||
? {
|
||||
id: session.model,
|
||||
name: session.model,
|
||||
provider: 'agent-session',
|
||||
group: 'agent-session'
|
||||
}
|
||||
: undefined
|
||||
|
||||
const userMessage: Message = createMessage('user', sessionTopicId, agentId, {
|
||||
id: userMessageId,
|
||||
blocks: userMessageBlocks.map((block) => block.id),
|
||||
model,
|
||||
modelId: model?.id
|
||||
})
|
||||
|
||||
const assistantStub: Assistant = {
|
||||
id: session?.agent_id ?? agentId,
|
||||
name: session?.name ?? 'Agent Session',
|
||||
prompt: session?.instructions ?? '',
|
||||
topics: [] as Topic[],
|
||||
type: 'agent-session',
|
||||
model,
|
||||
defaultModel: model,
|
||||
tags: [],
|
||||
enableWebSearch: false
|
||||
}
|
||||
|
||||
dispatch(
|
||||
dispatchSendMessage(userMessage, userMessageBlocks, assistantStub, sessionTopicId, {
|
||||
agentId,
|
||||
sessionId
|
||||
})
|
||||
)
|
||||
|
||||
setText('')
|
||||
setTimeoutTimer('sendMessage_1', () => setText(''), 500)
|
||||
} catch (error) {
|
||||
logger.warn('Failed to send message:', error as Error)
|
||||
}
|
||||
}, [createSessionMessage, inputEmpty, setTimeoutTimer, text])
|
||||
}, [
|
||||
agentId,
|
||||
dispatch,
|
||||
inputEmpty,
|
||||
session?.agent_id,
|
||||
session?.instructions,
|
||||
session?.model,
|
||||
session?.name,
|
||||
sessionId,
|
||||
sessionTopicId,
|
||||
setTimeoutTimer,
|
||||
text
|
||||
])
|
||||
|
||||
const onChange = useCallback((e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
const newText = e.target.value
|
||||
|
||||
@ -2,8 +2,11 @@ import { loggerService } from '@logger'
|
||||
import ContextMenu from '@renderer/components/ContextMenu'
|
||||
import Scrollbar from '@renderer/components/Scrollbar'
|
||||
import { useSession } from '@renderer/hooks/agents/useSession'
|
||||
import { ModelMessage } from 'ai'
|
||||
import { memo } from 'react'
|
||||
import Blocks from '@renderer/pages/home/Messages/Blocks'
|
||||
import { useAppSelector } from '@renderer/store'
|
||||
import { selectMessagesForTopic } from '@renderer/store/newMessage'
|
||||
import { buildAgentSessionTopicId } from '@renderer/utils/agentSession'
|
||||
import { useMemo } from 'react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
import NarrowLayout from './NarrowLayout'
|
||||
@ -16,31 +19,29 @@ type Props = {
|
||||
}
|
||||
|
||||
const AgentSessionMessages: React.FC<Props> = ({ agentId, sessionId }) => {
|
||||
const { messages } = useSession(agentId, sessionId)
|
||||
const { session } = useSession(agentId, sessionId)
|
||||
const sessionTopicId = useMemo(() => buildAgentSessionTopicId(sessionId), [sessionId])
|
||||
const messages = useAppSelector((state) => selectMessagesForTopic(state, sessionTopicId))
|
||||
|
||||
const getTextFromContent = (content: string | ModelMessage): string => {
|
||||
logger.debug('content', { content })
|
||||
if (typeof content === 'string') {
|
||||
return content
|
||||
} else if (typeof content.content === 'string') {
|
||||
return content.content
|
||||
} else {
|
||||
return content.content
|
||||
.filter((part) => part.type === 'text')
|
||||
.map((part) => part.text)
|
||||
.join('\n')
|
||||
}
|
||||
}
|
||||
logger.silly('Rendering agent session messages', {
|
||||
sessionId,
|
||||
messageCount: messages.length
|
||||
})
|
||||
|
||||
return (
|
||||
<MessagesContainer id="messages" className="messages-container">
|
||||
<NarrowLayout style={{ display: 'flex', flexDirection: 'column-reverse' }}>
|
||||
<ContextMenu>
|
||||
<ScrollContainer>
|
||||
{messages.toReversed().map((message) => {
|
||||
const content = getTextFromContent(message.content)
|
||||
return <div key={message.id}>{content}</div>
|
||||
})}
|
||||
{messages
|
||||
.slice()
|
||||
.reverse()
|
||||
.map((message) => (
|
||||
<MessageRow key={message.id} $role={message.role}>
|
||||
<Blocks blocks={message.blocks ?? []} message={message} />
|
||||
</MessageRow>
|
||||
))}
|
||||
{!messages.length && <EmptyState>{session ? 'No messages yet.' : 'Loading session...'}</EmptyState>}
|
||||
</ScrollContainer>
|
||||
</ContextMenu>
|
||||
</NarrowLayout>
|
||||
@ -51,12 +52,29 @@ const AgentSessionMessages: React.FC<Props> = ({ agentId, sessionId }) => {
|
||||
const ScrollContainer = styled.div`
|
||||
display: flex;
|
||||
flex-direction: column-reverse;
|
||||
gap: 12px;
|
||||
padding: 10px 10px 20px;
|
||||
.multi-select-mode & {
|
||||
padding-bottom: 60px;
|
||||
}
|
||||
`
|
||||
|
||||
const MessageRow = styled.div<{ $role: string }>`
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: ${(props) => (props.$role === 'user' ? 'flex-end' : 'flex-start')};
|
||||
.block-wrapper {
|
||||
max-width: 700px;
|
||||
}
|
||||
`
|
||||
|
||||
const EmptyState = styled.div`
|
||||
color: var(--color-text-3);
|
||||
font-size: 12px;
|
||||
text-align: center;
|
||||
padding: 20px 0;
|
||||
`
|
||||
|
||||
interface ContainerProps {
|
||||
$right?: boolean
|
||||
}
|
||||
@ -69,4 +87,4 @@ const MessagesContainer = styled(Scrollbar)<ContainerProps>`
|
||||
position: relative;
|
||||
`
|
||||
|
||||
export default memo(AgentSessionMessages)
|
||||
export default AgentSessionMessages
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { AiSdkToChunkAdapter } from '@renderer/aiCore/chunk/AiSdkToChunkAdapter'
|
||||
import db from '@renderer/databases'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { BlockManager } from '@renderer/services/messageStreaming/BlockManager'
|
||||
@ -8,18 +9,23 @@ import { endSpan } from '@renderer/services/SpanManagerService'
|
||||
import { createStreamProcessor, type StreamProcessorCallbacks } from '@renderer/services/StreamProcessingService'
|
||||
import store from '@renderer/store'
|
||||
import { updateTopicUpdatedAt } from '@renderer/store/assistants'
|
||||
import { type Assistant, type FileMetadata, type Model, type Topic } from '@renderer/types'
|
||||
import { type ApiServerConfig, type Assistant, type FileMetadata, type Model, type Topic } from '@renderer/types'
|
||||
import type { AgentPersistedMessage } from '@renderer/types/agent'
|
||||
import type { FileMessageBlock, ImageMessageBlock, Message, MessageBlock } from '@renderer/types/newMessage'
|
||||
import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
|
||||
import { uuid } from '@renderer/utils'
|
||||
import { addAbortController } from '@renderer/utils/abortController'
|
||||
import { isAgentSessionTopicId } from '@renderer/utils/agentSession'
|
||||
import {
|
||||
createAssistantMessage,
|
||||
createTranslationBlock,
|
||||
resetAssistantMessage
|
||||
} from '@renderer/utils/messageUtils/create'
|
||||
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { getTopicQueue, waitForTopicQueue } from '@renderer/utils/queue'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import type { TextStreamPart } from 'ai'
|
||||
import { t } from 'i18next'
|
||||
import { isEmpty, throttle } from 'lodash'
|
||||
import { LRUCache } from 'lru-cache'
|
||||
@ -35,9 +41,158 @@ const finishTopicLoading = async (topicId: string) => {
|
||||
store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
store.dispatch(newMessagesActions.setTopicFulfilled({ topicId, fulfilled: true }))
|
||||
}
|
||||
|
||||
type AgentSessionContext = {
|
||||
agentId: string
|
||||
sessionId: string
|
||||
}
|
||||
|
||||
const buildAgentBaseURL = (apiServer: ApiServerConfig) => {
|
||||
const hasProtocol = apiServer.host.startsWith('http://') || apiServer.host.startsWith('https://')
|
||||
const baseHost = hasProtocol ? apiServer.host : `http://${apiServer.host}`
|
||||
const portSegment = apiServer.port ? `:${apiServer.port}` : ''
|
||||
return `${baseHost}${portSegment}`
|
||||
}
|
||||
|
||||
const createSSEReadableStream = (
|
||||
source: ReadableStream<Uint8Array>,
|
||||
signal: AbortSignal
|
||||
): ReadableStream<TextStreamPart<Record<string, any>>> => {
|
||||
return new ReadableStream<TextStreamPart<Record<string, any>>>({
|
||||
start(controller) {
|
||||
const reader = source.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
|
||||
const cancelReader = (reason?: any) => reader.cancel(reason).catch(() => {})
|
||||
|
||||
const abortHandler = () => {
|
||||
cancelReader(signal.reason ?? 'aborted')
|
||||
controller.error(new DOMException('Aborted', 'AbortError'))
|
||||
}
|
||||
|
||||
if (signal.aborted) {
|
||||
abortHandler()
|
||||
return
|
||||
}
|
||||
|
||||
signal.addEventListener('abort', abortHandler, { once: true })
|
||||
|
||||
const emitEvent = (eventString: string): boolean => {
|
||||
const lines = eventString.split(/\r?\n/)
|
||||
let dataPayload = ''
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data:')) {
|
||||
dataPayload += line.slice(5).trimStart()
|
||||
}
|
||||
}
|
||||
|
||||
if (!dataPayload) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (dataPayload === '[DONE]') {
|
||||
signal.removeEventListener('abort', abortHandler)
|
||||
cancelReader()
|
||||
controller.close()
|
||||
return true
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(dataPayload) as TextStreamPart<Record<string, any>>
|
||||
controller.enqueue(parsed)
|
||||
} catch (error) {
|
||||
logger.warn('Failed to parse agent SSE chunk', { dataPayload })
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const pump = async () => {
|
||||
try {
|
||||
while (true) {
|
||||
const { value, done } = await reader.read()
|
||||
if (done) break
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
|
||||
let separatorIndex = buffer.indexOf('\n\n')
|
||||
while (separatorIndex !== -1) {
|
||||
const rawEvent = buffer.slice(0, separatorIndex).trim()
|
||||
buffer = buffer.slice(separatorIndex + 2)
|
||||
if (rawEvent) {
|
||||
const shouldStop = emitEvent(rawEvent)
|
||||
if (shouldStop) {
|
||||
return
|
||||
}
|
||||
}
|
||||
separatorIndex = buffer.indexOf('\n\n')
|
||||
}
|
||||
}
|
||||
|
||||
buffer += decoder.decode()
|
||||
if (buffer.trim()) {
|
||||
emitEvent(buffer.trim())
|
||||
}
|
||||
signal.removeEventListener('abort', abortHandler)
|
||||
controller.close()
|
||||
} catch (error) {
|
||||
signal.removeEventListener('abort', abortHandler)
|
||||
controller.error(error)
|
||||
}
|
||||
}
|
||||
|
||||
pump().catch((error) => {
|
||||
signal.removeEventListener('abort', abortHandler)
|
||||
controller.error(error)
|
||||
})
|
||||
},
|
||||
cancel(reason) {
|
||||
return source.cancel(reason).catch(() => {})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const createAgentMessageStream = async (
|
||||
apiServer: ApiServerConfig,
|
||||
agentSession: AgentSessionContext,
|
||||
content: string,
|
||||
signal: AbortSignal
|
||||
): Promise<ReadableStream<TextStreamPart<Record<string, any>>>> => {
|
||||
if (!apiServer.enabled) {
|
||||
throw new Error('Agent API server is disabled')
|
||||
}
|
||||
|
||||
const baseURL = buildAgentBaseURL(apiServer)
|
||||
const url = `${baseURL}/v1/agents/${agentSession.agentId}/sessions/${agentSession.sessionId}/messages`
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiServer.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
Accept: 'text/event-stream',
|
||||
'Cache-Control': 'no-cache'
|
||||
},
|
||||
body: JSON.stringify({ content }),
|
||||
signal
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text().catch(() => '')
|
||||
throw new Error(errorText || `Failed to stream agent message: ${response.status}`)
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error('Agent message stream has no body')
|
||||
}
|
||||
|
||||
return createSSEReadableStream(response.body, signal)
|
||||
}
|
||||
// TODO: 后续可以将db操作移到Listener Middleware中
|
||||
export const saveMessageAndBlocksToDB = async (message: Message, blocks: MessageBlock[], messageIndex: number = -1) => {
|
||||
try {
|
||||
if (isAgentSessionTopicId(message.topicId)) {
|
||||
return
|
||||
}
|
||||
if (blocks.length > 0) {
|
||||
await db.message_blocks.bulkPut(blocks)
|
||||
}
|
||||
@ -70,6 +225,9 @@ const updateExistingMessageAndBlocksInDB = async (
|
||||
updatedBlocks: MessageBlock[]
|
||||
) => {
|
||||
try {
|
||||
if (isAgentSessionTopicId(updatedMessage.topicId)) {
|
||||
return
|
||||
}
|
||||
await db.transaction('rw', db.topics, db.message_blocks, async () => {
|
||||
// Always update blocks if provided
|
||||
if (updatedBlocks.length > 0) {
|
||||
@ -244,6 +402,157 @@ const saveUpdatedBlockToDB = async (
|
||||
}
|
||||
}
|
||||
|
||||
interface AgentStreamParams {
|
||||
topicId: string
|
||||
assistant: Assistant
|
||||
assistantMessage: Message
|
||||
agentSession: AgentSessionContext
|
||||
userMessageId: string
|
||||
}
|
||||
|
||||
const fetchAndProcessAgentResponseImpl = async (
|
||||
dispatch: AppDispatch,
|
||||
getState: () => RootState,
|
||||
{ topicId, assistant, assistantMessage, agentSession, userMessageId }: AgentStreamParams
|
||||
) => {
|
||||
let callbacks: StreamProcessorCallbacks = {}
|
||||
try {
|
||||
dispatch(newMessagesActions.setTopicLoading({ topicId, loading: true }))
|
||||
|
||||
const blockManager = new BlockManager({
|
||||
dispatch,
|
||||
getState,
|
||||
saveUpdatedBlockToDB,
|
||||
saveUpdatesToDB,
|
||||
assistantMsgId: assistantMessage.id,
|
||||
topicId,
|
||||
throttledBlockUpdate,
|
||||
cancelThrottledBlockUpdate
|
||||
})
|
||||
|
||||
callbacks = createCallbacks({
|
||||
blockManager,
|
||||
dispatch,
|
||||
getState,
|
||||
topicId,
|
||||
assistantMsgId: assistantMessage.id,
|
||||
saveUpdatesToDB,
|
||||
assistant
|
||||
})
|
||||
|
||||
const streamProcessorCallbacks = createStreamProcessor(callbacks)
|
||||
|
||||
const state = getState()
|
||||
const userMessageEntity = state.messages.entities[userMessageId]
|
||||
const userContent = userMessageEntity ? getMainTextContent(userMessageEntity) : ''
|
||||
|
||||
const abortController = new AbortController()
|
||||
addAbortController(userMessageId, () => abortController.abort())
|
||||
|
||||
const stream = await createAgentMessageStream(
|
||||
state.settings.apiServer,
|
||||
agentSession,
|
||||
userContent,
|
||||
abortController.signal
|
||||
)
|
||||
|
||||
let latestAgentSessionId = ''
|
||||
const adapter = new AiSdkToChunkAdapter(streamProcessorCallbacks, [], false, false, (sessionId) => {
|
||||
latestAgentSessionId = sessionId
|
||||
})
|
||||
|
||||
await adapter.processStream({
|
||||
fullStream: stream,
|
||||
text: Promise.resolve('')
|
||||
})
|
||||
|
||||
await persistAgentExchange({
|
||||
getState,
|
||||
agentSession,
|
||||
userMessageId,
|
||||
assistantMessageId: assistantMessage.id,
|
||||
latestAgentSessionId
|
||||
})
|
||||
} catch (error: any) {
|
||||
logger.error('Error in fetchAndProcessAgentResponseImpl:', error)
|
||||
try {
|
||||
callbacks.onError?.(error)
|
||||
} catch (callbackError) {
|
||||
logger.error('Error in agent onError callback:', callbackError as Error)
|
||||
} finally {
|
||||
dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
interface PersistAgentExchangeParams {
|
||||
getState: () => RootState
|
||||
agentSession: AgentSessionContext
|
||||
userMessageId: string
|
||||
assistantMessageId: string
|
||||
latestAgentSessionId: string
|
||||
}
|
||||
|
||||
const persistAgentExchange = async ({
|
||||
getState,
|
||||
agentSession,
|
||||
userMessageId,
|
||||
assistantMessageId,
|
||||
latestAgentSessionId
|
||||
}: PersistAgentExchangeParams) => {
|
||||
if (!window.electron?.ipcRenderer) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const state = getState()
|
||||
const userMessage = state.messages.entities[userMessageId]
|
||||
const assistantMessage = state.messages.entities[assistantMessageId]
|
||||
|
||||
if (!userMessage || !assistantMessage) {
|
||||
logger.warn('persistAgentExchange: missing user or assistant message entity')
|
||||
return
|
||||
}
|
||||
|
||||
const userPersistedPayload = createPersistedMessagePayload(userMessage, state)
|
||||
const assistantPersistedPayload = createPersistedMessagePayload(assistantMessage, state)
|
||||
|
||||
await window.electron.ipcRenderer.invoke(IpcChannel.AgentMessage_PersistExchange, {
|
||||
sessionId: agentSession.sessionId,
|
||||
agentSessionId: latestAgentSessionId || '',
|
||||
user: userPersistedPayload ? { payload: userPersistedPayload } : undefined,
|
||||
assistant: assistantPersistedPayload ? { payload: assistantPersistedPayload } : undefined
|
||||
})
|
||||
} catch (error) {
|
||||
logger.warn('Failed to persist agent exchange', error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
const createPersistedMessagePayload = (
|
||||
message: Message | undefined,
|
||||
state: RootState
|
||||
): AgentPersistedMessage | undefined => {
|
||||
if (!message) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
try {
|
||||
const clonedMessage = JSON.parse(JSON.stringify(message)) as Message
|
||||
const blockEntities = (message.blocks || [])
|
||||
.map((blockId) => state.messageBlocks.entities[blockId])
|
||||
.filter((block): block is MessageBlock => Boolean(block))
|
||||
.map((block) => JSON.parse(JSON.stringify(block)) as MessageBlock)
|
||||
|
||||
return {
|
||||
message: clonedMessage,
|
||||
blocks: blockEntities
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Failed to build persisted payload for message', error as Error)
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper Function for Multi-Model Dispatch ---
|
||||
// 多模型创建和发送请求的逻辑,用于用户消息多模型发送和重发
|
||||
const dispatchMultiModelResponses = async (
|
||||
@ -385,7 +694,7 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
})
|
||||
// 统一错误处理:确保 loading 状态被正确设置,避免队列任务卡住
|
||||
try {
|
||||
await callbacks.onError?.(error)
|
||||
callbacks.onError?.(error)
|
||||
} catch (callbackError) {
|
||||
logger.error('Error in onError callback:', callbackError as Error)
|
||||
} finally {
|
||||
@ -403,7 +712,13 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
* @param topicId 主题ID
|
||||
*/
|
||||
export const sendMessage =
|
||||
(userMessage: Message, userMessageBlocks: MessageBlock[], assistant: Assistant, topicId: Topic['id']) =>
|
||||
(
|
||||
userMessage: Message,
|
||||
userMessageBlocks: MessageBlock[],
|
||||
assistant: Assistant,
|
||||
topicId: Topic['id'],
|
||||
agentSession?: AgentSessionContext
|
||||
) =>
|
||||
async (dispatch: AppDispatch, getState: () => RootState) => {
|
||||
try {
|
||||
if (userMessage.blocks.length === 0) {
|
||||
@ -417,12 +732,9 @@ export const sendMessage =
|
||||
}
|
||||
dispatch(updateTopicUpdatedAt({ topicId }))
|
||||
|
||||
const mentionedModels = userMessage.mentions
|
||||
const queue = getTopicQueue(topicId)
|
||||
|
||||
if (mentionedModels && mentionedModels.length > 0) {
|
||||
await dispatchMultiModelResponses(dispatch, getState, topicId, userMessage, assistant, mentionedModels)
|
||||
} else {
|
||||
if (agentSession) {
|
||||
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
|
||||
askId: userMessage.id,
|
||||
model: assistant.model,
|
||||
@ -432,8 +744,32 @@ export const sendMessage =
|
||||
dispatch(newMessagesActions.addMessage({ topicId, message: assistantMessage }))
|
||||
|
||||
queue.add(async () => {
|
||||
await fetchAndProcessAssistantResponseImpl(dispatch, getState, topicId, assistant, assistantMessage)
|
||||
await fetchAndProcessAgentResponseImpl(dispatch, getState, {
|
||||
topicId,
|
||||
assistant,
|
||||
assistantMessage,
|
||||
agentSession,
|
||||
userMessageId: userMessage.id
|
||||
})
|
||||
})
|
||||
} else {
|
||||
const mentionedModels = userMessage.mentions
|
||||
|
||||
if (mentionedModels && mentionedModels.length > 0) {
|
||||
await dispatchMultiModelResponses(dispatch, getState, topicId, userMessage, assistant, mentionedModels)
|
||||
} else {
|
||||
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
|
||||
askId: userMessage.id,
|
||||
model: assistant.model,
|
||||
traceId: userMessage.traceId
|
||||
})
|
||||
await saveMessageAndBlocksToDB(assistantMessage, [])
|
||||
dispatch(newMessagesActions.addMessage({ topicId, message: assistantMessage }))
|
||||
|
||||
queue.add(async () => {
|
||||
await fetchAndProcessAssistantResponseImpl(dispatch, getState, topicId, assistant, assistantMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error in sendMessage thunk:', error as Error)
|
||||
|
||||
@ -4,9 +4,11 @@
|
||||
*
|
||||
* WARNING: Any null value will be converted to undefined from api.
|
||||
*/
|
||||
import { ModelMessage, modelMessageSchema, TextStreamPart } from 'ai'
|
||||
import { TextStreamPart } from 'ai'
|
||||
import { z } from 'zod'
|
||||
|
||||
import type { Message, MessageBlock } from './newMessage'
|
||||
|
||||
// ------------------ Core enums and helper types ------------------
|
||||
export const PermissionModeSchema = z.enum(['default', 'acceptEdits', 'bypassPermissions', 'plan'])
|
||||
export type PermissionMode = z.infer<typeof PermissionModeSchema>
|
||||
@ -109,8 +111,8 @@ export const AgentSessionMessageEntitySchema = z.object({
|
||||
id: z.number(), // Auto-increment primary key
|
||||
session_id: z.string(), // Reference to session
|
||||
// manual defined. may not synced with ai sdk definition
|
||||
role: SessionMessageRoleSchema, // Enforce roles supported by modelMessageSchema
|
||||
content: modelMessageSchema,
|
||||
role: SessionMessageRoleSchema,
|
||||
content: z.unknown(),
|
||||
agent_session_id: z.string(), // agent session id, use to resume agent session
|
||||
metadata: z.record(z.string(), z.any()).optional(), // Additional metadata (optional)
|
||||
created_at: z.iso.datetime(), // ISO timestamp
|
||||
@ -119,6 +121,35 @@ export const AgentSessionMessageEntitySchema = z.object({
|
||||
|
||||
export type AgentSessionMessageEntity = z.infer<typeof AgentSessionMessageEntitySchema>
|
||||
|
||||
export interface AgentPersistedMessage {
|
||||
message: Message
|
||||
blocks: MessageBlock[]
|
||||
}
|
||||
|
||||
export interface AgentMessageUserPersistPayload {
|
||||
payload: AgentPersistedMessage
|
||||
metadata?: Record<string, unknown>
|
||||
createdAt?: string
|
||||
}
|
||||
|
||||
export interface AgentMessageAssistantPersistPayload {
|
||||
payload: AgentPersistedMessage
|
||||
metadata?: Record<string, unknown>
|
||||
createdAt?: string
|
||||
}
|
||||
|
||||
export interface AgentMessagePersistExchangePayload {
|
||||
sessionId: string
|
||||
agentSessionId: string
|
||||
user?: AgentMessageUserPersistPayload
|
||||
assistant?: AgentMessageAssistantPersistPayload
|
||||
}
|
||||
|
||||
export interface AgentMessagePersistExchangeResult {
|
||||
userMessage?: AgentSessionMessageEntity
|
||||
assistantMessage?: AgentSessionMessageEntity
|
||||
}
|
||||
|
||||
// ------------------ Session message payload ------------------
|
||||
|
||||
// Not implemented fields:
|
||||
|
||||
9
src/renderer/src/utils/agentSession.ts
Normal file
9
src/renderer/src/utils/agentSession.ts
Normal file
@ -0,0 +1,9 @@
|
||||
const SESSION_TOPIC_PREFIX = 'agent-session:'
|
||||
|
||||
export const buildAgentSessionTopicId = (sessionId: string): string => {
|
||||
return `${SESSION_TOPIC_PREFIX}${sessionId}`
|
||||
}
|
||||
|
||||
export const isAgentSessionTopicId = (topicId: string): boolean => {
|
||||
return topicId.startsWith(SESSION_TOPIC_PREFIX)
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user