refactor: reorganize AiSdkToChunkAdapter and enhance tool call handling

- Moved AiSdkToChunkAdapter to a new directory structure for better organization.
- Implemented detailed handling for tool call events in ToolCallChunkHandler, including creation, updates, and completions.
- Added a new method to handle tool call creation and improved state management for active tool calls.
- Updated StreamProcessingService to support new chunk types and callbacks for block creation.
- Enhanced type definitions and added comments for clarity in the new chunk handling logic.
This commit is contained in:
MyPrototypeWhat 2025-07-17 16:30:26 +08:00
parent f38e4a87b8
commit 650650a68f
8 changed files with 116 additions and 23 deletions

View File

@ -7,7 +7,7 @@ import { TextStreamPart, ToolSet } from '@cherrystudio/ai-core'
import { BaseTool, WebSearchResults, WebSearchSource } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { ToolCallChunkHandler } from './chunk/handleTooCallChunk'
import { ToolCallChunkHandler } from './handleTooCallChunk'
export interface CherryStudioChunk {
type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error'
@ -56,7 +56,8 @@ export class AiSdkToChunkAdapter {
const final = {
text: '',
reasoningContent: '',
webSearchResults: []
webSearchResults: [],
reasoningId: ''
}
try {
while (true) {
@ -80,7 +81,7 @@ export class AiSdkToChunkAdapter {
*/
private convertAndEmitChunk(
chunk: TextStreamPart<any>,
final: { text: string; reasoningContent: string; webSearchResults: any[] }
final: { text: string; reasoningContent: string; webSearchResults: any[]; reasoningId: string }
) {
console.log('AI SDK chunk type:', chunk.type, chunk)
switch (chunk.type) {
@ -96,6 +97,7 @@ export class AiSdkToChunkAdapter {
type: ChunkType.TEXT_DELTA,
text: final.text || ''
})
console.log('final.text', final.text)
break
case 'text-end':
this.onChunk({
@ -105,9 +107,12 @@ export class AiSdkToChunkAdapter {
final.text = ''
break
case 'reasoning-start':
this.onChunk({
type: ChunkType.THINKING_START
})
if (final.reasoningId !== chunk.id) {
final.reasoningId = chunk.id
this.onChunk({
type: ChunkType.THINKING_START
})
}
break
case 'reasoning':
this.onChunk({
@ -127,6 +132,16 @@ export class AiSdkToChunkAdapter {
break
// === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) ===
case 'tool-input-start':
case 'tool-input-delta':
case 'tool-input-end':
this.toolCallHandler.handleToolCallCreated(chunk)
break
// case 'tool-input-delta':
// this.toolCallHandler.handleToolCallCreated(chunk)
// break
case 'tool-call':
// 原始的工具调用(未被中间件处理)
this.toolCallHandler.handleToolCall(chunk)
@ -136,17 +151,17 @@ export class AiSdkToChunkAdapter {
// 原始的工具调用结果(未被中间件处理)
this.toolCallHandler.handleToolResult(chunk)
break
// case 'start':
// this.onChunk({
// type: ChunkType.LLM_RESPONSE_CREATED
// })
// break
// === 步骤相关事件 ===
case 'start':
this.onChunk({
type: ChunkType.LLM_RESPONSE_CREATED
})
break
// TODO: 需要区分接口开始和步骤开始
// case 'step-start':
// case 'start-step':
// this.onChunk({
// type: ChunkType.LLM_RESPONSE_CREATED
// type: ChunkType.BLOCK_CREATED
// })
// break
// case 'step-finish':
@ -178,8 +193,8 @@ export class AiSdkToChunkAdapter {
})
}
final.webSearchResults = []
// final.reasoningId = ''
break
// const { totalUsage, finishReason, providerMetadata } = chunk
}
case 'finish':
@ -256,8 +271,8 @@ export class AiSdkToChunkAdapter {
break
default:
// 其他类型的 chunk 可以忽略或记录日志
console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
// 其他类型的 chunk 可以忽略或记录日志
// console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
}
}
}

View File

@ -40,6 +40,69 @@ export class ToolCallChunkHandler {
// this.onChunk = callback
// }
handleToolCallCreated(chunk: { type: 'tool-input-start' | 'tool-input-delta' | 'tool-input-end' }): void {
switch (chunk.type) {
case 'tool-input-start': {
this.activeToolCalls.set(chunk.id, {
toolCallId: chunk.id,
toolName: chunk.toolName,
args: '',
mcpTool: {
id: chunk.id,
name: chunk.toolName,
description: chunk.toolName,
type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
}
})
break
}
case 'tool-input-delta': {
const toolCall = this.activeToolCalls.get(chunk.id)
if (!toolCall) {
Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
return
}
toolCall.args += chunk.delta
break
}
case 'tool-input-end': {
const toolCall = this.activeToolCalls.get(chunk.id)
this.activeToolCalls.delete(chunk.id)
if (!toolCall) {
Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
return
}
const toolResponse: MCPToolResponse = {
id: toolCall.toolCallId,
tool: toolCall.mcpTool,
arguments: toolCall.args,
status: 'pending',
toolCallId: toolCall.toolCallId
}
this.onChunk({
type: ChunkType.MCP_TOOL_PENDING,
responses: [toolResponse]
})
break
}
}
// if (!toolCall) {
// Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
// return
// }
// this.onChunk({
// type: ChunkType.MCP_TOOL_CREATED,
// tool_calls: [
// {
// id: chunk.id,
// name: chunk.toolName,
// status: 'pending'
// }
// ]
// })
}
/**
*
*/

View File

@ -25,7 +25,7 @@ import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
import { cloneDeep } from 'lodash'
import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
import LegacyAiProvider from './index'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsResult } from './middleware/schemas'

View File

@ -2,7 +2,7 @@ import type { RootState } from '@renderer/store'
import { messageBlocksSelectors } from '@renderer/store/messageBlock'
import type { ImageMessageBlock, MainTextMessageBlock, Message, MessageBlock } from '@renderer/types/newMessage'
import { MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
import { AnimatePresence, motion } from 'motion/react'
import { AnimatePresence, motion, type Variants } from 'motion/react'
import React, { useMemo } from 'react'
import { useSelector } from 'react-redux'
import styled from 'styled-components'
@ -22,7 +22,7 @@ interface AnimatedBlockWrapperProps {
enableAnimation: boolean
}
const blockWrapperVariants = {
const blockWrapperVariants: Variants = {
visible: {
opacity: 1,
x: 0,

View File

@ -16,7 +16,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import store from '@renderer/store'
import { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
import { type Chunk, ChunkType } from '@renderer/types/chunk'
import { type Chunk } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import { SdkModel } from '@renderer/types/sdk'
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
@ -435,7 +435,7 @@ export async function fetchChatCompletion({
mcpTools
}
// --- Call AI Completions ---
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
// onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
await AI.completions(modelId, aiSdkParams, middlewareConfig)
// if (enableWebSearch) {
// onChunkReceived({ type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS })

View File

@ -40,6 +40,8 @@ export interface StreamProcessorCallbacks {
onError?: (error: any) => void
// Called when the entire stream processing is signaled as complete (success or failure)
onComplete?: (status: AssistantMessageStatus, response?: Response) => void
// Called when a block is created
onBlockCreated?: () => void
}
// Function to create a stream processor instance
@ -133,9 +135,13 @@ export function createStreamProcessor(callbacks: StreamProcessorCallbacks = {})
if (callbacks.onError) callbacks.onError(data.error)
break
}
case ChunkType.BLOCK_CREATED: {
if (callbacks.onBlockCreated) callbacks.onBlockCreated()
break
}
default: {
// Handle unknown chunk types or log an error
console.warn(`Unknown chunk type: ${data.type}`)
// console.warn(`Unknown chunk type: ${data.type}`)
}
}
} catch (error) {

View File

@ -76,7 +76,6 @@ export class BlockManager {
blockType: MessageBlockType,
isComplete: boolean = false
) {
console.log('smartBlockUpdate', blockId, changes, blockType, isComplete)
const isBlockTypeChanged = this._lastBlockType !== null && this._lastBlockType !== blockType
if (isBlockTypeChanged || isComplete) {
// 如果块类型改变,则取消上一个块的节流更新

View File

@ -66,6 +66,16 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
})
await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
},
// onBlockCreated: async () => {
// if (blockManager.hasInitialPlaceholder) {
// return
// }
// console.log('onBlockCreated')
// const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, {
// status: MessageBlockStatus.PROCESSING
// })
// await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
// },
onError: async (error: any) => {
console.dir(error, { depth: null })