mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
refactor: reorganize AiSdkToChunkAdapter and enhance tool call handling
- Moved AiSdkToChunkAdapter to a new directory structure for better organization. - Implemented detailed handling for tool call events in ToolCallChunkHandler, including creation, updates, and completions. - Added a new method to handle tool call creation and improved state management for active tool calls. - Updated StreamProcessingService to support new chunk types and callbacks for block creation. - Enhanced type definitions and added comments for clarity in the new chunk handling logic.
This commit is contained in:
parent
f38e4a87b8
commit
650650a68f
@ -7,7 +7,7 @@ import { TextStreamPart, ToolSet } from '@cherrystudio/ai-core'
|
||||
import { BaseTool, WebSearchResults, WebSearchSource } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import { ToolCallChunkHandler } from './chunk/handleTooCallChunk'
|
||||
import { ToolCallChunkHandler } from './handleTooCallChunk'
|
||||
|
||||
export interface CherryStudioChunk {
|
||||
type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error'
|
||||
@ -56,7 +56,8 @@ export class AiSdkToChunkAdapter {
|
||||
const final = {
|
||||
text: '',
|
||||
reasoningContent: '',
|
||||
webSearchResults: []
|
||||
webSearchResults: [],
|
||||
reasoningId: ''
|
||||
}
|
||||
try {
|
||||
while (true) {
|
||||
@ -80,7 +81,7 @@ export class AiSdkToChunkAdapter {
|
||||
*/
|
||||
private convertAndEmitChunk(
|
||||
chunk: TextStreamPart<any>,
|
||||
final: { text: string; reasoningContent: string; webSearchResults: any[] }
|
||||
final: { text: string; reasoningContent: string; webSearchResults: any[]; reasoningId: string }
|
||||
) {
|
||||
console.log('AI SDK chunk type:', chunk.type, chunk)
|
||||
switch (chunk.type) {
|
||||
@ -96,6 +97,7 @@ export class AiSdkToChunkAdapter {
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: final.text || ''
|
||||
})
|
||||
console.log('final.text', final.text)
|
||||
break
|
||||
case 'text-end':
|
||||
this.onChunk({
|
||||
@ -105,9 +107,12 @@ export class AiSdkToChunkAdapter {
|
||||
final.text = ''
|
||||
break
|
||||
case 'reasoning-start':
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_START
|
||||
})
|
||||
if (final.reasoningId !== chunk.id) {
|
||||
final.reasoningId = chunk.id
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_START
|
||||
})
|
||||
}
|
||||
break
|
||||
case 'reasoning':
|
||||
this.onChunk({
|
||||
@ -127,6 +132,16 @@ export class AiSdkToChunkAdapter {
|
||||
break
|
||||
|
||||
// === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) ===
|
||||
|
||||
case 'tool-input-start':
|
||||
case 'tool-input-delta':
|
||||
case 'tool-input-end':
|
||||
this.toolCallHandler.handleToolCallCreated(chunk)
|
||||
break
|
||||
|
||||
// case 'tool-input-delta':
|
||||
// this.toolCallHandler.handleToolCallCreated(chunk)
|
||||
// break
|
||||
case 'tool-call':
|
||||
// 原始的工具调用(未被中间件处理)
|
||||
this.toolCallHandler.handleToolCall(chunk)
|
||||
@ -136,17 +151,17 @@ export class AiSdkToChunkAdapter {
|
||||
// 原始的工具调用结果(未被中间件处理)
|
||||
this.toolCallHandler.handleToolResult(chunk)
|
||||
break
|
||||
// case 'start':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||
// })
|
||||
// break
|
||||
|
||||
// === 步骤相关事件 ===
|
||||
case 'start':
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_CREATED
|
||||
})
|
||||
break
|
||||
// TODO: 需要区分接口开始和步骤开始
|
||||
// case 'step-start':
|
||||
// case 'start-step':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||
// type: ChunkType.BLOCK_CREATED
|
||||
// })
|
||||
// break
|
||||
// case 'step-finish':
|
||||
@ -178,8 +193,8 @@ export class AiSdkToChunkAdapter {
|
||||
})
|
||||
}
|
||||
final.webSearchResults = []
|
||||
// final.reasoningId = ''
|
||||
break
|
||||
// const { totalUsage, finishReason, providerMetadata } = chunk
|
||||
}
|
||||
|
||||
case 'finish':
|
||||
@ -256,8 +271,8 @@ export class AiSdkToChunkAdapter {
|
||||
break
|
||||
|
||||
default:
|
||||
// 其他类型的 chunk 可以忽略或记录日志
|
||||
console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
|
||||
// 其他类型的 chunk 可以忽略或记录日志
|
||||
// console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -40,6 +40,69 @@ export class ToolCallChunkHandler {
|
||||
// this.onChunk = callback
|
||||
// }
|
||||
|
||||
handleToolCallCreated(chunk: { type: 'tool-input-start' | 'tool-input-delta' | 'tool-input-end' }): void {
|
||||
switch (chunk.type) {
|
||||
case 'tool-input-start': {
|
||||
this.activeToolCalls.set(chunk.id, {
|
||||
toolCallId: chunk.id,
|
||||
toolName: chunk.toolName,
|
||||
args: '',
|
||||
mcpTool: {
|
||||
id: chunk.id,
|
||||
name: chunk.toolName,
|
||||
description: chunk.toolName,
|
||||
type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'tool-input-delta': {
|
||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
if (!toolCall) {
|
||||
Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
return
|
||||
}
|
||||
toolCall.args += chunk.delta
|
||||
break
|
||||
}
|
||||
case 'tool-input-end': {
|
||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
this.activeToolCalls.delete(chunk.id)
|
||||
if (!toolCall) {
|
||||
Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
return
|
||||
}
|
||||
const toolResponse: MCPToolResponse = {
|
||||
id: toolCall.toolCallId,
|
||||
tool: toolCall.mcpTool,
|
||||
arguments: toolCall.args,
|
||||
status: 'pending',
|
||||
toolCallId: toolCall.toolCallId
|
||||
}
|
||||
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
// if (!toolCall) {
|
||||
// Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
// return
|
||||
// }
|
||||
// this.onChunk({
|
||||
// type: ChunkType.MCP_TOOL_CREATED,
|
||||
// tool_calls: [
|
||||
// {
|
||||
// id: chunk.id,
|
||||
// name: chunk.toolName,
|
||||
// status: 'pending'
|
||||
// }
|
||||
// ]
|
||||
// })
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具调用事件
|
||||
*/
|
||||
|
||||
@ -25,7 +25,7 @@ import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { cloneDeep } from 'lodash'
|
||||
|
||||
import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
|
||||
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
|
||||
import LegacyAiProvider from './index'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
|
||||
import { CompletionsResult } from './middleware/schemas'
|
||||
|
||||
@ -2,7 +2,7 @@ import type { RootState } from '@renderer/store'
|
||||
import { messageBlocksSelectors } from '@renderer/store/messageBlock'
|
||||
import type { ImageMessageBlock, MainTextMessageBlock, Message, MessageBlock } from '@renderer/types/newMessage'
|
||||
import { MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
|
||||
import { AnimatePresence, motion } from 'motion/react'
|
||||
import { AnimatePresence, motion, type Variants } from 'motion/react'
|
||||
import React, { useMemo } from 'react'
|
||||
import { useSelector } from 'react-redux'
|
||||
import styled from 'styled-components'
|
||||
@ -22,7 +22,7 @@ interface AnimatedBlockWrapperProps {
|
||||
enableAnimation: boolean
|
||||
}
|
||||
|
||||
const blockWrapperVariants = {
|
||||
const blockWrapperVariants: Variants = {
|
||||
visible: {
|
||||
opacity: 1,
|
||||
x: 0,
|
||||
|
||||
@ -16,7 +16,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import store from '@renderer/store'
|
||||
import { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
|
||||
import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { type Chunk } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import { SdkModel } from '@renderer/types/sdk'
|
||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||
@ -435,7 +435,7 @@ export async function fetchChatCompletion({
|
||||
mcpTools
|
||||
}
|
||||
// --- Call AI Completions ---
|
||||
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
// onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
await AI.completions(modelId, aiSdkParams, middlewareConfig)
|
||||
// if (enableWebSearch) {
|
||||
// onChunkReceived({ type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS })
|
||||
|
||||
@ -40,6 +40,8 @@ export interface StreamProcessorCallbacks {
|
||||
onError?: (error: any) => void
|
||||
// Called when the entire stream processing is signaled as complete (success or failure)
|
||||
onComplete?: (status: AssistantMessageStatus, response?: Response) => void
|
||||
// Called when a block is created
|
||||
onBlockCreated?: () => void
|
||||
}
|
||||
|
||||
// Function to create a stream processor instance
|
||||
@ -133,9 +135,13 @@ export function createStreamProcessor(callbacks: StreamProcessorCallbacks = {})
|
||||
if (callbacks.onError) callbacks.onError(data.error)
|
||||
break
|
||||
}
|
||||
case ChunkType.BLOCK_CREATED: {
|
||||
if (callbacks.onBlockCreated) callbacks.onBlockCreated()
|
||||
break
|
||||
}
|
||||
default: {
|
||||
// Handle unknown chunk types or log an error
|
||||
console.warn(`Unknown chunk type: ${data.type}`)
|
||||
// console.warn(`Unknown chunk type: ${data.type}`)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
|
||||
@ -76,7 +76,6 @@ export class BlockManager {
|
||||
blockType: MessageBlockType,
|
||||
isComplete: boolean = false
|
||||
) {
|
||||
console.log('smartBlockUpdate', blockId, changes, blockType, isComplete)
|
||||
const isBlockTypeChanged = this._lastBlockType !== null && this._lastBlockType !== blockType
|
||||
if (isBlockTypeChanged || isComplete) {
|
||||
// 如果块类型改变,则取消上一个块的节流更新
|
||||
|
||||
@ -66,6 +66,16 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
|
||||
})
|
||||
await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
|
||||
},
|
||||
// onBlockCreated: async () => {
|
||||
// if (blockManager.hasInitialPlaceholder) {
|
||||
// return
|
||||
// }
|
||||
// console.log('onBlockCreated')
|
||||
// const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, {
|
||||
// status: MessageBlockStatus.PROCESSING
|
||||
// })
|
||||
// await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
|
||||
// },
|
||||
|
||||
onError: async (error: any) => {
|
||||
console.dir(error, { depth: null })
|
||||
|
||||
Loading…
Reference in New Issue
Block a user