feat: implement reasoning cache for improved performance and error handling in AI SDK integration

This commit is contained in:
suyao 2025-11-28 13:11:13 +08:00
parent 356e828422
commit d367040fd4
No known key found for this signature in database
8 changed files with 225 additions and 77 deletions

View File

@ -36,7 +36,8 @@ import type {
Usage Usage
} from '@anthropic-ai/sdk/resources/messages' } from '@anthropic-ai/sdk/resources/messages'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai' import { reasoningCache } from '@main/apiServer/services/cache'
import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai'
const logger = loggerService.withContext('AiSdkToAnthropicSSE') const logger = loggerService.withContext('AiSdkToAnthropicSSE')
@ -125,6 +126,9 @@ export class AiSdkToAnthropicSSE {
// Ensure all blocks are closed and emit final events // Ensure all blocks are closed and emit final events
this.finalize() this.finalize()
} catch (error) {
await reader.cancel()
throw error
} finally { } finally {
reader.releaseLock() reader.releaseLock()
} }
@ -188,8 +192,13 @@ export class AiSdkToAnthropicSSE {
// }) // })
break break
// === Completion Events ===
case 'finish-step': case 'finish-step':
if (
chunk.providerMetadata?.openrouter?.reasoning_details &&
Array.isArray(chunk.providerMetadata.openrouter.reasoning_details)
) {
reasoningCache.set('openrouter', chunk.providerMetadata?.openrouter?.reasoning_details)
}
if (chunk.finishReason === 'tool-calls') { if (chunk.finishReason === 'tool-calls') {
this.state.stopReason = 'tool_use' this.state.stopReason = 'tool_use'
} }
@ -199,10 +208,8 @@ export class AiSdkToAnthropicSSE {
this.handleFinish(chunk) this.handleFinish(chunk)
break break
// === Error Events ===
case 'error': case 'error':
this.handleError(chunk.error) throw chunk.error
break
// Ignore other event types // Ignore other event types
default: default:
@ -496,33 +503,6 @@ export class AiSdkToAnthropicSSE {
} }
} }
private handleError(error: unknown): void {
// Log the error for debugging
logger.warn('AiSdkToAnthropicSSE - Provider error received:', { error })
// Extract error message
let errorMessage = 'Unknown error from provider'
if (error && typeof error === 'object') {
const err = error as { message?: string; metadata?: { raw?: string } }
if (err.metadata?.raw) {
errorMessage = `Provider error: ${err.metadata.raw}`
} else if (err.message) {
errorMessage = err.message
}
} else if (typeof error === 'string') {
errorMessage = error
}
// Emit error as a text block so the user can see it
// First close any open thinking blocks to maintain proper event order
for (const reasoningId of Array.from(this.state.thinkingBlocks.keys())) {
this.stopThinkingBlock(reasoningId)
}
// Emit the error as text
this.emitTextDelta(`\n\n[Error: ${errorMessage}]\n`)
}
private finalize(): void { private finalize(): void {
// Close any open blocks // Close any open blocks
if (this.state.textBlockIndex !== null) { if (this.state.textBlockIndex !== null) {

View File

@ -245,7 +245,6 @@ async function handleUnifiedProcessing({
res.json(response) res.json(response)
} }
} catch (error: any) { } catch (error: any) {
logger.error('Unified processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error) const { statusCode, errorResponse } = messagesService.transformError(error)
res.status(statusCode).json(errorResponse) res.status(statusCode).json(errorResponse)
} }

View File

@ -0,0 +1,116 @@
import { loggerService } from '@logger'
const logger = loggerService.withContext('Cache')
/**
* Cache entry with TTL support
*/
interface CacheEntry<T> {
details: T[]
timestamp: number
}
/**
* In-memory cache for reasoning details
* Key: signature
* Value: reasoning array with timestamp
*/
export class ReasoningCache<T> {
private cache = new Map<string, CacheEntry<T>>()
private readonly ttlMs: number
private cleanupInterval: ReturnType<typeof setInterval> | null = null
constructor(ttlMs: number = 30 * 60 * 1000) {
// Default 30 minutes TTL
this.ttlMs = ttlMs
this.startCleanup()
}
/**
* Store reasoning details by signature
*/
set(signature: string, details: T[]): void {
if (!signature || !details.length) return
this.cache.set(signature, {
details,
timestamp: Date.now()
})
logger.debug('Cached reasoning details', {
signature: signature.substring(0, 20) + '...',
detailsCount: details.length
})
}
/**
* Retrieve reasoning details by signature
*/
get(signature: string): T[] | undefined {
const entry = this.cache.get(signature)
if (!entry) return undefined
// Check TTL
if (Date.now() - entry.timestamp > this.ttlMs) {
this.cache.delete(signature)
return undefined
}
logger.debug('Retrieved reasoning details from cache', {
signature: signature.substring(0, 20) + '...',
detailsCount: entry.details.length
})
return entry.details
}
/**
* Clear expired entries
*/
cleanup(): void {
const now = Date.now()
let cleaned = 0
for (const [key, entry] of this.cache) {
if (now - entry.timestamp > this.ttlMs) {
this.cache.delete(key)
cleaned++
}
}
if (cleaned > 0) {
logger.debug('Cleaned up expired reasoning cache entries', { cleaned, remaining: this.cache.size })
}
}
/**
* Start periodic cleanup
*/
private startCleanup(): void {
// Cleanup every 5 minutes
this.cleanupInterval = setInterval(() => this.cleanup(), 5 * 60 * 1000)
}
/**
* Stop cleanup and clear cache
*/
destroy(): void {
if (this.cleanupInterval) {
clearInterval(this.cleanupInterval)
this.cleanupInterval = null
}
this.cache.clear()
}
/**
* Get cache stats for debugging
*/
stats(): { size: number; ttlMs: number } {
return {
size: this.cache.size,
ttlMs: this.ttlMs
}
}
}
// Singleton cache instance
export const reasoningCache = new ReasoningCache()

View File

@ -4,6 +4,7 @@ import { loggerService } from '@logger'
import anthropicService from '@main/services/AnthropicService' import anthropicService from '@main/services/AnthropicService'
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic' import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
import type { Provider } from '@types' import type { Provider } from '@types'
import { APICallError } from 'ai'
import { net } from 'electron' import { net } from 'electron'
import type { Response } from 'express' import type { Response } from 'express'
@ -253,9 +254,36 @@ export class MessagesService {
} }
transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } { transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } {
let statusCode = 500 let statusCode: number | undefined = undefined
let errorType = 'api_error' let errorType: string | undefined = undefined
let errorMessage = 'Internal server error' let errorMessage: string | undefined = undefined
const errorMap: Record<number, string> = {
400: 'invalid_request_error',
401: 'authentication_error',
403: 'forbidden_error',
404: 'not_found_error',
429: 'rate_limit_error',
500: 'internal_server_error'
}
if (APICallError.isInstance(error)) {
statusCode = error.statusCode
errorMessage = error.message
if (statusCode) {
return {
statusCode,
errorResponse: {
type: 'error',
error: {
type: errorMap[statusCode] || 'api_error',
message: errorMessage,
requestId: error.name
}
}
}
}
}
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
const anthropicError = error?.error const anthropicError = error?.error
@ -297,11 +325,11 @@ export class MessagesService {
typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error' typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error'
return { return {
statusCode, statusCode: statusCode ?? 500,
errorResponse: { errorResponse: {
type: 'error', type: 'error',
error: { error: {
type: errorType, type: errorType || 'api_error',
message: safeErrorMessage, message: safeErrorMessage,
requestId: error?.request_id requestId: error?.request_id
} }

View File

@ -23,11 +23,13 @@ import {
} from '@shared/provider' } from '@shared/provider'
import { defaultAppHeaders } from '@shared/utils' import { defaultAppHeaders } from '@shared/utils'
import type { Provider } from '@types' import type { Provider } from '@types'
import type { ImagePart, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai' import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai'
import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel } from 'ai' import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel } from 'ai'
import { net } from 'electron' import { net } from 'electron'
import type { Response } from 'express' import type { Response } from 'express'
import { reasoningCache } from './cache'
const logger = loggerService.withContext('UnifiedMessagesService') const logger = loggerService.withContext('UnifiedMessagesService')
const MAGIC_STRING = 'skip_thought_signature_validator' const MAGIC_STRING = 'skip_thought_signature_validator'
@ -154,8 +156,6 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
} }
} }
// Build a map of tool_use_id -> toolName from all messages first
// This is needed because tool_result references tool_use from previous assistant messages
const toolCallIdToName = new Map<string, string>() const toolCallIdToName = new Map<string, string>()
for (const msg of params.messages) { for (const msg of params.messages) {
if (Array.isArray(msg.content)) { if (Array.isArray(msg.content)) {
@ -227,13 +227,16 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts] const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
if (assistantContent.length > 0) { if (assistantContent.length > 0) {
let providerOptions: ProviderOptions | undefined = undefined let providerOptions: ProviderOptions | undefined = undefined
if (isGemini3ModelId(params.model)) { if (reasoningCache.get('openrouter')) {
providerOptions = {
openrouter: {
reasoning_details: (reasoningCache.get('openrouter') as JSONValue[]) || []
}
}
} else if (isGemini3ModelId(params.model)) {
providerOptions = { providerOptions = {
google: { google: {
thoughtSignature: MAGIC_STRING thoughtSignature: MAGIC_STRING
},
openrouter: {
reasoning_details: []
} }
} }
} }
@ -367,6 +370,7 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis
middlewares, middlewares,
plugins, plugins,
onEvent: (event) => { onEvent: (event) => {
logger.silly('Streaming event', { eventType: event.type })
const sseData = formatSSEEvent(event) const sseData = formatSSEEvent(event)
response.write(sseData) response.write(sseData)
} }
@ -380,22 +384,6 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis
onComplete?.() onComplete?.()
} catch (error) { } catch (error) {
logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId }) logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId })
if (!response.writableEnded) {
try {
const errorMessage = error instanceof Error ? error.message : 'Unknown error'
response.write(
`event: error\ndata: ${JSON.stringify({
type: 'error',
error: { type: 'api_error', message: errorMessage }
})}\n\n`
)
response.end()
} catch {
// Response already ended
}
}
onError?.(error) onError?.(error)
throw error throw error
} }

View File

@ -87,6 +87,7 @@ export class ClaudeStreamState {
private pendingUsage: PendingUsageState = {} private pendingUsage: PendingUsageState = {}
private pendingToolCalls = new Map<string, PendingToolCall>() private pendingToolCalls = new Map<string, PendingToolCall>()
private stepActive = false private stepActive = false
private _streamFinished = false
constructor(options: ClaudeStreamStateOptions) { constructor(options: ClaudeStreamStateOptions) {
this.logger = loggerService.withContext('ClaudeStreamState') this.logger = loggerService.withContext('ClaudeStreamState')
@ -289,6 +290,16 @@ export class ClaudeStreamState {
getNamespacedToolCallId(rawToolCallId: string): string { getNamespacedToolCallId(rawToolCallId: string): string {
return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId) return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
} }
/** Marks the stream as finished (either completed or errored). */
markFinished(): void {
this._streamFinished = true
}
/** Returns true if the stream has already emitted a terminal event. */
isFinished(): boolean {
return this._streamFinished
}
} }
export type { PendingToolCall } export type { PendingToolCall }

View File

@ -529,6 +529,19 @@ class ClaudeCodeService implements AgentServiceInterface {
return return
} }
// Skip emitting error if stream already finished (error was handled via result message)
if (streamState.isFinished()) {
logger.debug('SDK process exited after stream finished, skipping duplicate error event', {
duration,
error: errorObj instanceof Error ? { name: errorObj.name, message: errorObj.message } : String(errorObj)
})
// Still emit complete to signal stream end
stream.emit('data', {
type: 'complete'
})
return
}
errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj)) errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj))
const errorMessage = errorChunks.join('\n\n') const errorMessage = errorChunks.join('\n\n')
logger.error('SDK query failed', { logger.error('SDK query failed', {

View File

@ -121,7 +121,7 @@ export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state:
case 'system': case 'system':
return handleSystemMessage(sdkMessage) return handleSystemMessage(sdkMessage)
case 'result': case 'result':
return handleResultMessage(sdkMessage) return handleResultMessage(sdkMessage, state)
default: default:
logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type }) logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type })
return [] return []
@ -707,7 +707,13 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
* Successful runs yield a `finish` frame with aggregated usage metrics, while * Successful runs yield a `finish` frame with aggregated usage metrics, while
* failures are surfaced as `error` frames. * failures are surfaced as `error` frames.
*/ */
function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>): AgentStreamPart[] { function handleResultMessage(
message: Extract<SDKMessage, { type: 'result' }>,
state: ClaudeStreamState
): AgentStreamPart[] {
// Mark stream as finished to prevent duplicate error events when SDK process exits
state.markFinished()
const chunks: AgentStreamPart[] = [] const chunks: AgentStreamPart[] = []
let usage: LanguageModelUsage | undefined let usage: LanguageModelUsage | undefined
@ -719,7 +725,6 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
} }
} }
if (message.subtype === 'success') {
chunks.push({ chunks.push({
type: 'finish', type: 'finish',
totalUsage: usage ?? emptyUsage, totalUsage: usage ?? emptyUsage,
@ -732,13 +737,21 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
raw: message raw: message
} }
} as AgentStreamPart) } as AgentStreamPart)
} else { if (message.subtype !== 'success') {
chunks.push({ chunks.push({
type: 'error', type: 'error',
error: { error: {
message: `${message.subtype}: Process failed after ${message.num_turns} turns` message: `${message.subtype}: Process failed after ${message.num_turns} turns`
} }
} as AgentStreamPart) } as AgentStreamPart)
} else {
if (message.is_error) {
const errorMatch = message.result.match(/\{.*\}/)
if (errorMatch) {
const errorDetail = JSON.parse(errorMatch[0])
chunks.push(errorDetail)
}
}
} }
return chunks return chunks
} }