mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-31 00:10:22 +08:00
feat: Enhance thinking block management and tool conversion in unified messages
This commit is contained in:
parent
a5e7aa1342
commit
192357a32e
@ -59,7 +59,9 @@ interface AdapterState {
|
||||
currentBlockIndex: number
|
||||
blocks: Map<number, ContentBlockState>
|
||||
textBlockIndex: number | null
|
||||
thinkingBlockIndex: number | null
|
||||
// Track multiple thinking blocks by their reasoning ID
|
||||
thinkingBlocks: Map<string, number> // reasoningId -> blockIndex
|
||||
currentThinkingId: string | null // Currently active thinking block ID
|
||||
toolBlocks: Map<string, number> // toolCallId -> blockIndex
|
||||
stopReason: StopReason | null
|
||||
hasEmittedMessageStart: boolean
|
||||
@ -95,7 +97,8 @@ export class AiSdkToAnthropicSSE {
|
||||
currentBlockIndex: 0,
|
||||
blocks: new Map(),
|
||||
textBlockIndex: null,
|
||||
thinkingBlockIndex: null,
|
||||
thinkingBlocks: new Map(),
|
||||
currentThinkingId: null,
|
||||
toolBlocks: new Map(),
|
||||
stopReason: null,
|
||||
hasEmittedMessageStart: false
|
||||
@ -133,7 +136,7 @@ export class AiSdkToAnthropicSSE {
|
||||
* Process a single AI SDK chunk and emit corresponding Anthropic events
|
||||
*/
|
||||
private processChunk(chunk: TextStreamPart<ToolSet>): void {
|
||||
logger.silly('AiSdkToAnthropicSSE - Processing chunk:', chunk)
|
||||
logger.silly('AiSdkToAnthropicSSE - Processing chunk:', { chunk: JSON.stringify(chunk) })
|
||||
switch (chunk.type) {
|
||||
// === Text Events ===
|
||||
case 'text-start':
|
||||
@ -149,17 +152,23 @@ export class AiSdkToAnthropicSSE {
|
||||
break
|
||||
|
||||
// === Reasoning/Thinking Events ===
|
||||
case 'reasoning-start':
|
||||
this.startThinkingBlock()
|
||||
case 'reasoning-start': {
|
||||
const reasoningId = (chunk as { id?: string }).id || `reasoning_${Date.now()}`
|
||||
this.startThinkingBlock(reasoningId)
|
||||
break
|
||||
}
|
||||
|
||||
case 'reasoning-delta':
|
||||
this.emitThinkingDelta(chunk.text || '')
|
||||
case 'reasoning-delta': {
|
||||
const reasoningId = (chunk as { id?: string }).id
|
||||
this.emitThinkingDelta(chunk.text || '', reasoningId)
|
||||
break
|
||||
}
|
||||
|
||||
case 'reasoning-end':
|
||||
this.stopThinkingBlock()
|
||||
case 'reasoning-end': {
|
||||
const reasoningId = (chunk as { id?: string }).id
|
||||
this.stopThinkingBlock(reasoningId)
|
||||
break
|
||||
}
|
||||
|
||||
// === Tool Events ===
|
||||
case 'tool-call':
|
||||
@ -190,9 +199,7 @@ export class AiSdkToAnthropicSSE {
|
||||
|
||||
// === Error Events ===
|
||||
case 'error':
|
||||
// Anthropic doesn't have a standard error event in the stream
|
||||
// Errors are typically sent as separate HTTP responses
|
||||
// For now, we'll just log and continue
|
||||
this.handleError(chunk.error)
|
||||
break
|
||||
|
||||
// Ignore other event types
|
||||
@ -303,11 +310,13 @@ export class AiSdkToAnthropicSSE {
|
||||
this.state.textBlockIndex = null
|
||||
}
|
||||
|
||||
private startThinkingBlock(): void {
|
||||
if (this.state.thinkingBlockIndex !== null) return
|
||||
private startThinkingBlock(reasoningId: string): void {
|
||||
// Check if this thinking block already exists
|
||||
if (this.state.thinkingBlocks.has(reasoningId)) return
|
||||
|
||||
const index = this.state.currentBlockIndex++
|
||||
this.state.thinkingBlockIndex = index
|
||||
this.state.thinkingBlocks.set(reasoningId, index)
|
||||
this.state.currentThinkingId = reasoningId
|
||||
this.state.blocks.set(index, {
|
||||
type: 'thinking',
|
||||
index,
|
||||
@ -330,15 +339,25 @@ export class AiSdkToAnthropicSSE {
|
||||
this.onEvent(event)
|
||||
}
|
||||
|
||||
private emitThinkingDelta(text: string): void {
|
||||
private emitThinkingDelta(text: string, reasoningId?: string): void {
|
||||
if (!text) return
|
||||
|
||||
// Auto-start thinking block if not started
|
||||
if (this.state.thinkingBlockIndex === null) {
|
||||
this.startThinkingBlock()
|
||||
// Determine which thinking block to use
|
||||
const targetId = reasoningId || this.state.currentThinkingId
|
||||
if (!targetId) {
|
||||
// Auto-start thinking block if not started
|
||||
const newId = `reasoning_${Date.now()}`
|
||||
this.startThinkingBlock(newId)
|
||||
return this.emitThinkingDelta(text, newId)
|
||||
}
|
||||
|
||||
const index = this.state.thinkingBlocks.get(targetId)
|
||||
if (index === undefined) {
|
||||
// If the block doesn't exist, create it
|
||||
this.startThinkingBlock(targetId)
|
||||
return this.emitThinkingDelta(text, targetId)
|
||||
}
|
||||
|
||||
const index = this.state.thinkingBlockIndex!
|
||||
const block = this.state.blocks.get(index)
|
||||
if (block) {
|
||||
block.content += text
|
||||
@ -358,10 +377,12 @@ export class AiSdkToAnthropicSSE {
|
||||
this.onEvent(event)
|
||||
}
|
||||
|
||||
private stopThinkingBlock(): void {
|
||||
if (this.state.thinkingBlockIndex === null) return
|
||||
private stopThinkingBlock(reasoningId?: string): void {
|
||||
const targetId = reasoningId || this.state.currentThinkingId
|
||||
if (!targetId) return
|
||||
|
||||
const index = this.state.thinkingBlockIndex
|
||||
const index = this.state.thinkingBlocks.get(targetId)
|
||||
if (index === undefined) return
|
||||
|
||||
const event: RawContentBlockStopEvent = {
|
||||
type: 'content_block_stop',
|
||||
@ -369,7 +390,14 @@ export class AiSdkToAnthropicSSE {
|
||||
}
|
||||
|
||||
this.onEvent(event)
|
||||
this.state.thinkingBlockIndex = null
|
||||
this.state.thinkingBlocks.delete(targetId)
|
||||
|
||||
// Update currentThinkingId if we just closed the current one
|
||||
if (this.state.currentThinkingId === targetId) {
|
||||
// Set to the most recent remaining thinking block, or null if none
|
||||
const remaining = Array.from(this.state.thinkingBlocks.keys())
|
||||
this.state.currentThinkingId = remaining.length > 0 ? remaining[remaining.length - 1] : null
|
||||
}
|
||||
}
|
||||
|
||||
private handleToolCall(chunk: { type: 'tool-call'; toolCallId: string; toolName: string; args: unknown }): void {
|
||||
@ -471,13 +499,41 @@ export class AiSdkToAnthropicSSE {
|
||||
}
|
||||
}
|
||||
|
||||
private handleError(error: unknown): void {
|
||||
// Log the error for debugging
|
||||
logger.warn('AiSdkToAnthropicSSE - Provider error received:', { error })
|
||||
|
||||
// Extract error message
|
||||
let errorMessage = 'Unknown error from provider'
|
||||
if (error && typeof error === 'object') {
|
||||
const err = error as { message?: string; metadata?: { raw?: string } }
|
||||
if (err.metadata?.raw) {
|
||||
errorMessage = `Provider error: ${err.metadata.raw}`
|
||||
} else if (err.message) {
|
||||
errorMessage = err.message
|
||||
}
|
||||
} else if (typeof error === 'string') {
|
||||
errorMessage = error
|
||||
}
|
||||
|
||||
// Emit error as a text block so the user can see it
|
||||
// First close any open thinking blocks to maintain proper event order
|
||||
for (const reasoningId of Array.from(this.state.thinkingBlocks.keys())) {
|
||||
this.stopThinkingBlock(reasoningId)
|
||||
}
|
||||
|
||||
// Emit the error as text
|
||||
this.emitTextDelta(`\n\n[Error: ${errorMessage}]\n`)
|
||||
}
|
||||
|
||||
private finalize(): void {
|
||||
// Close any open blocks
|
||||
if (this.state.textBlockIndex !== null) {
|
||||
this.stopTextBlock()
|
||||
}
|
||||
if (this.state.thinkingBlockIndex !== null) {
|
||||
this.stopThinkingBlock()
|
||||
// Close all open thinking blocks
|
||||
for (const reasoningId of this.state.thinkingBlocks.keys()) {
|
||||
this.stopThinkingBlock(reasoningId)
|
||||
}
|
||||
|
||||
// Emit message_delta with final stop reason and usage
|
||||
|
||||
@ -1,6 +1,11 @@
|
||||
import type { LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
|
||||
import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
|
||||
import type { ImageBlockParam, MessageCreateParams, TextBlockParam } from '@anthropic-ai/sdk/resources/messages'
|
||||
import type {
|
||||
ImageBlockParam,
|
||||
MessageCreateParams,
|
||||
TextBlockParam,
|
||||
Tool as AnthropicTool
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { reduxService } from '@main/services/ReduxService'
|
||||
@ -16,8 +21,8 @@ import {
|
||||
} from '@shared/provider'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import type { Provider } from '@types'
|
||||
import type { ImagePart, LanguageModel, ModelMessage, Provider as AiSdkProvider, TextPart } from 'ai'
|
||||
import { stepCountIs, streamText } from 'ai'
|
||||
import type { ImagePart, LanguageModel, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai'
|
||||
import { jsonSchema, stepCountIs, streamText, tool } from 'ai'
|
||||
import { net } from 'electron'
|
||||
import type { Response } from 'express'
|
||||
|
||||
@ -190,6 +195,39 @@ IANA media type.
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Anthropic tools format to AI SDK tools format
|
||||
*/
|
||||
function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record<string, Tool> | undefined {
|
||||
if (!tools || tools.length === 0) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const aiSdkTools: Record<string, Tool> = {}
|
||||
|
||||
for (const anthropicTool of tools) {
|
||||
// Handle different tool types
|
||||
if (anthropicTool.type === 'bash_20250124') {
|
||||
// Skip computer use and bash tools - these are Anthropic-specific
|
||||
continue
|
||||
}
|
||||
|
||||
// Regular tool (type === 'custom' or no type)
|
||||
const toolDef = anthropicTool as AnthropicTool
|
||||
const parameters = toolDef.input_schema as Parameters<typeof jsonSchema>[0]
|
||||
|
||||
aiSdkTools[toolDef.name] = tool({
|
||||
description: toolDef.description || '',
|
||||
inputSchema: jsonSchema(parameters),
|
||||
execute: async (input: Record<string, unknown>) => {
|
||||
return input
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Anthropic MessageCreateParams to AI SDK message format
|
||||
*/
|
||||
@ -271,6 +309,13 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
|
||||
}
|
||||
}
|
||||
|
||||
if (toolResultParts.length > 0) {
|
||||
messages.push({
|
||||
role: 'tool',
|
||||
content: [...toolResultParts]
|
||||
})
|
||||
}
|
||||
|
||||
// Build the message based on role
|
||||
if (msg.role === 'user') {
|
||||
messages.push({
|
||||
@ -278,13 +323,11 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
|
||||
content: [...textParts, ...imageParts]
|
||||
})
|
||||
} else {
|
||||
// Assistant messages can only have text
|
||||
if (textParts.length > 0) {
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content: [...reasoningParts, ...textParts, ...toolCallParts, ...toolResultParts]
|
||||
})
|
||||
}
|
||||
// Assistant messages contain tool calls, not tool results
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content: [...reasoningParts, ...textParts, ...toolCallParts]
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -315,10 +358,29 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis
|
||||
|
||||
const coreMessages = convertAnthropicToAiMessages(params)
|
||||
|
||||
// Convert tools if present
|
||||
const tools = convertAnthropicToolsToAiSdk(params.tools)
|
||||
|
||||
logger.debug('Converted messages', {
|
||||
originalCount: params.messages.length,
|
||||
convertedCount: coreMessages.length,
|
||||
hasSystem: !!params.system
|
||||
hasSystem: !!params.system,
|
||||
hasTools: !!tools,
|
||||
toolCount: tools ? Object.keys(tools).length : 0,
|
||||
toolNames: tools ? Object.keys(tools).slice(0, 10) : [],
|
||||
paramsToolCount: params.tools?.length || 0
|
||||
})
|
||||
|
||||
// Debug: Log message structure to understand tool_result handling
|
||||
logger.silly('Message structure for debugging', {
|
||||
messages: coreMessages.map((m) => ({
|
||||
role: m.role,
|
||||
contentTypes: Array.isArray(m.content)
|
||||
? m.content.map((c: { type: string }) => c.type)
|
||||
: typeof m.content === 'string'
|
||||
? ['string']
|
||||
: ['unknown']
|
||||
}))
|
||||
})
|
||||
|
||||
// Create the adapter
|
||||
@ -340,6 +402,7 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis
|
||||
stopSequences: params.stop_sequences,
|
||||
stopWhen: stepCountIs(100),
|
||||
headers: defaultAppHeaders(),
|
||||
tools,
|
||||
providerOptions: {}
|
||||
})
|
||||
|
||||
@ -404,8 +467,9 @@ export async function generateUnifiedMessage(
|
||||
// Create language model (async - uses @cherrystudio/ai-core)
|
||||
const model = await createLanguageModel(provider, modelId)
|
||||
|
||||
// Convert messages
|
||||
// Convert messages and tools
|
||||
const coreMessages = convertAnthropicToAiMessages(params)
|
||||
const tools = convertAnthropicToolsToAiSdk(params.tools)
|
||||
|
||||
// Create adapter to collect the response
|
||||
let finalResponse: ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse> | null = null
|
||||
@ -425,6 +489,7 @@ export async function generateUnifiedMessage(
|
||||
topP: params.top_p,
|
||||
stopSequences: params.stop_sequences,
|
||||
headers: defaultAppHeaders(),
|
||||
tools,
|
||||
stopWhen: stepCountIs(100)
|
||||
})
|
||||
|
||||
|
||||
@ -193,6 +193,30 @@ function handleAssistantMessage(
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'thinking':
|
||||
case 'redacted_thinking': {
|
||||
const thinkingText = block.type === 'thinking' ? block.thinking : block.data
|
||||
if (thinkingText) {
|
||||
const id = generateMessageId()
|
||||
chunks.push({
|
||||
type: 'reasoning-start',
|
||||
id,
|
||||
providerMetadata
|
||||
})
|
||||
chunks.push({
|
||||
type: 'reasoning-delta',
|
||||
id,
|
||||
text: thinkingText,
|
||||
providerMetadata
|
||||
})
|
||||
chunks.push({
|
||||
type: 'reasoning-end',
|
||||
id,
|
||||
providerMetadata
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'tool_use':
|
||||
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
|
||||
break
|
||||
@ -445,7 +469,11 @@ function handleStreamEvent(
|
||||
case 'content_block_stop': {
|
||||
const block = state.closeBlock(event.index)
|
||||
if (!block) {
|
||||
logger.warn('Received content_block_stop for unknown index', { index: event.index })
|
||||
// Some providers (e.g., Gemini) send content via assistant message before stream events,
|
||||
// so the block may not exist in state. This is expected behavior, not an error.
|
||||
logger.debug('Received content_block_stop for unknown index (may be from non-streaming content)', {
|
||||
index: event.index
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user