feat: Enhance thinking block management and tool conversion in unified messages

This commit is contained in:
suyao 2025-11-27 19:19:04 +08:00
parent a5e7aa1342
commit 192357a32e
No known key found for this signature in database
3 changed files with 188 additions and 39 deletions

View File

@ -59,7 +59,9 @@ interface AdapterState {
currentBlockIndex: number
blocks: Map<number, ContentBlockState>
textBlockIndex: number | null
thinkingBlockIndex: number | null
// Track multiple thinking blocks by their reasoning ID
thinkingBlocks: Map<string, number> // reasoningId -> blockIndex
currentThinkingId: string | null // Currently active thinking block ID
toolBlocks: Map<string, number> // toolCallId -> blockIndex
stopReason: StopReason | null
hasEmittedMessageStart: boolean
@ -95,7 +97,8 @@ export class AiSdkToAnthropicSSE {
currentBlockIndex: 0,
blocks: new Map(),
textBlockIndex: null,
thinkingBlockIndex: null,
thinkingBlocks: new Map(),
currentThinkingId: null,
toolBlocks: new Map(),
stopReason: null,
hasEmittedMessageStart: false
@ -133,7 +136,7 @@ export class AiSdkToAnthropicSSE {
* Process a single AI SDK chunk and emit corresponding Anthropic events
*/
private processChunk(chunk: TextStreamPart<ToolSet>): void {
logger.silly('AiSdkToAnthropicSSE - Processing chunk:', chunk)
logger.silly('AiSdkToAnthropicSSE - Processing chunk:', { chunk: JSON.stringify(chunk) })
switch (chunk.type) {
// === Text Events ===
case 'text-start':
@ -149,17 +152,23 @@ export class AiSdkToAnthropicSSE {
break
// === Reasoning/Thinking Events ===
case 'reasoning-start':
this.startThinkingBlock()
case 'reasoning-start': {
const reasoningId = (chunk as { id?: string }).id || `reasoning_${Date.now()}`
this.startThinkingBlock(reasoningId)
break
}
case 'reasoning-delta':
this.emitThinkingDelta(chunk.text || '')
case 'reasoning-delta': {
const reasoningId = (chunk as { id?: string }).id
this.emitThinkingDelta(chunk.text || '', reasoningId)
break
}
case 'reasoning-end':
this.stopThinkingBlock()
case 'reasoning-end': {
const reasoningId = (chunk as { id?: string }).id
this.stopThinkingBlock(reasoningId)
break
}
// === Tool Events ===
case 'tool-call':
@ -190,9 +199,7 @@ export class AiSdkToAnthropicSSE {
// === Error Events ===
case 'error':
// Anthropic doesn't have a standard error event in the stream
// Errors are typically sent as separate HTTP responses
// For now, we'll just log and continue
this.handleError(chunk.error)
break
// Ignore other event types
@ -303,11 +310,13 @@ export class AiSdkToAnthropicSSE {
this.state.textBlockIndex = null
}
private startThinkingBlock(): void {
if (this.state.thinkingBlockIndex !== null) return
private startThinkingBlock(reasoningId: string): void {
// Check if this thinking block already exists
if (this.state.thinkingBlocks.has(reasoningId)) return
const index = this.state.currentBlockIndex++
this.state.thinkingBlockIndex = index
this.state.thinkingBlocks.set(reasoningId, index)
this.state.currentThinkingId = reasoningId
this.state.blocks.set(index, {
type: 'thinking',
index,
@ -330,15 +339,25 @@ export class AiSdkToAnthropicSSE {
this.onEvent(event)
}
private emitThinkingDelta(text: string): void {
private emitThinkingDelta(text: string, reasoningId?: string): void {
if (!text) return
// Auto-start thinking block if not started
if (this.state.thinkingBlockIndex === null) {
this.startThinkingBlock()
// Determine which thinking block to use
const targetId = reasoningId || this.state.currentThinkingId
if (!targetId) {
// Auto-start thinking block if not started
const newId = `reasoning_${Date.now()}`
this.startThinkingBlock(newId)
return this.emitThinkingDelta(text, newId)
}
const index = this.state.thinkingBlocks.get(targetId)
if (index === undefined) {
// If the block doesn't exist, create it
this.startThinkingBlock(targetId)
return this.emitThinkingDelta(text, targetId)
}
const index = this.state.thinkingBlockIndex!
const block = this.state.blocks.get(index)
if (block) {
block.content += text
@ -358,10 +377,12 @@ export class AiSdkToAnthropicSSE {
this.onEvent(event)
}
private stopThinkingBlock(): void {
if (this.state.thinkingBlockIndex === null) return
private stopThinkingBlock(reasoningId?: string): void {
const targetId = reasoningId || this.state.currentThinkingId
if (!targetId) return
const index = this.state.thinkingBlockIndex
const index = this.state.thinkingBlocks.get(targetId)
if (index === undefined) return
const event: RawContentBlockStopEvent = {
type: 'content_block_stop',
@ -369,7 +390,14 @@ export class AiSdkToAnthropicSSE {
}
this.onEvent(event)
this.state.thinkingBlockIndex = null
this.state.thinkingBlocks.delete(targetId)
// Update currentThinkingId if we just closed the current one
if (this.state.currentThinkingId === targetId) {
// Set to the most recent remaining thinking block, or null if none
const remaining = Array.from(this.state.thinkingBlocks.keys())
this.state.currentThinkingId = remaining.length > 0 ? remaining[remaining.length - 1] : null
}
}
private handleToolCall(chunk: { type: 'tool-call'; toolCallId: string; toolName: string; args: unknown }): void {
@ -471,13 +499,41 @@ export class AiSdkToAnthropicSSE {
}
}
private handleError(error: unknown): void {
// Log the error for debugging
logger.warn('AiSdkToAnthropicSSE - Provider error received:', { error })
// Extract error message
let errorMessage = 'Unknown error from provider'
if (error && typeof error === 'object') {
const err = error as { message?: string; metadata?: { raw?: string } }
if (err.metadata?.raw) {
errorMessage = `Provider error: ${err.metadata.raw}`
} else if (err.message) {
errorMessage = err.message
}
} else if (typeof error === 'string') {
errorMessage = error
}
// Emit error as a text block so the user can see it
// First close any open thinking blocks to maintain proper event order
for (const reasoningId of Array.from(this.state.thinkingBlocks.keys())) {
this.stopThinkingBlock(reasoningId)
}
// Emit the error as text
this.emitTextDelta(`\n\n[Error: ${errorMessage}]\n`)
}
private finalize(): void {
// Close any open blocks
if (this.state.textBlockIndex !== null) {
this.stopTextBlock()
}
if (this.state.thinkingBlockIndex !== null) {
this.stopThinkingBlock()
// Close all open thinking blocks
for (const reasoningId of this.state.thinkingBlocks.keys()) {
this.stopThinkingBlock(reasoningId)
}
// Emit message_delta with final stop reason and usage

View File

@ -1,6 +1,11 @@
import type { LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
import type { ImageBlockParam, MessageCreateParams, TextBlockParam } from '@anthropic-ai/sdk/resources/messages'
import type {
ImageBlockParam,
MessageCreateParams,
TextBlockParam,
Tool as AnthropicTool
} from '@anthropic-ai/sdk/resources/messages'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { reduxService } from '@main/services/ReduxService'
@ -16,8 +21,8 @@ import {
} from '@shared/provider'
import { defaultAppHeaders } from '@shared/utils'
import type { Provider } from '@types'
import type { ImagePart, LanguageModel, ModelMessage, Provider as AiSdkProvider, TextPart } from 'ai'
import { stepCountIs, streamText } from 'ai'
import type { ImagePart, LanguageModel, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai'
import { jsonSchema, stepCountIs, streamText, tool } from 'ai'
import { net } from 'electron'
import type { Response } from 'express'
@ -190,6 +195,39 @@ IANA media type.
}
}
/**
* Convert Anthropic tools format to AI SDK tools format
*/
function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record<string, Tool> | undefined {
if (!tools || tools.length === 0) {
return undefined
}
const aiSdkTools: Record<string, Tool> = {}
for (const anthropicTool of tools) {
// Handle different tool types
if (anthropicTool.type === 'bash_20250124') {
// Skip computer use and bash tools - these are Anthropic-specific
continue
}
// Regular tool (type === 'custom' or no type)
const toolDef = anthropicTool as AnthropicTool
const parameters = toolDef.input_schema as Parameters<typeof jsonSchema>[0]
aiSdkTools[toolDef.name] = tool({
description: toolDef.description || '',
inputSchema: jsonSchema(parameters),
execute: async (input: Record<string, unknown>) => {
return input
}
})
}
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
}
/**
* Convert Anthropic MessageCreateParams to AI SDK message format
*/
@ -271,6 +309,13 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
}
}
if (toolResultParts.length > 0) {
messages.push({
role: 'tool',
content: [...toolResultParts]
})
}
// Build the message based on role
if (msg.role === 'user') {
messages.push({
@ -278,13 +323,11 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
content: [...textParts, ...imageParts]
})
} else {
// Assistant messages can only have text
if (textParts.length > 0) {
messages.push({
role: 'assistant',
content: [...reasoningParts, ...textParts, ...toolCallParts, ...toolResultParts]
})
}
// Assistant messages contain tool calls, not tool results
messages.push({
role: 'assistant',
content: [...reasoningParts, ...textParts, ...toolCallParts]
})
}
}
}
@ -315,10 +358,29 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis
const coreMessages = convertAnthropicToAiMessages(params)
// Convert tools if present
const tools = convertAnthropicToolsToAiSdk(params.tools)
logger.debug('Converted messages', {
originalCount: params.messages.length,
convertedCount: coreMessages.length,
hasSystem: !!params.system
hasSystem: !!params.system,
hasTools: !!tools,
toolCount: tools ? Object.keys(tools).length : 0,
toolNames: tools ? Object.keys(tools).slice(0, 10) : [],
paramsToolCount: params.tools?.length || 0
})
// Debug: Log message structure to understand tool_result handling
logger.silly('Message structure for debugging', {
messages: coreMessages.map((m) => ({
role: m.role,
contentTypes: Array.isArray(m.content)
? m.content.map((c: { type: string }) => c.type)
: typeof m.content === 'string'
? ['string']
: ['unknown']
}))
})
// Create the adapter
@ -340,6 +402,7 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis
stopSequences: params.stop_sequences,
stopWhen: stepCountIs(100),
headers: defaultAppHeaders(),
tools,
providerOptions: {}
})
@ -404,8 +467,9 @@ export async function generateUnifiedMessage(
// Create language model (async - uses @cherrystudio/ai-core)
const model = await createLanguageModel(provider, modelId)
// Convert messages
// Convert messages and tools
const coreMessages = convertAnthropicToAiMessages(params)
const tools = convertAnthropicToolsToAiSdk(params.tools)
// Create adapter to collect the response
let finalResponse: ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse> | null = null
@ -425,6 +489,7 @@ export async function generateUnifiedMessage(
topP: params.top_p,
stopSequences: params.stop_sequences,
headers: defaultAppHeaders(),
tools,
stopWhen: stepCountIs(100)
})

View File

@ -193,6 +193,30 @@ function handleAssistantMessage(
}
break
}
case 'thinking':
case 'redacted_thinking': {
const thinkingText = block.type === 'thinking' ? block.thinking : block.data
if (thinkingText) {
const id = generateMessageId()
chunks.push({
type: 'reasoning-start',
id,
providerMetadata
})
chunks.push({
type: 'reasoning-delta',
id,
text: thinkingText,
providerMetadata
})
chunks.push({
type: 'reasoning-end',
id,
providerMetadata
})
}
break
}
case 'tool_use':
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
break
@ -445,7 +469,11 @@ function handleStreamEvent(
case 'content_block_stop': {
const block = state.closeBlock(event.index)
if (!block) {
logger.warn('Received content_block_stop for unknown index', { index: event.index })
// Some providers (e.g., Gemini) send content via assistant message before stream events,
// so the block may not exist in state. This is expected behavior, not an error.
logger.debug('Received content_block_stop for unknown index (may be from non-streaming content)', {
index: event.index
})
break
}