mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-22 00:13:09 +08:00
feat: implement reasoning cache for improved performance and error handling in AI SDK integration
This commit is contained in:
parent
356e828422
commit
d367040fd4
@ -36,7 +36,8 @@ import type {
|
||||
Usage
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { loggerService } from '@logger'
|
||||
import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai'
|
||||
import { reasoningCache } from '@main/apiServer/services/cache'
|
||||
import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkToAnthropicSSE')
|
||||
|
||||
@ -125,6 +126,9 @@ export class AiSdkToAnthropicSSE {
|
||||
|
||||
// Ensure all blocks are closed and emit final events
|
||||
this.finalize()
|
||||
} catch (error) {
|
||||
await reader.cancel()
|
||||
throw error
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
@ -188,8 +192,13 @@ export class AiSdkToAnthropicSSE {
|
||||
// })
|
||||
break
|
||||
|
||||
// === Completion Events ===
|
||||
case 'finish-step':
|
||||
if (
|
||||
chunk.providerMetadata?.openrouter?.reasoning_details &&
|
||||
Array.isArray(chunk.providerMetadata.openrouter.reasoning_details)
|
||||
) {
|
||||
reasoningCache.set('openrouter', chunk.providerMetadata?.openrouter?.reasoning_details)
|
||||
}
|
||||
if (chunk.finishReason === 'tool-calls') {
|
||||
this.state.stopReason = 'tool_use'
|
||||
}
|
||||
@ -199,10 +208,8 @@ export class AiSdkToAnthropicSSE {
|
||||
this.handleFinish(chunk)
|
||||
break
|
||||
|
||||
// === Error Events ===
|
||||
case 'error':
|
||||
this.handleError(chunk.error)
|
||||
break
|
||||
throw chunk.error
|
||||
|
||||
// Ignore other event types
|
||||
default:
|
||||
@ -496,33 +503,6 @@ export class AiSdkToAnthropicSSE {
|
||||
}
|
||||
}
|
||||
|
||||
private handleError(error: unknown): void {
|
||||
// Log the error for debugging
|
||||
logger.warn('AiSdkToAnthropicSSE - Provider error received:', { error })
|
||||
|
||||
// Extract error message
|
||||
let errorMessage = 'Unknown error from provider'
|
||||
if (error && typeof error === 'object') {
|
||||
const err = error as { message?: string; metadata?: { raw?: string } }
|
||||
if (err.metadata?.raw) {
|
||||
errorMessage = `Provider error: ${err.metadata.raw}`
|
||||
} else if (err.message) {
|
||||
errorMessage = err.message
|
||||
}
|
||||
} else if (typeof error === 'string') {
|
||||
errorMessage = error
|
||||
}
|
||||
|
||||
// Emit error as a text block so the user can see it
|
||||
// First close any open thinking blocks to maintain proper event order
|
||||
for (const reasoningId of Array.from(this.state.thinkingBlocks.keys())) {
|
||||
this.stopThinkingBlock(reasoningId)
|
||||
}
|
||||
|
||||
// Emit the error as text
|
||||
this.emitTextDelta(`\n\n[Error: ${errorMessage}]\n`)
|
||||
}
|
||||
|
||||
private finalize(): void {
|
||||
// Close any open blocks
|
||||
if (this.state.textBlockIndex !== null) {
|
||||
|
||||
@ -245,7 +245,6 @@ async function handleUnifiedProcessing({
|
||||
res.json(response)
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error('Unified processing error', { error })
|
||||
const { statusCode, errorResponse } = messagesService.transformError(error)
|
||||
res.status(statusCode).json(errorResponse)
|
||||
}
|
||||
|
||||
116
src/main/apiServer/services/cache.ts
Normal file
116
src/main/apiServer/services/cache.ts
Normal file
@ -0,0 +1,116 @@
|
||||
import { loggerService } from '@logger'
|
||||
|
||||
const logger = loggerService.withContext('Cache')
|
||||
/**
|
||||
* Cache entry with TTL support
|
||||
*/
|
||||
interface CacheEntry<T> {
|
||||
details: T[]
|
||||
timestamp: number
|
||||
}
|
||||
|
||||
/**
|
||||
* In-memory cache for reasoning details
|
||||
* Key: signature
|
||||
* Value: reasoning array with timestamp
|
||||
*/
|
||||
export class ReasoningCache<T> {
|
||||
private cache = new Map<string, CacheEntry<T>>()
|
||||
private readonly ttlMs: number
|
||||
private cleanupInterval: ReturnType<typeof setInterval> | null = null
|
||||
|
||||
constructor(ttlMs: number = 30 * 60 * 1000) {
|
||||
// Default 30 minutes TTL
|
||||
this.ttlMs = ttlMs
|
||||
this.startCleanup()
|
||||
}
|
||||
|
||||
/**
|
||||
* Store reasoning details by signature
|
||||
*/
|
||||
set(signature: string, details: T[]): void {
|
||||
if (!signature || !details.length) return
|
||||
|
||||
this.cache.set(signature, {
|
||||
details,
|
||||
timestamp: Date.now()
|
||||
})
|
||||
|
||||
logger.debug('Cached reasoning details', {
|
||||
signature: signature.substring(0, 20) + '...',
|
||||
detailsCount: details.length
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve reasoning details by signature
|
||||
*/
|
||||
get(signature: string): T[] | undefined {
|
||||
const entry = this.cache.get(signature)
|
||||
if (!entry) return undefined
|
||||
|
||||
// Check TTL
|
||||
if (Date.now() - entry.timestamp > this.ttlMs) {
|
||||
this.cache.delete(signature)
|
||||
return undefined
|
||||
}
|
||||
|
||||
logger.debug('Retrieved reasoning details from cache', {
|
||||
signature: signature.substring(0, 20) + '...',
|
||||
detailsCount: entry.details.length
|
||||
})
|
||||
|
||||
return entry.details
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear expired entries
|
||||
*/
|
||||
cleanup(): void {
|
||||
const now = Date.now()
|
||||
let cleaned = 0
|
||||
|
||||
for (const [key, entry] of this.cache) {
|
||||
if (now - entry.timestamp > this.ttlMs) {
|
||||
this.cache.delete(key)
|
||||
cleaned++
|
||||
}
|
||||
}
|
||||
|
||||
if (cleaned > 0) {
|
||||
logger.debug('Cleaned up expired reasoning cache entries', { cleaned, remaining: this.cache.size })
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start periodic cleanup
|
||||
*/
|
||||
private startCleanup(): void {
|
||||
// Cleanup every 5 minutes
|
||||
this.cleanupInterval = setInterval(() => this.cleanup(), 5 * 60 * 1000)
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop cleanup and clear cache
|
||||
*/
|
||||
destroy(): void {
|
||||
if (this.cleanupInterval) {
|
||||
clearInterval(this.cleanupInterval)
|
||||
this.cleanupInterval = null
|
||||
}
|
||||
this.cache.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* Get cache stats for debugging
|
||||
*/
|
||||
stats(): { size: number; ttlMs: number } {
|
||||
return {
|
||||
size: this.cache.size,
|
||||
ttlMs: this.ttlMs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Singleton cache instance
|
||||
export const reasoningCache = new ReasoningCache()
|
||||
@ -4,6 +4,7 @@ import { loggerService } from '@logger'
|
||||
import anthropicService from '@main/services/AnthropicService'
|
||||
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
|
||||
import type { Provider } from '@types'
|
||||
import { APICallError } from 'ai'
|
||||
import { net } from 'electron'
|
||||
import type { Response } from 'express'
|
||||
|
||||
@ -253,9 +254,36 @@ export class MessagesService {
|
||||
}
|
||||
|
||||
transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } {
|
||||
let statusCode = 500
|
||||
let errorType = 'api_error'
|
||||
let errorMessage = 'Internal server error'
|
||||
let statusCode: number | undefined = undefined
|
||||
let errorType: string | undefined = undefined
|
||||
let errorMessage: string | undefined = undefined
|
||||
|
||||
const errorMap: Record<number, string> = {
|
||||
400: 'invalid_request_error',
|
||||
401: 'authentication_error',
|
||||
403: 'forbidden_error',
|
||||
404: 'not_found_error',
|
||||
429: 'rate_limit_error',
|
||||
500: 'internal_server_error'
|
||||
}
|
||||
|
||||
if (APICallError.isInstance(error)) {
|
||||
statusCode = error.statusCode
|
||||
errorMessage = error.message
|
||||
if (statusCode) {
|
||||
return {
|
||||
statusCode,
|
||||
errorResponse: {
|
||||
type: 'error',
|
||||
error: {
|
||||
type: errorMap[statusCode] || 'api_error',
|
||||
message: errorMessage,
|
||||
requestId: error.name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
|
||||
const anthropicError = error?.error
|
||||
@ -297,11 +325,11 @@ export class MessagesService {
|
||||
typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error'
|
||||
|
||||
return {
|
||||
statusCode,
|
||||
statusCode: statusCode ?? 500,
|
||||
errorResponse: {
|
||||
type: 'error',
|
||||
error: {
|
||||
type: errorType,
|
||||
type: errorType || 'api_error',
|
||||
message: safeErrorMessage,
|
||||
requestId: error?.request_id
|
||||
}
|
||||
|
||||
@ -23,11 +23,13 @@ import {
|
||||
} from '@shared/provider'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import type { Provider } from '@types'
|
||||
import type { ImagePart, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai'
|
||||
import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai'
|
||||
import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel } from 'ai'
|
||||
import { net } from 'electron'
|
||||
import type { Response } from 'express'
|
||||
|
||||
import { reasoningCache } from './cache'
|
||||
|
||||
const logger = loggerService.withContext('UnifiedMessagesService')
|
||||
|
||||
const MAGIC_STRING = 'skip_thought_signature_validator'
|
||||
@ -154,8 +156,6 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
|
||||
}
|
||||
}
|
||||
|
||||
// Build a map of tool_use_id -> toolName from all messages first
|
||||
// This is needed because tool_result references tool_use from previous assistant messages
|
||||
const toolCallIdToName = new Map<string, string>()
|
||||
for (const msg of params.messages) {
|
||||
if (Array.isArray(msg.content)) {
|
||||
@ -227,13 +227,16 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
|
||||
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
|
||||
if (assistantContent.length > 0) {
|
||||
let providerOptions: ProviderOptions | undefined = undefined
|
||||
if (isGemini3ModelId(params.model)) {
|
||||
if (reasoningCache.get('openrouter')) {
|
||||
providerOptions = {
|
||||
openrouter: {
|
||||
reasoning_details: (reasoningCache.get('openrouter') as JSONValue[]) || []
|
||||
}
|
||||
}
|
||||
} else if (isGemini3ModelId(params.model)) {
|
||||
providerOptions = {
|
||||
google: {
|
||||
thoughtSignature: MAGIC_STRING
|
||||
},
|
||||
openrouter: {
|
||||
reasoning_details: []
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -367,6 +370,7 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis
|
||||
middlewares,
|
||||
plugins,
|
||||
onEvent: (event) => {
|
||||
logger.silly('Streaming event', { eventType: event.type })
|
||||
const sseData = formatSSEEvent(event)
|
||||
response.write(sseData)
|
||||
}
|
||||
@ -380,22 +384,6 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis
|
||||
onComplete?.()
|
||||
} catch (error) {
|
||||
logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId })
|
||||
|
||||
if (!response.writableEnded) {
|
||||
try {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error'
|
||||
response.write(
|
||||
`event: error\ndata: ${JSON.stringify({
|
||||
type: 'error',
|
||||
error: { type: 'api_error', message: errorMessage }
|
||||
})}\n\n`
|
||||
)
|
||||
response.end()
|
||||
} catch {
|
||||
// Response already ended
|
||||
}
|
||||
}
|
||||
|
||||
onError?.(error)
|
||||
throw error
|
||||
}
|
||||
|
||||
@ -87,6 +87,7 @@ export class ClaudeStreamState {
|
||||
private pendingUsage: PendingUsageState = {}
|
||||
private pendingToolCalls = new Map<string, PendingToolCall>()
|
||||
private stepActive = false
|
||||
private _streamFinished = false
|
||||
|
||||
constructor(options: ClaudeStreamStateOptions) {
|
||||
this.logger = loggerService.withContext('ClaudeStreamState')
|
||||
@ -289,6 +290,16 @@ export class ClaudeStreamState {
|
||||
getNamespacedToolCallId(rawToolCallId: string): string {
|
||||
return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
|
||||
}
|
||||
|
||||
/** Marks the stream as finished (either completed or errored). */
|
||||
markFinished(): void {
|
||||
this._streamFinished = true
|
||||
}
|
||||
|
||||
/** Returns true if the stream has already emitted a terminal event. */
|
||||
isFinished(): boolean {
|
||||
return this._streamFinished
|
||||
}
|
||||
}
|
||||
|
||||
export type { PendingToolCall }
|
||||
|
||||
@ -529,6 +529,19 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip emitting error if stream already finished (error was handled via result message)
|
||||
if (streamState.isFinished()) {
|
||||
logger.debug('SDK process exited after stream finished, skipping duplicate error event', {
|
||||
duration,
|
||||
error: errorObj instanceof Error ? { name: errorObj.name, message: errorObj.message } : String(errorObj)
|
||||
})
|
||||
// Still emit complete to signal stream end
|
||||
stream.emit('data', {
|
||||
type: 'complete'
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj))
|
||||
const errorMessage = errorChunks.join('\n\n')
|
||||
logger.error('SDK query failed', {
|
||||
|
||||
@ -121,7 +121,7 @@ export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state:
|
||||
case 'system':
|
||||
return handleSystemMessage(sdkMessage)
|
||||
case 'result':
|
||||
return handleResultMessage(sdkMessage)
|
||||
return handleResultMessage(sdkMessage, state)
|
||||
default:
|
||||
logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type })
|
||||
return []
|
||||
@ -707,7 +707,13 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
|
||||
* Successful runs yield a `finish` frame with aggregated usage metrics, while
|
||||
* failures are surfaced as `error` frames.
|
||||
*/
|
||||
function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>): AgentStreamPart[] {
|
||||
function handleResultMessage(
|
||||
message: Extract<SDKMessage, { type: 'result' }>,
|
||||
state: ClaudeStreamState
|
||||
): AgentStreamPart[] {
|
||||
// Mark stream as finished to prevent duplicate error events when SDK process exits
|
||||
state.markFinished()
|
||||
|
||||
const chunks: AgentStreamPart[] = []
|
||||
|
||||
let usage: LanguageModelUsage | undefined
|
||||
@ -719,26 +725,33 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
|
||||
}
|
||||
}
|
||||
|
||||
if (message.subtype === 'success') {
|
||||
chunks.push({
|
||||
type: 'finish',
|
||||
totalUsage: usage ?? emptyUsage,
|
||||
finishReason: mapClaudeCodeFinishReason(message.subtype),
|
||||
providerMetadata: {
|
||||
...sdkMessageToProviderMetadata(message),
|
||||
usage: message.usage,
|
||||
durationMs: message.duration_ms,
|
||||
costUsd: message.total_cost_usd,
|
||||
raw: message
|
||||
}
|
||||
} as AgentStreamPart)
|
||||
} else {
|
||||
chunks.push({
|
||||
type: 'finish',
|
||||
totalUsage: usage ?? emptyUsage,
|
||||
finishReason: mapClaudeCodeFinishReason(message.subtype),
|
||||
providerMetadata: {
|
||||
...sdkMessageToProviderMetadata(message),
|
||||
usage: message.usage,
|
||||
durationMs: message.duration_ms,
|
||||
costUsd: message.total_cost_usd,
|
||||
raw: message
|
||||
}
|
||||
} as AgentStreamPart)
|
||||
if (message.subtype !== 'success') {
|
||||
chunks.push({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: `${message.subtype}: Process failed after ${message.num_turns} turns`
|
||||
}
|
||||
} as AgentStreamPart)
|
||||
} else {
|
||||
if (message.is_error) {
|
||||
const errorMatch = message.result.match(/\{.*\}/)
|
||||
if (errorMatch) {
|
||||
const errorDetail = JSON.parse(errorMatch[0])
|
||||
chunks.push(errorDetail)
|
||||
}
|
||||
}
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user