From 76de357cbfe96bbb2145004131717fca90cee9ea Mon Sep 17 00:00:00 2001 From: SuYao Date: Tue, 15 Jul 2025 15:39:40 +0800 Subject: [PATCH] test: add integration test for message thunk and fix some bugs (#8148) * test: add integration test for message thunk and fix some bugs * fix: ci --- src/renderer/src/aiCore/index.ts | 4 + .../middleware/core/TextChunkMiddleware.ts | 11 +- .../middleware/core/ThinkChunkMiddleware.ts | 1 + src/renderer/src/aiCore/middleware/schemas.ts | 2 +- .../src/services/__tests__/ApiService.test.ts | 985 +++++++++++++++ src/renderer/src/store/messageBlock.ts | 2 +- src/renderer/src/store/newMessage.ts | 2 +- .../streamCallback.integration.test.ts | 766 ++++++++++++ src/renderer/src/store/thunk/messageThunk.ts | 1080 ++++++++--------- 9 files changed, 2300 insertions(+), 553 deletions(-) create mode 100644 src/renderer/src/services/__tests__/ApiService.test.ts create mode 100644 src/renderer/src/store/thunk/__tests__/streamCallback.integration.test.ts diff --git a/src/renderer/src/aiCore/index.ts b/src/renderer/src/aiCore/index.ts index 51cb84df15..f9caa80f81 100644 --- a/src/renderer/src/aiCore/index.ts +++ b/src/renderer/src/aiCore/index.ts @@ -99,6 +99,10 @@ export default class AiProvider { if (params.callType !== 'chat') { builder.remove(AbortHandlerMiddlewareName) } + if (params.callType === 'test') { + builder.remove(ErrorHandlerMiddlewareName) + builder.remove(FinalChunkConsumerMiddlewareName) + } } const middlewares = builder.build() diff --git a/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts b/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts index 0affc6b382..db9fdb253a 100644 --- a/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts +++ b/src/renderer/src/aiCore/middleware/core/TextChunkMiddleware.ts @@ -1,5 +1,5 @@ import Logger from '@renderer/config/logger' -import { ChunkType, TextDeltaChunk } from '@renderer/types/chunk' +import { ChunkType } from '@renderer/types/chunk' import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas' import { CompletionsContext, CompletionsMiddleware } from '../types' @@ -42,16 +42,17 @@ export const TextChunkMiddleware: CompletionsMiddleware = new TransformStream({ transform(chunk: GenericChunk, controller) { if (chunk.type === ChunkType.TEXT_DELTA) { - const textChunk = chunk as TextDeltaChunk - accumulatedTextContent += textChunk.text + accumulatedTextContent += chunk.text // 处理 onResponse 回调 - 发送增量文本更新 if (params.onResponse) { params.onResponse(accumulatedTextContent, false) } - // 创建新的chunk,包含处理后的文本 - controller.enqueue(chunk) + controller.enqueue({ + ...chunk, + text: accumulatedTextContent // 增量更新 + }) } else if (accumulatedTextContent && chunk.type !== ChunkType.TEXT_START) { if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) { const finalText = accumulatedTextContent diff --git a/src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts b/src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts index 22eaabe96d..957b925400 100644 --- a/src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts +++ b/src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts @@ -62,6 +62,7 @@ export const ThinkChunkMiddleware: CompletionsMiddleware = // 更新思考时间并传递 const enhancedChunk: ThinkingDeltaChunk = { ...thinkingChunk, + text: accumulatedThinkingContent, thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 } controller.enqueue(enhancedChunk) diff --git a/src/renderer/src/aiCore/middleware/schemas.ts b/src/renderer/src/aiCore/middleware/schemas.ts index 75247443c2..2e60214625 100644 --- a/src/renderer/src/aiCore/middleware/schemas.ts +++ b/src/renderer/src/aiCore/middleware/schemas.ts @@ -23,7 +23,7 @@ export interface CompletionsParams { * 'generate': 生成 * 'check': API检查 */ - callType?: 'chat' | 'translate' | 'summary' | 'search' | 'generate' | 'check' + callType?: 'chat' | 'translate' | 'summary' | 'search' | 'generate' | 'check' | 'test' // 基础对话数据 messages: Message[] | string // 联合类型方便判断是否为空 diff --git a/src/renderer/src/services/__tests__/ApiService.test.ts b/src/renderer/src/services/__tests__/ApiService.test.ts new file mode 100644 index 0000000000..b88d0d0775 --- /dev/null +++ b/src/renderer/src/services/__tests__/ApiService.test.ts @@ -0,0 +1,985 @@ +import { FinishReason, MediaModality } from '@google/genai' +import { FunctionCall } from '@google/genai' +import AiProvider from '@renderer/aiCore' +import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory' +import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' +import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient' +import { GenericChunk } from '@renderer/aiCore/middleware/schemas' +import { Assistant, Provider, WebSearchSource } from '@renderer/types' +import { + ChunkType, + LLMResponseCompleteChunk, + LLMWebSearchCompleteChunk, + TextDeltaChunk, + TextStartChunk, + ThinkingStartChunk +} from '@renderer/types/chunk' +import { GeminiSdkRawChunk } from '@renderer/types/sdk' +import { cloneDeep } from 'lodash' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +// Mock the ApiClientFactory +vi.mock('@renderer/aiCore/clients/ApiClientFactory', () => ({ + ApiClientFactory: { + create: vi.fn() + } +})) + +// Mock the models config +vi.mock('@renderer/config/models', () => ({ + isDedicatedImageGenerationModel: vi.fn(() => false), + isTextToImageModel: vi.fn(() => false), + isEmbeddingModel: vi.fn(() => false), + isRerankModel: vi.fn(() => false), + isVisionModel: vi.fn(() => false), + isReasoningModel: vi.fn(() => false), + isWebSearchModel: vi.fn(() => false), + isOpenAIModel: vi.fn(() => false), + isFunctionCallingModel: vi.fn(() => true), + models: { + gemini: { + id: 'gemini-2.5-pro', + name: 'Gemini 2.5 Pro' + } + } +})) + +// Mock uuid +vi.mock('uuid', () => ({ + v4: vi.fn(() => 'mock-uuid') +})) + +// Mock other necessary modules +vi.mock('@renderer/services/AssistantService', () => ({ + getProviderByModel: vi.fn(() => ({ id: 'gemini', name: 'Gemini' })), + getDefaultAssistant: vi.fn(() => ({ + id: 'mock-assistant', + name: 'Mock Assistant', + model: { id: 'gemini-2.5-pro', name: 'Gemini 2.5 Pro' } + })), + getDefaultTopic: vi.fn(() => ({ + id: 'mock-topic', + assistantId: 'mock-assistant', + createdAt: new Date().toISOString(), + messages: [] + })) +})) + +vi.mock('@renderer/utils', () => ({ + getLowerBaseModelName: vi.fn((name) => name.toLowerCase()) +})) + +vi.mock('@renderer/config/prompts', () => ({ + WEB_SEARCH_PROMPT_FOR_OPENROUTER: 'mock-prompt' +})) + +vi.mock('@renderer/config/systemModels', () => ({ + GENERATE_IMAGE_MODELS: [], + SUPPORTED_DISABLE_GENERATION_MODELS: [] +})) + +vi.mock('@renderer/config/tools', () => ({ + getWebSearchTools: vi.fn(() => []) +})) + +// Mock store modules +vi.mock('@renderer/store/assistants', () => ({ + default: (state = { assistants: [] }) => state +})) + +vi.mock('@renderer/store/agents', () => ({ + default: (state = { agents: [] }) => state +})) + +vi.mock('@renderer/store/backup', () => ({ + default: (state = { backups: [] }) => state +})) + +vi.mock('@renderer/store/chat', () => ({ + default: (state = { messages: [] }) => state +})) + +vi.mock('@renderer/store/files', () => ({ + default: (state = { files: [] }) => state +})) + +vi.mock('@renderer/store/knowledge', () => ({ + default: (state = { knowledge: [] }) => state +})) + +vi.mock('@renderer/store/paintings', () => ({ + default: (state = { paintings: [] }) => state +})) + +vi.mock('@renderer/store/runtime', () => ({ + default: (state = { runtime: {} }) => state +})) + +vi.mock('@renderer/store/settings', () => ({ + default: (state = { settings: {} }) => state +})) + +vi.mock('@renderer/store/topics', () => ({ + default: (state = { topics: [] }) => state +})) + +vi.mock('@renderer/store/translate', () => ({ + default: (state = { translate: {} }) => state +})) + +vi.mock('@renderer/store/websearch', () => ({ + default: (state = { websearch: {} }) => state +})) + +vi.mock('@renderer/store/migrate', () => ({ + default: vi.fn().mockResolvedValue(undefined) +})) + +// Mock the llm store with a proper reducer function +vi.mock('@renderer/store/llm.ts', () => { + const mockInitialState = { + providers: [ + { + id: 'gemini', + name: 'Gemini', + type: 'gemini', + apiKey: 'mock-api-key', + apiHost: 'mock-api-host', + models: [ + { + id: 'gemini-2.5-pro', + name: 'Gemini 2.5 Pro', + provider: 'gemini' + } + ], + isSystem: true, + enabled: true + } + ], + defaultModel: { + id: 'gemini-2.5-pro', + name: 'Gemini 2.5 Pro', + provider: 'gemini' + }, + topicNamingModel: { + id: 'gemini-2.5-pro', + name: 'Gemini 2.5 Pro', + provider: 'gemini' + }, + translateModel: { + id: 'gemini-2.5-pro', + name: 'Gemini 2.5 Pro', + provider: 'gemini' + }, + quickAssistantId: '', + settings: { + ollama: { keepAliveTime: 0 }, + lmstudio: { keepAliveTime: 0 }, + gpustack: { keepAliveTime: 0 }, + vertexai: { + serviceAccount: { + privateKey: '', + clientEmail: '' + }, + projectId: '', + location: '' + } + } + } + + const mockReducer = (state = mockInitialState) => { + return state + } + + return { + default: mockReducer, + initialState: mockInitialState + } +}) + +// 测试用例:将 Gemini API 响应数据转换为 geminiChunks 数组 +const geminiChunks: GeminiSdkRawChunk[] = [ + { + candidates: [ + { + content: { + parts: [{ text: 'Hi, 1212312312' }], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: { + promptTokenCount: 383, + candidatesTokenCount: 5, + totalTokenCount: 1157, + promptTokensDetails: [ + { + modality: MediaModality.TEXT, + tokenCount: 383 + } + ], + thoughtsTokenCount: 769 + }, + modelVersion: 'gemini-2.5-pro', + responseId: 'C75waL7rNsPRjrEP3MS5-A8' + } as GeminiSdkRawChunk, + + // 第二个 chunk - 中间响应 + { + candidates: [ + { + content: { + parts: [{ text: '!\n\n我是 Gemini 2.5 Pro,很高兴能为您服务。\n\n今天有什么可以帮您的吗?无论您是' }], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: { + promptTokenCount: 383, + candidatesTokenCount: 32, + totalTokenCount: 1184, + promptTokensDetails: [ + { + modality: MediaModality.TEXT, + tokenCount: 383 + } + ], + thoughtsTokenCount: 769 + }, + modelVersion: 'gemini-2.5-pro', + responseId: 'C75waL7rNsPRjrEP3MS5-A8' + } as GeminiSdkRawChunk, + + // 第三个 chunk - 结束响应 + { + candidates: [ + { + content: { + parts: [{ text: '想寻找信息、进行创作,还是有任何其他问题,我都在这里准备好提供帮助。' }], + role: 'model' + }, + finishReason: FinishReason.STOP, + index: 0 + } + ], + usageMetadata: { + promptTokenCount: 383, + candidatesTokenCount: 53, + totalTokenCount: 1205, + promptTokensDetails: [ + { + modality: MediaModality.TEXT, + tokenCount: 383 + } + ], + thoughtsTokenCount: 769 + }, + modelVersion: 'gemini-2.5-pro', + responseId: 'C75waL7rNsPRjrEP3MS5-A8' + } as GeminiSdkRawChunk +] + +const geminiThinkingChunks: GeminiSdkRawChunk[] = [ + { + candidates: [ + { + content: { + parts: [ + { + text: `**Analyzing Core Functionality**\n\nI've identified the core query: "What can I do?" expressed in Chinese. Recognizing my nature as a Google-trained language model is the foundation for a relevant response. This fundamental understanding guides the development of an effective answer.\n\n\n`, + thought: true + } + ], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: { + promptTokenCount: 6, + candidatesTokenCount: 1, + totalTokenCount: 1020, + promptTokensDetails: [{ modality: MediaModality.TEXT, tokenCount: 6 }], + thoughtsTokenCount: 69 + }, + modelVersion: 'gemini-2.5-flash-lite-preview-06-17', + responseId: 'hNRzaKyMG4DVz7IP6NfaqAs' + } as GeminiSdkRawChunk, + { + candidates: [ + { + content: { + parts: [ + { + text: `**Formulating the Chinese Response**\n\nI'm now drafting a Chinese response. I've moved past the initial simple sentence and am incorporating more detail. The goal is a clear, concise list of my key capabilities, tailored for a user asking about my function. I'm focusing on "understanding and generating text," "answering questions," and "translating languages" for now. Refining the exact phrasing for optimal clarity is an ongoing focus.\n\n\n`, + thought: true + } + ], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: { + promptTokenCount: 6, + candidatesTokenCount: 1, + totalTokenCount: 1020, + promptTokensDetails: [{ modality: MediaModality.TEXT, tokenCount: 6 }], + thoughtsTokenCount: 318 + }, + modelVersion: 'gemini-2.5-flash-lite-preview-06-17', + responseId: 'hNRzaKyMG4DVz7IP6NfaqAs' + } as GeminiSdkRawChunk, + { + candidates: [ + { + content: { + parts: [ + { + text: `**Categorizing My Abilities**\n\nI'm organizing my thoughts to classify the capabilities in my response. I'm grouping functions like text generation and question answering, then differentiating them from specialized features, such as translation and creative writing. Considering the best structure for clarity and comprehensiveness, I am refining the outline for my reply. I'm aiming for concise categories that clearly illustrate the range of my functionality in the response. I'm adding an optional explanation of my training to enrich the overall response.\n\n\n`, + thought: true + } + ], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: { + promptTokenCount: 6, + totalTokenCount: 820, + promptTokensDetails: [{ modality: MediaModality.TEXT, tokenCount: 6 }], + thoughtsTokenCount: 826 + }, + modelVersion: 'gemini-2.5-flash-lite-preview-06-17', + responseId: 'hNRzaKyMG4DVz7IP6NfaqAs' + } as GeminiSdkRawChunk, + { + candidates: [ + { + content: { + parts: [ + { + text: `**Developing the Chinese Draft**\n\nI'm now iterating on the final Chinese response. I've refined the categories to highlight my versatility. I'm focusing on "understanding and generating text" and "answering questions", and adding a section on how I can perform creative writing tasks. I'm aiming for concise explanations for clarity. I will also include a call to action at the end. I'm considering adding an optional sentence describing how I learned the data I know.\n\n\n`, + thought: true + } + ], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: { + promptTokenCount: 6, + totalTokenCount: 1019, + promptTokensDetails: [{ modality: MediaModality.TEXT, tokenCount: 6 }], + thoughtsTokenCount: 1013 + }, + modelVersion: 'gemini-2.5-flash-lite-preview-06-17', + responseId: 'hNRzaKyMG4DVz7IP6NfaqAs' + } as GeminiSdkRawChunk, + { + candidates: [ + { + content: { + parts: [{ text: '我是一个大型语言模型,' }], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: { + promptTokenCount: 6, + candidatesTokenCount: 1, + totalTokenCount: 1020, + promptTokensDetails: [{ modality: MediaModality.TEXT, tokenCount: 6 }], + thoughtsTokenCount: 1019 + }, + modelVersion: 'gemini-2.5-flash-lite-preview-06-17', + responseId: 'hNRzaKyMG4DVz7IP6NfaqAs' + } as GeminiSdkRawChunk, + { + candidates: [ + { + content: { + parts: [ + { + text: '由 Google 训练。\n\n我的能力主要包括:\n\n1. **理解和生成文本:** 我可以阅读、理解并创作各种形式的文本,包括文章、故事、对话、代码等。\n2. ' + } + ], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: { + promptTokenCount: 6, + candidatesTokenCount: 48, + totalTokenCount: 1067, + promptTokensDetails: [{ modality: MediaModality.TEXT, tokenCount: 6 }], + thoughtsTokenCount: 1019 + }, + modelVersion: 'gemini-2.5-flash-lite-preview-06-17', + responseId: 'hNRzaKyMG4DVz7IP6NfaqAs' + } as GeminiSdkRawChunk, + { + candidates: [ + { + content: { + parts: [ + { + text: '**回答问题:** 基于我所学习到的信息,我可以回答你提出的各种问题,无论是事实性的、概念性的还是需要解释的。\n3. **语言翻译:** 我可以翻译多种语言之间的文本。\n' + } + ], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: { + promptTokenCount: 6, + candidatesTokenCount: 95, + totalTokenCount: 1109, + promptTokensDetails: [{ modality: MediaModality.TEXT, tokenCount: 6 }], + thoughtsTokenCount: 1013 + }, + modelVersion: 'gemini-2.5-flash-lite-preview-06-17', + responseId: 'hNRzaKyMG4DVz7IP6NfaqAs' + } as GeminiSdkRawChunk, + { + candidates: [ + { + content: { + parts: [ + { + text: '4. **信息总结:** 我可以阅读长篇文本并提炼出关键信息或进行总结。\n5. **创意写作:** 我可以帮助你创作诗歌、代码、剧本、音乐作品' + } + ], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: { + promptTokenCount: 6, + candidatesTokenCount: 143, + totalTokenCount: 1162, + promptTokensDetails: [{ modality: MediaModality.TEXT, tokenCount: 6 }], + thoughtsTokenCount: 1109 + }, + modelVersion: 'gemini-2.5-flash-lite-preview-06-17', + responseId: 'hNRzaKyMG4DVz7IP6NfaqAs' + } as GeminiSdkRawChunk, + { + candidates: [ + { + content: { + parts: [ + { + text: '、电子邮件、信件等各种创意内容。\n6. **解释概念:** 我可以解释复杂的术语、概念或主题,使其更容易理解。\n7. **对话交流:** 我可以和你进行自然' + } + ], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: { + promptTokenCount: 6, + candidatesTokenCount: 191, + totalTokenCount: 1210, + promptTokensDetails: [{ modality: MediaModality.TEXT, tokenCount: 6 }], + thoughtsTokenCount: 1109 + }, + modelVersion: 'gemini-2.5-flash-lite-preview-06-17', + responseId: 'hNRzaKyMG4DVz7IP6NfaqAs' + } as GeminiSdkRawChunk, + { + candidates: [ + { + content: { + parts: [ + { + text: '流畅的对话,就像与人交流一样。\n8. **学习和研究助手:** 我可以帮助你查找信息、学习新知识、整理思路等。\n\n总的来说,我的目标是为你提供信息、帮助你完成' + } + ], + role: 'model' + }, + index: 0 + } + ], + usageMetadata: { + promptTokenCount: 6, + candidatesTokenCount: 191, + totalTokenCount: 1210, + promptTokensDetails: [{ modality: MediaModality.TEXT, tokenCount: 6 }], + thoughtsTokenCount: 1109 + }, + modelVersion: 'gemini-2.5-flash-lite-preview-06-17', + responseId: 'hNRzaKyMG4DVz7IP6NfaqAs' + } as GeminiSdkRawChunk, + { + candidates: [ + { + content: { + parts: [{ text: '任务,并以有益和富有成效的方式与你互动。\n\n你有什么具体想让我做的吗?' }], + role: 'model' + }, + index: 0, + finishReason: FinishReason.STOP + } + ], + usageMetadata: { + promptTokenCount: 6, + candidatesTokenCount: 266, + totalTokenCount: 1285, + promptTokensDetails: [{ modality: MediaModality.TEXT, tokenCount: 6 }], + thoughtsTokenCount: 1109 + }, + modelVersion: 'gemini-2.5-flash-lite-preview-06-17', + responseId: 'hNRzaKyMG4DVz7IP6NfaqAs' + } as unknown as GeminiSdkRawChunk +] + +// 正确的 async generator 函数 +async function* geminiChunkGenerator(): AsyncGenerator { + for (const chunk of geminiChunks) { + yield chunk + } +} + +async function* geminiThinkingChunkGenerator(): AsyncGenerator { + for (const chunk of geminiThinkingChunks) { + yield chunk + } +} + +// 创建 mock 的 GeminiAPIClient +const mockGeminiApiClient = { + createCompletions: vi.fn().mockImplementation(() => geminiChunkGenerator()), + getResponseChunkTransformer: vi.fn().mockImplementation(() => { + const toolCalls: FunctionCall[] = [] + let isFirstTextChunk = true + let isFirstThinkingChunk = true + return () => ({ + async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController) { + if (chunk.candidates && chunk.candidates.length > 0) { + for (const candidate of chunk.candidates) { + if (candidate.content) { + candidate.content.parts?.forEach((part) => { + const text = part.text || '' + if (part.thought) { + if (isFirstThinkingChunk) { + controller.enqueue({ + type: ChunkType.THINKING_START + } as ThinkingStartChunk) + isFirstThinkingChunk = false + } + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: text + }) + } else if (part.text) { + if (isFirstTextChunk) { + controller.enqueue({ + type: ChunkType.TEXT_START + } as TextStartChunk) + isFirstTextChunk = false + } + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: text + }) + } else if (part.inlineData) { + controller.enqueue({ + type: ChunkType.IMAGE_COMPLETE, + image: { + type: 'base64', + images: [ + part.inlineData?.data?.startsWith('data:') + ? part.inlineData?.data + : `data:${part.inlineData?.mimeType || 'image/png'};base64,${part.inlineData?.data}` + ] + } + }) + } else if (part.functionCall) { + toolCalls.push(part.functionCall) + } + }) + } + + if (candidate.finishReason) { + if (candidate.groundingMetadata) { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + results: candidate.groundingMetadata, + source: WebSearchSource.GEMINI + } + } as LLMWebSearchCompleteChunk) + } + if (toolCalls.length > 0) { + controller.enqueue({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: [...toolCalls] + }) + toolCalls.length = 0 + } + controller.enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0, + completion_tokens: + (chunk.usageMetadata?.totalTokenCount || 0) - (chunk.usageMetadata?.promptTokenCount || 0), + total_tokens: chunk.usageMetadata?.totalTokenCount || 0 + } + } + }) + } + } + } + + if (toolCalls.length > 0) { + controller.enqueue({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: toolCalls + }) + } + } + }) + }), + getSdkInstance: vi.fn(), + getRequestTransformer: vi.fn().mockImplementation(() => ({ + async transform(params: any) { + return { + payload: { + model: params.assistant?.model?.id || 'gemini-2.5-pro', + messages: params.messages || [], + tools: params.tools || [] + }, + metadata: {} + } + } + })), + convertMcpToolsToSdkTools: vi.fn(() => []), + convertSdkToolCallToMcpToolResponse: vi.fn(), + buildSdkMessages: vi.fn(() => []), + extractMessagesFromSdkPayload: vi.fn(() => []), + provider: {} as Provider, + useSystemPromptForTools: true, + getBaseURL: vi.fn(() => 'https://api.gemini.com'), + getApiKey: vi.fn(() => 'mock-api-key') +} as unknown as GeminiAPIClient + +const mockGeminiThinkingApiClient = cloneDeep(mockGeminiApiClient) +mockGeminiThinkingApiClient.createCompletions = vi.fn().mockImplementation(() => geminiThinkingChunkGenerator()) + +const mockProvider = { + id: 'gemini', + type: 'gemini', + name: 'Gemini', + apiKey: 'mock-api-key', + apiHost: 'mock-api-host' +} as Provider + +const collectedChunks: GenericChunk[] = [] +const mockOnChunk = vi.fn((chunk: GenericChunk) => { + collectedChunks.push(chunk) +}) + +describe('ApiService', () => { + beforeEach(() => { + vi.clearAllMocks() + collectedChunks.length = 0 + }) + + it('should return a stream of chunks with correct types and content', async () => { + const mockCreate = vi.mocked(ApiClientFactory.create) + mockCreate.mockReturnValue(mockGeminiApiClient as unknown as BaseApiClient) + const AI = new AiProvider(mockProvider) + + const result = await AI.completions({ + callType: 'test', + messages: [], + assistant: { + id: '1', + name: 'test', + prompt: 'test', + model: { + id: 'gemini-2.5-pro', + name: 'Gemini 2.5 Pro' + } + } as Assistant, + onChunk: mockOnChunk, + mcpTools: [], + maxTokens: 1000, + streamOutput: true + }) + + expect(result).toBeDefined() + expect(ApiClientFactory.create).toHaveBeenCalledWith(mockProvider) + expect(result.stream).toBeDefined() + + // 验证stream中的chunks + const stream = result.stream! as ReadableStream + const reader = stream.getReader() + const chunks: GenericChunk[] = [] + + try { + while (true) { + const { done, value } = await reader.read() + if (done) break + chunks.push(value) + } + } finally { + reader.releaseLock() + } + + const expectedChunks: GenericChunk[] = [ + { + type: ChunkType.TEXT_START + }, + { + type: ChunkType.TEXT_DELTA, + text: 'Hi, 1212312312' + }, + { + type: ChunkType.TEXT_DELTA, + text: 'Hi, 1212312312' + '!\n\n我是 Gemini 2.5 Pro,很高兴能为您服务。\n\n今天有什么可以帮您的吗?无论您是' + }, + { + type: ChunkType.TEXT_DELTA, + text: + 'Hi, 1212312312' + + '!\n\n我是 Gemini 2.5 Pro,很高兴能为您服务。\n\n今天有什么可以帮您的吗?无论您是' + + '想寻找信息、进行创作,还是有任何其他问题,我都在这里准备好提供帮助。' + }, + { + type: ChunkType.TEXT_COMPLETE, + text: + 'Hi, 1212312312' + + '!\n\n我是 Gemini 2.5 Pro,很高兴能为您服务。\n\n今天有什么可以帮您的吗?无论您是' + + '想寻找信息、进行创作,还是有任何其他问题,我都在这里准备好提供帮助。' + }, + { + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + total_tokens: 1205, + prompt_tokens: 383, + completion_tokens: 822 + } + } + } + ] + + expect(chunks).toEqual(expectedChunks) + + // 验证chunk的数量和类型 + expect(chunks.length).toBeGreaterThan(0) + + // 验证第一个chunk应该是TEXT_START + const firstChunk = chunks[0] + expect(firstChunk.type).toBe(ChunkType.TEXT_START) + + // 验证TEXT_DELTA chunks的内容 + const textDeltaChunks = chunks.filter((chunk) => chunk.type === ChunkType.TEXT_DELTA) as TextDeltaChunk[] + expect(textDeltaChunks.length).toBeGreaterThan(0) + + // 验证文本内容 + const expectedTexts = [ + 'Hi, 1212312312', + 'Hi, 1212312312' + '!\n\n我是 Gemini 2.5 Pro,很高兴能为您服务。\n\n今天有什么可以帮您的吗?无论您是', + 'Hi, 1212312312' + + '!\n\n我是 Gemini 2.5 Pro,很高兴能为您服务。\n\n今天有什么可以帮您的吗?无论您是' + + '想寻找信息、进行创作,还是有任何其他问题,我都在这里准备好提供帮助。' + ] + + textDeltaChunks.forEach((chunk, index) => { + expect(chunk.text).toBe(expectedTexts[index]) + }) + + // 验证最后一个chunk应该是LLM_RESPONSE_COMPLETE + const lastChunk = chunks[chunks.length - 1] + expect(lastChunk.type).toBe(ChunkType.LLM_RESPONSE_COMPLETE) + + // 验证LLM_RESPONSE_COMPLETE chunk包含usage信息 + const completionChunk = lastChunk as LLMResponseCompleteChunk + expect(completionChunk.response?.usage).toBeDefined() + expect(completionChunk.response?.usage?.total_tokens).toBe(1205) + expect(completionChunk.response?.usage?.prompt_tokens).toBe(383) + expect(completionChunk.response?.usage?.completion_tokens).toBe(822) + }) + + it('should return a stream of thinking chunks with correct types and content', async () => { + const mockCreate = vi.mocked(ApiClientFactory.create) + mockCreate.mockReturnValue(mockGeminiThinkingApiClient as unknown as BaseApiClient) + const AI = new AiProvider(mockProvider) + + const result = await AI.completions({ + callType: 'test', + messages: [], + assistant: { + id: '1', + name: 'test', + prompt: 'test', + model: { + id: 'gemini-2.5-pro', + name: 'Gemini 2.5 Pro' + } + } as Assistant, + onChunk: mockOnChunk, + enableReasoning: true, + streamOutput: true + }) + + expect(result).toBeDefined() + expect(ApiClientFactory.create).toHaveBeenCalledWith(mockProvider) + expect(result.stream).toBeDefined() + + const stream = result.stream! as ReadableStream + const reader = stream.getReader() + + const chunks: GenericChunk[] = [] + + while (true) { + const { done, value } = await reader.read() + if (done) break + chunks.push(value) + } + + reader.releaseLock() + + // 过滤掉 thinking_millsec 字段,因为它不是确定值 + const filteredChunks = chunks.map((chunk) => { + if (chunk.type === ChunkType.THINKING_DELTA || chunk.type === ChunkType.THINKING_COMPLETE) { + delete (chunk as any).thinking_millsec + return chunk + } + return chunk + }) + + const expectedChunks: GenericChunk[] = [ + { + type: ChunkType.THINKING_START + }, + { + type: ChunkType.THINKING_DELTA, + text: `**Analyzing Core Functionality**\n\nI've identified the core query: "What can I do?" expressed in Chinese. Recognizing my nature as a Google-trained language model is the foundation for a relevant response. This fundamental understanding guides the development of an effective answer.\n\n\n` + }, + { + type: ChunkType.THINKING_DELTA, + text: + `**Analyzing Core Functionality**\n\nI've identified the core query: "What can I do?" expressed in Chinese. Recognizing my nature as a Google-trained language model is the foundation for a relevant response. This fundamental understanding guides the development of an effective answer.\n\n\n` + + `**Formulating the Chinese Response**\n\nI'm now drafting a Chinese response. I've moved past the initial simple sentence and am incorporating more detail. The goal is a clear, concise list of my key capabilities, tailored for a user asking about my function. I'm focusing on "understanding and generating text," "answering questions," and "translating languages" for now. Refining the exact phrasing for optimal clarity is an ongoing focus.\n\n\n` + }, + { + type: ChunkType.THINKING_DELTA, + text: + `**Analyzing Core Functionality**\n\nI've identified the core query: "What can I do?" expressed in Chinese. Recognizing my nature as a Google-trained language model is the foundation for a relevant response. This fundamental understanding guides the development of an effective answer.\n\n\n` + + `**Formulating the Chinese Response**\n\nI'm now drafting a Chinese response. I've moved past the initial simple sentence and am incorporating more detail. The goal is a clear, concise list of my key capabilities, tailored for a user asking about my function. I'm focusing on "understanding and generating text," "answering questions," and "translating languages" for now. Refining the exact phrasing for optimal clarity is an ongoing focus.\n\n\n` + + `**Categorizing My Abilities**\n\nI'm organizing my thoughts to classify the capabilities in my response. I'm grouping functions like text generation and question answering, then differentiating them from specialized features, such as translation and creative writing. Considering the best structure for clarity and comprehensiveness, I am refining the outline for my reply. I'm aiming for concise categories that clearly illustrate the range of my functionality in the response. I'm adding an optional explanation of my training to enrich the overall response.\n\n\n` + }, + { + type: ChunkType.THINKING_DELTA, + text: + `**Analyzing Core Functionality**\n\nI've identified the core query: "What can I do?" expressed in Chinese. Recognizing my nature as a Google-trained language model is the foundation for a relevant response. This fundamental understanding guides the development of an effective answer.\n\n\n` + + `**Formulating the Chinese Response**\n\nI'm now drafting a Chinese response. I've moved past the initial simple sentence and am incorporating more detail. The goal is a clear, concise list of my key capabilities, tailored for a user asking about my function. I'm focusing on "understanding and generating text," "answering questions," and "translating languages" for now. Refining the exact phrasing for optimal clarity is an ongoing focus.\n\n\n` + + `**Categorizing My Abilities**\n\nI'm organizing my thoughts to classify the capabilities in my response. I'm grouping functions like text generation and question answering, then differentiating them from specialized features, such as translation and creative writing. Considering the best structure for clarity and comprehensiveness, I am refining the outline for my reply. I'm aiming for concise categories that clearly illustrate the range of my functionality in the response. I'm adding an optional explanation of my training to enrich the overall response.\n\n\n` + + `**Developing the Chinese Draft**\n\nI'm now iterating on the final Chinese response. I've refined the categories to highlight my versatility. I'm focusing on "understanding and generating text" and "answering questions", and adding a section on how I can perform creative writing tasks. I'm aiming for concise explanations for clarity. I will also include a call to action at the end. I'm considering adding an optional sentence describing how I learned the data I know.\n\n\n` + }, + { + type: ChunkType.THINKING_COMPLETE, + text: + `**Analyzing Core Functionality**\n\nI've identified the core query: "What can I do?" expressed in Chinese. Recognizing my nature as a Google-trained language model is the foundation for a relevant response. This fundamental understanding guides the development of an effective answer.\n\n\n` + + `**Formulating the Chinese Response**\n\nI'm now drafting a Chinese response. I've moved past the initial simple sentence and am incorporating more detail. The goal is a clear, concise list of my key capabilities, tailored for a user asking about my function. I'm focusing on "understanding and generating text," "answering questions," and "translating languages" for now. Refining the exact phrasing for optimal clarity is an ongoing focus.\n\n\n` + + `**Categorizing My Abilities**\n\nI'm organizing my thoughts to classify the capabilities in my response. I'm grouping functions like text generation and question answering, then differentiating them from specialized features, such as translation and creative writing. Considering the best structure for clarity and comprehensiveness, I am refining the outline for my reply. I'm aiming for concise categories that clearly illustrate the range of my functionality in the response. I'm adding an optional explanation of my training to enrich the overall response.\n\n\n` + + `**Developing the Chinese Draft**\n\nI'm now iterating on the final Chinese response. I've refined the categories to highlight my versatility. I'm focusing on "understanding and generating text" and "answering questions", and adding a section on how I can perform creative writing tasks. I'm aiming for concise explanations for clarity. I will also include a call to action at the end. I'm considering adding an optional sentence describing how I learned the data I know.\n\n\n` + }, + { + type: ChunkType.TEXT_START + }, + { + type: ChunkType.TEXT_DELTA, + text: '我是一个大型语言模型,' + }, + { + type: ChunkType.TEXT_DELTA, + text: + '我是一个大型语言模型,' + + '由 Google 训练。\n\n我的能力主要包括:\n\n1. **理解和生成文本:** 我可以阅读、理解并创作各种形式的文本,包括文章、故事、对话、代码等。\n2. ' + }, + { + type: ChunkType.TEXT_DELTA, + text: + '我是一个大型语言模型,' + + '由 Google 训练。\n\n我的能力主要包括:\n\n1. **理解和生成文本:** 我可以阅读、理解并创作各种形式的文本,包括文章、故事、对话、代码等。\n2. ' + + '**回答问题:** 基于我所学习到的信息,我可以回答你提出的各种问题,无论是事实性的、概念性的还是需要解释的。\n3. **语言翻译:** 我可以翻译多种语言之间的文本。\n' + }, + { + type: ChunkType.TEXT_DELTA, + text: + '我是一个大型语言模型,' + + '由 Google 训练。\n\n我的能力主要包括:\n\n1. **理解和生成文本:** 我可以阅读、理解并创作各种形式的文本,包括文章、故事、对话、代码等。\n2. ' + + '**回答问题:** 基于我所学习到的信息,我可以回答你提出的各种问题,无论是事实性的、概念性的还是需要解释的。\n3. **语言翻译:** 我可以翻译多种语言之间的文本。\n' + + '4. **信息总结:** 我可以阅读长篇文本并提炼出关键信息或进行总结。\n5. **创意写作:** 我可以帮助你创作诗歌、代码、剧本、音乐作品' + }, + { + type: ChunkType.TEXT_DELTA, + text: + '我是一个大型语言模型,' + + '由 Google 训练。\n\n我的能力主要包括:\n\n1. **理解和生成文本:** 我可以阅读、理解并创作各种形式的文本,包括文章、故事、对话、代码等。\n2. ' + + '**回答问题:** 基于我所学习到的信息,我可以回答你提出的各种问题,无论是事实性的、概念性的还是需要解释的。\n3. **语言翻译:** 我可以翻译多种语言之间的文本。\n' + + '4. **信息总结:** 我可以阅读长篇文本并提炼出关键信息或进行总结。\n5. **创意写作:** 我可以帮助你创作诗歌、代码、剧本、音乐作品' + + '、电子邮件、信件等各种创意内容。\n6. **解释概念:** 我可以解释复杂的术语、概念或主题,使其更容易理解。\n7. **对话交流:** 我可以和你进行自然' + }, + { + type: ChunkType.TEXT_DELTA, + text: + '我是一个大型语言模型,' + + '由 Google 训练。\n\n我的能力主要包括:\n\n1. **理解和生成文本:** 我可以阅读、理解并创作各种形式的文本,包括文章、故事、对话、代码等。\n2. ' + + '**回答问题:** 基于我所学习到的信息,我可以回答你提出的各种问题,无论是事实性的、概念性的还是需要解释的。\n3. **语言翻译:** 我可以翻译多种语言之间的文本。\n' + + '4. **信息总结:** 我可以阅读长篇文本并提炼出关键信息或进行总结。\n5. **创意写作:** 我可以帮助你创作诗歌、代码、剧本、音乐作品' + + '、电子邮件、信件等各种创意内容。\n6. **解释概念:** 我可以解释复杂的术语、概念或主题,使其更容易理解。\n7. **对话交流:** 我可以和你进行自然' + + '流畅的对话,就像与人交流一样。\n8. **学习和研究助手:** 我可以帮助你查找信息、学习新知识、整理思路等。\n\n总的来说,我的目标是为你提供信息、帮助你完成' + }, + { + type: ChunkType.TEXT_DELTA, + text: + '我是一个大型语言模型,' + + '由 Google 训练。\n\n我的能力主要包括:\n\n1. **理解和生成文本:** 我可以阅读、理解并创作各种形式的文本,包括文章、故事、对话、代码等。\n2. ' + + '**回答问题:** 基于我所学习到的信息,我可以回答你提出的各种问题,无论是事实性的、概念性的还是需要解释的。\n3. **语言翻译:** 我可以翻译多种语言之间的文本。\n' + + '4. **信息总结:** 我可以阅读长篇文本并提炼出关键信息或进行总结。\n5. **创意写作:** 我可以帮助你创作诗歌、代码、剧本、音乐作品' + + '、电子邮件、信件等各种创意内容。\n6. **解释概念:** 我可以解释复杂的术语、概念或主题,使其更容易理解。\n7. **对话交流:** 我可以和你进行自然' + + '流畅的对话,就像与人交流一样。\n8. **学习和研究助手:** 我可以帮助你查找信息、学习新知识、整理思路等。\n\n总的来说,我的目标是为你提供信息、帮助你完成' + + '任务,并以有益和富有成效的方式与你互动。\n\n你有什么具体想让我做的吗?' + }, + { + type: ChunkType.TEXT_COMPLETE, + text: + '我是一个大型语言模型,' + + '由 Google 训练。\n\n我的能力主要包括:\n\n1. **理解和生成文本:** 我可以阅读、理解并创作各种形式的文本,包括文章、故事、对话、代码等。\n2. ' + + '**回答问题:** 基于我所学习到的信息,我可以回答你提出的各种问题,无论是事实性的、概念性的还是需要解释的。\n3. **语言翻译:** 我可以翻译多种语言之间的文本。\n' + + '4. **信息总结:** 我可以阅读长篇文本并提炼出关键信息或进行总结。\n5. **创意写作:** 我可以帮助你创作诗歌、代码、剧本、音乐作品' + + '、电子邮件、信件等各种创意内容。\n6. **解释概念:** 我可以解释复杂的术语、概念或主题,使其更容易理解。\n7. **对话交流:** 我可以和你进行自然' + + '流畅的对话,就像与人交流一样。\n8. **学习和研究助手:** 我可以帮助你查找信息、学习新知识、整理思路等。\n\n总的来说,我的目标是为你提供信息、帮助你完成' + + '任务,并以有益和富有成效的方式与你互动。\n\n你有什么具体想让我做的吗?' + }, + { + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + prompt_tokens: 6, + completion_tokens: 1279, + total_tokens: 1285 + } + } + } + ] + + expect(filteredChunks).toEqual(expectedChunks) + }) +}) diff --git a/src/renderer/src/store/messageBlock.ts b/src/renderer/src/store/messageBlock.ts index a2860f4685..f06470f7cd 100644 --- a/src/renderer/src/store/messageBlock.ts +++ b/src/renderer/src/store/messageBlock.ts @@ -20,7 +20,7 @@ const initialState = messageBlocksAdapter.getInitialState({ }) // 3. 创建 Slice -const messageBlocksSlice = createSlice({ +export const messageBlocksSlice = createSlice({ name: 'messageBlocks', initialState, reducers: { diff --git a/src/renderer/src/store/newMessage.ts b/src/renderer/src/store/newMessage.ts index 1a90a73ec6..66030125ef 100644 --- a/src/renderer/src/store/newMessage.ts +++ b/src/renderer/src/store/newMessage.ts @@ -67,7 +67,7 @@ interface InsertMessageAtIndexPayload { } // 4. Create the Slice with Refactored Reducers -const messagesSlice = createSlice({ +export const messagesSlice = createSlice({ name: 'newMessages', initialState, reducers: { diff --git a/src/renderer/src/store/thunk/__tests__/streamCallback.integration.test.ts b/src/renderer/src/store/thunk/__tests__/streamCallback.integration.test.ts new file mode 100644 index 0000000000..15a86ce64d --- /dev/null +++ b/src/renderer/src/store/thunk/__tests__/streamCallback.integration.test.ts @@ -0,0 +1,766 @@ +import { combineReducers, configureStore } from '@reduxjs/toolkit' +import { createStreamProcessor } from '@renderer/services/StreamProcessingService' +import type { AppDispatch } from '@renderer/store' +import { messageBlocksSlice } from '@renderer/store/messageBlock' +import { messagesSlice } from '@renderer/store/newMessage' +import { streamCallback } from '@renderer/store/thunk/messageThunk' +import type { Assistant, ExternalToolResult, MCPTool, Model } from '@renderer/types' +import { WebSearchSource } from '@renderer/types' +import type { Chunk } from '@renderer/types/chunk' +import { ChunkType } from '@renderer/types/chunk' +import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import type { RootState } from '../../index' + +// Mock external dependencies +vi.mock('@renderer/config/models', () => ({ + SYSTEM_MODELS: { + defaultModel: [{}, {}, {}], + silicon: [], + aihubmix: [], + ocoolai: [], + deepseek: [], + ppio: [], + alayanew: [], + qiniu: [], + dmxapi: [], + burncloud: [], + tokenflux: [], + '302ai': [], + cephalon: [], + lanyun: [], + ph8: [], + openrouter: [], + ollama: [], + 'new-api': [], + lmstudio: [], + anthropic: [], + openai: [], + 'azure-openai': [], + gemini: [], + vertexai: [], + github: [], + copilot: [], + zhipu: [], + yi: [], + moonshot: [], + baichuan: [], + dashscope: [], + stepfun: [], + doubao: [], + infini: [], + minimax: [], + groq: [], + together: [], + fireworks: [], + nvidia: [], + grok: [], + hyperbolic: [], + mistral: [], + jina: [], + perplexity: [], + modelscope: [], + xirang: [], + hunyuan: [], + 'tencent-cloud-ti': [], + 'baidu-cloud': [], + gpustack: [], + voyageai: [] + }, + getModelLogo: vi.fn(), + isVisionModel: vi.fn(() => false), + isFunctionCallingModel: vi.fn(() => false), + isEmbeddingModel: vi.fn(() => false), + isReasoningModel: vi.fn(() => false) + // ... 其他需要用到的函数也可以在这里 mock +})) + +vi.mock('@renderer/databases', () => ({ + default: { + message_blocks: { + bulkPut: vi.fn(), + update: vi.fn(), + bulkDelete: vi.fn(), + put: vi.fn(), + bulkAdd: vi.fn(), + where: vi.fn().mockReturnValue({ + equals: vi.fn().mockReturnValue({ + modify: vi.fn() + }), + anyOf: vi.fn().mockReturnValue({ + toArray: vi.fn().mockResolvedValue([]) + }) + }) + }, + topics: { + get: vi.fn(), + update: vi.fn(), + where: vi.fn().mockReturnValue({ + equals: vi.fn().mockReturnValue({ + modify: vi.fn() + }) + }) + }, + files: { + where: vi.fn().mockReturnValue({ + equals: vi.fn().mockReturnValue({ + modify: vi.fn() + }) + }) + }, + transaction: vi.fn((callback) => { + if (typeof callback === 'function') { + return callback() + } + return Promise.resolve() + }) + } +})) + +vi.mock('@renderer/services/FileManager', () => ({ + default: { + deleteFile: vi.fn() + } +})) + +vi.mock('@renderer/services/NotificationService', () => ({ + NotificationService: { + getInstance: vi.fn(() => ({ + send: vi.fn() + })) + } +})) + +vi.mock('@renderer/services/EventService', () => ({ + EventEmitter: { + emit: vi.fn() + }, + EVENT_NAMES: { + MESSAGE_COMPLETE: 'MESSAGE_COMPLETE' + } +})) + +vi.mock('@renderer/utils/window', () => ({ + isOnHomePage: vi.fn(() => true) +})) + +vi.mock('@renderer/hooks/useTopic', () => ({ + autoRenameTopic: vi.fn() +})) + +vi.mock('@renderer/store/assistants', () => { + const mockAssistantsSlice = { + name: 'assistants', + reducer: vi.fn((state = { entities: {}, ids: [] }) => state), + actions: { + updateTopicUpdatedAt: vi.fn(() => ({ type: 'UPDATE_TOPIC_UPDATED_AT' })) + } + } + + return { + default: mockAssistantsSlice.reducer, + updateTopicUpdatedAt: vi.fn(() => ({ type: 'UPDATE_TOPIC_UPDATED_AT' })), + assistantsSlice: mockAssistantsSlice + } +}) + +vi.mock('@renderer/services/TokenService', () => ({ + estimateMessagesUsage: vi.fn(() => + Promise.resolve({ + prompt_tokens: 100, + completion_tokens: 50, + total_tokens: 150 + }) + ) +})) + +vi.mock('@renderer/utils/queue', () => ({ + getTopicQueue: vi.fn(() => ({ + add: vi.fn((task) => task()) + })), + waitForTopicQueue: vi.fn() +})) + +vi.mock('@renderer/utils/messageUtils/find', () => ({ + default: {}, + findMainTextBlocks: vi.fn(() => []), + getMainTextContent: vi.fn(() => 'Test content') +})) + +vi.mock('i18next', () => { + const mockI18n = { + use: vi.fn().mockReturnThis(), + init: vi.fn().mockResolvedValue(undefined), + t: vi.fn((key) => key), + changeLanguage: vi.fn().mockResolvedValue(undefined), + language: 'en', + languages: ['en', 'zh'], + on: vi.fn(), + off: vi.fn(), + emit: vi.fn(), + store: {}, + services: {}, + options: {} + } + + return { + default: mockI18n, + ...mockI18n + } +}) + +vi.mock('@renderer/utils/error', () => ({ + formatErrorMessage: vi.fn((error) => error.message || 'Unknown error'), + isAbortError: vi.fn((error) => error.name === 'AbortError') +})) + +vi.mock('@renderer/utils', () => ({ + default: {}, + uuid: vi.fn(() => 'mock-uuid-' + Math.random().toString(36).substr(2, 9)) +})) + +interface MockTopicsState { + entities: Record +} + +const reducer = combineReducers({ + messages: messagesSlice.reducer, + messageBlocks: messageBlocksSlice.reducer, + topics: (state: MockTopicsState = { entities: {} }) => state +}) + +const createMockStore = () => { + const store = configureStore({ + reducer: reducer, + middleware: (getDefaultMiddleware) => getDefaultMiddleware({ serializableCheck: false }) + }) + return store +} + +// Helper function to simulate processing chunks through the stream processor +const processChunks = async (chunks: Chunk[], callbacks: ReturnType) => { + const streamProcessor = createStreamProcessor(callbacks) + + const stream = new ReadableStream({ + start(controller) { + for (const chunk of chunks) { + controller.enqueue(chunk) + } + controller.close() + } + }) + + const reader = stream.getReader() + + try { + while (true) { + const { done, value: chunk } = await reader.read() + if (done) { + break + } + + if (chunk) { + streamProcessor(chunk) + + // Add small delay to simulate real streaming + await new Promise((resolve) => setTimeout(resolve, 10)) + } + } + } catch (error) { + console.error('Error processing chunks:', error) + throw error + } finally { + reader.releaseLock() + } +} + +describe('streamCallback Integration Tests', () => { + let store: ReturnType + let dispatch: AppDispatch + let getState: () => ReturnType & RootState + + const mockTopicId = 'test-topic-id' + const mockAssistantMsgId = 'test-assistant-msg-id' + const mockAssistant: Assistant = { + id: 'test-assistant', + name: 'Test Assistant', + model: { + id: 'test-model', + name: 'Test Model' + } as Model, + prompt: '', + enableWebSearch: false, + enableGenerateImage: false, + knowledge_bases: [], + topics: [], + type: 'test' + } + + beforeEach(() => { + vi.clearAllMocks() + store = createMockStore() + dispatch = store.dispatch + getState = store.getState as () => ReturnType & RootState + + // 为测试消息添加初始状态 + store.dispatch( + messagesSlice.actions.addMessage({ + topicId: mockTopicId, + message: { + id: mockAssistantMsgId, + assistantId: mockAssistant.id, + role: 'assistant', + topicId: mockTopicId, + blocks: [], + status: AssistantMessageStatus.PENDING, + createdAt: new Date().toISOString(), + updatedAt: new Date().toISOString() + } + }) + ) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + it('should handle complete text streaming flow', async () => { + const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId) + + const chunks: Chunk[] = [ + { type: ChunkType.LLM_RESPONSE_CREATED }, + { type: ChunkType.TEXT_START }, + { type: ChunkType.TEXT_DELTA, text: 'Hello ' }, + { type: ChunkType.TEXT_DELTA, text: 'Hello world!' }, + { type: ChunkType.TEXT_COMPLETE, text: 'Hello world!' }, + { + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { prompt_tokens: 100, completion_tokens: 50, total_tokens: 150 }, + metrics: { completion_tokens: 50, time_completion_millsec: 1000 } + } + }, + { + type: ChunkType.BLOCK_COMPLETE, + response: { + usage: { prompt_tokens: 100, completion_tokens: 50, total_tokens: 150 }, + metrics: { completion_tokens: 50, time_completion_millsec: 1000 } + } + } + ] + + await processChunks(chunks, callbacks) + + // 验证 Redux 状态 + const state = getState() + const blocks = Object.values(state.messageBlocks.entities) + expect(blocks.length).toBeGreaterThan(0) + + const textBlock = blocks.find((block) => block.type === MessageBlockType.MAIN_TEXT) + expect(textBlock).toBeDefined() + expect(textBlock?.content).toBe('Hello world!') + expect(textBlock?.status).toBe(MessageBlockStatus.SUCCESS) + + // 验证消息状态更新 + const message = state.messages.entities[mockAssistantMsgId] + expect(message?.status).toBe(AssistantMessageStatus.SUCCESS) + expect(message?.usage?.total_tokens).toBe(150) + }) + + it('should handle thinking flow', async () => { + const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId) + + const chunks: Chunk[] = [ + { type: ChunkType.LLM_RESPONSE_CREATED }, + { type: ChunkType.THINKING_START }, + { type: ChunkType.THINKING_DELTA, text: 'Let me think...', thinking_millsec: 1000 }, + { type: ChunkType.THINKING_DELTA, text: 'I need to consider...', thinking_millsec: 2000 }, + { type: ChunkType.THINKING_COMPLETE, text: 'Final thoughts', thinking_millsec: 3000 }, + { type: ChunkType.BLOCK_COMPLETE } + ] + + await processChunks(chunks, callbacks) + + // 验证 Redux 状态 + const state = getState() + const blocks = Object.values(state.messageBlocks.entities) + + const thinkingBlock = blocks.find((block) => block.type === MessageBlockType.THINKING) + expect(thinkingBlock).toBeDefined() + expect(thinkingBlock?.content).toBe('Final thoughts') + expect(thinkingBlock?.status).toBe(MessageBlockStatus.SUCCESS) + expect((thinkingBlock as any)?.thinking_millsec).toBe(3000) + }) + + it('should handle tool call flow', async () => { + const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId) + + const mockTool: MCPTool = { + id: 'tool-1', + serverId: 'server-1', + serverName: 'Test Server', + name: 'test-tool', + description: 'Test tool', + inputSchema: { + type: 'object', + title: 'Test Tool Input', + properties: {} + } + } + + const chunks: Chunk[] = [ + { type: ChunkType.LLM_RESPONSE_CREATED }, + { + type: ChunkType.MCP_TOOL_PENDING, + responses: [ + { + id: 'tool-call-1', + tool: mockTool, + arguments: { testArg: 'value' }, + status: 'pending' as const, + response: '' + } + ] + }, + { + type: ChunkType.MCP_TOOL_IN_PROGRESS, + responses: [ + { + id: 'tool-call-1', + tool: mockTool, + arguments: { testArg: 'value' }, + status: 'invoking' as const, + response: '' + } + ] + }, + { + type: ChunkType.MCP_TOOL_COMPLETE, + responses: [ + { + id: 'tool-call-1', + tool: mockTool, + arguments: { testArg: 'value' }, + status: 'done' as const, + response: 'Tool result' + } + ] + }, + { type: ChunkType.BLOCK_COMPLETE } + ] + + await processChunks(chunks, callbacks) + + // 验证 Redux 状态 + const state = getState() + const blocks = Object.values(state.messageBlocks.entities) + + const toolBlock = blocks.find((block) => block.type === MessageBlockType.TOOL) + expect(toolBlock).toBeDefined() + expect(toolBlock?.content).toBe('Tool result') + expect(toolBlock?.status).toBe(MessageBlockStatus.SUCCESS) + expect((toolBlock as any)?.toolName).toBe('test-tool') + }) + + it('should handle image generation flow', async () => { + const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId) + + const chunks: Chunk[] = [ + { type: ChunkType.LLM_RESPONSE_CREATED }, + { type: ChunkType.IMAGE_CREATED }, + { + type: ChunkType.IMAGE_DELTA, + image: { + type: 'base64', + images: [ + '' + ] + } + }, + { + type: ChunkType.IMAGE_COMPLETE, + image: { + type: 'base64', + images: [ + '' + ] + } + }, + { type: ChunkType.BLOCK_COMPLETE } + ] + + await processChunks(chunks, callbacks) + + // 验证 Redux 状态 + const state = getState() + const blocks = Object.values(state.messageBlocks.entities) + const imageBlock = blocks.find((block) => block.type === MessageBlockType.IMAGE) + expect(imageBlock).toBeDefined() + expect(imageBlock?.url).toBe( + '' + ) + expect(imageBlock?.status).toBe(MessageBlockStatus.SUCCESS) + }) + + it('should handle web search flow', async () => { + const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId) + + const mockWebSearchResult = { + source: WebSearchSource.WEBSEARCH, + results: [{ title: 'Test Result', url: 'http://example.com', snippet: 'Test snippet' }] + } + + const chunks: Chunk[] = [ + { type: ChunkType.LLM_RESPONSE_CREATED }, + { type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS }, + { type: ChunkType.LLM_WEB_SEARCH_COMPLETE, llm_web_search: mockWebSearchResult }, + { type: ChunkType.BLOCK_COMPLETE } + ] + + await processChunks(chunks, callbacks) + + // 验证 Redux 状态 + const state = getState() + const blocks = Object.values(state.messageBlocks.entities) + + const citationBlock = blocks.find((block) => block.type === MessageBlockType.CITATION) + expect(citationBlock).toBeDefined() + expect(citationBlock?.response?.source).toEqual(mockWebSearchResult.source) + expect(citationBlock?.status).toBe(MessageBlockStatus.SUCCESS) + }) + + it('should handle mixed content flow (thinking + tool + text)', async () => { + const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId) + + const mockCalculatorTool: MCPTool = { + id: 'tool-1', + serverId: 'server-1', + serverName: 'Test Server', + name: 'calculator', + description: 'Calculator tool', + inputSchema: { + type: 'object', + title: 'Calculator Input', + properties: {} + } + } + + const chunks: Chunk[] = [ + { type: ChunkType.LLM_RESPONSE_CREATED }, + + // 思考阶段 + { type: ChunkType.THINKING_START }, + { type: ChunkType.THINKING_DELTA, text: 'Let me calculate this...', thinking_millsec: 1000 }, + { + type: ChunkType.THINKING_DELTA, + text: 'Let me calculate this..., I need to use a calculator', + thinking_millsec: 1000 + }, + { + type: ChunkType.THINKING_COMPLETE, + text: 'Let me calculate this..., I need to use a calculator', + thinking_millsec: 2000 + }, + + // 工具调用阶段 + { + type: ChunkType.MCP_TOOL_PENDING, + responses: [ + { + id: 'tool-call-1', + tool: mockCalculatorTool, + arguments: { operation: 'add', a: 1, b: 2 }, + status: 'pending' as const, + response: '' + } + ] + }, + { + type: ChunkType.MCP_TOOL_IN_PROGRESS, + responses: [ + { + id: 'tool-call-1', + tool: mockCalculatorTool, + arguments: { operation: 'add', a: 1, b: 2 }, + status: 'invoking' as const, + response: '' + } + ] + }, + { + type: ChunkType.MCP_TOOL_COMPLETE, + responses: [ + { + id: 'tool-call-1', + tool: mockCalculatorTool, + arguments: { operation: 'add', a: 1, b: 2 }, + status: 'done' as const, + response: '42' + } + ] + }, + + // 文本响应阶段 + { type: ChunkType.TEXT_START }, + { type: ChunkType.TEXT_DELTA, text: 'The answer is ' }, + { type: ChunkType.TEXT_DELTA, text: '42' }, + { type: ChunkType.TEXT_COMPLETE, text: 'The answer is 42' }, + + { type: ChunkType.BLOCK_COMPLETE } + ] + + await processChunks(chunks, callbacks) + + // 验证 Redux 状态 + const state = getState() + const blocks = Object.values(state.messageBlocks.entities) + + expect(blocks.length).toBeGreaterThan(2) // 至少有思考块、工具块、文本块 + + const thinkingBlock = blocks.find((block) => block.type === MessageBlockType.THINKING) + expect(thinkingBlock?.content).toBe('Let me calculate this..., I need to use a calculator') + expect(thinkingBlock?.status).toBe(MessageBlockStatus.SUCCESS) + + const toolBlock = blocks.find((block) => block.type === MessageBlockType.TOOL) + expect(toolBlock?.content).toBe('42') + expect(toolBlock?.status).toBe(MessageBlockStatus.SUCCESS) + + const textBlock = blocks.find((block) => block.type === MessageBlockType.MAIN_TEXT) + expect(textBlock?.content).toBe('The answer is 42') + expect(textBlock?.status).toBe(MessageBlockStatus.SUCCESS) + }) + + it('should handle error flow', async () => { + const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId) + + const mockError = new Error('Test error') + + const chunks: Chunk[] = [ + { type: ChunkType.LLM_RESPONSE_CREATED }, + { type: ChunkType.TEXT_START }, + { type: ChunkType.TEXT_DELTA, text: 'Hello ' }, + { type: ChunkType.ERROR, error: mockError } + ] + + await processChunks(chunks, callbacks) + + // 验证 Redux 状态 + const state = getState() + const blocks = Object.values(state.messageBlocks.entities) + + expect(blocks.length).toBeGreaterThan(0) + + const errorBlock = blocks.find((block) => block.type === MessageBlockType.ERROR) + expect(errorBlock).toBeDefined() + expect(errorBlock?.status).toBe(MessageBlockStatus.SUCCESS) + expect((errorBlock as any)?.error?.message).toBe('Test error') + + // 验证消息状态更新 + const message = state.messages.entities[mockAssistantMsgId] + expect(message?.status).toBe(AssistantMessageStatus.ERROR) + }) + + it('should handle external tool flow', async () => { + const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId) + + const mockExternalToolResult: ExternalToolResult = { + webSearch: { + source: WebSearchSource.WEBSEARCH, + results: [{ title: 'External Result', url: 'http://external.com', snippet: 'External snippet' }] + }, + knowledge: [ + { + id: 1, + content: 'Knowledge content', + sourceUrl: 'http://external.com', + type: 'url' + } + ] + } + + const chunks: Chunk[] = [ + { type: ChunkType.LLM_RESPONSE_CREATED }, + { type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS }, + { type: ChunkType.EXTERNEL_TOOL_COMPLETE, external_tool: mockExternalToolResult }, + { type: ChunkType.BLOCK_COMPLETE } + ] + + await processChunks(chunks, callbacks) + + // 验证 Redux 状态 + const state = getState() + const blocks = Object.values(state.messageBlocks.entities) + + const citationBlock = blocks.find((block) => block.type === MessageBlockType.CITATION) + expect(citationBlock).toBeDefined() + expect((citationBlock as any)?.response).toEqual(mockExternalToolResult.webSearch) + expect((citationBlock as any)?.knowledge).toEqual(mockExternalToolResult.knowledge) + expect(citationBlock?.status).toBe(MessageBlockStatus.SUCCESS) + }) + + it('should handle abort error correctly', async () => { + const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId) + + // 创建一个模拟的 abort 错误 + const abortError = new Error('Request aborted') + abortError.name = 'AbortError' + + const chunks: Chunk[] = [ + { type: ChunkType.LLM_RESPONSE_CREATED }, + { type: ChunkType.TEXT_START }, + { type: ChunkType.TEXT_DELTA, text: 'Partial text...' }, + { type: ChunkType.ERROR, error: abortError } + ] + + await processChunks(chunks, callbacks) + + // 验证 Redux 状态 + const state = getState() + const blocks = Object.values(state.messageBlocks.entities) + + expect(blocks.length).toBeGreaterThan(0) + + const errorBlock = blocks.find((block) => block.type === MessageBlockType.ERROR) + expect(errorBlock).toBeDefined() + expect(errorBlock?.status).toBe(MessageBlockStatus.SUCCESS) + + // 验证消息状态更新为成功(因为是暂停,不是真正的错误) + const message = state.messages.entities[mockAssistantMsgId] + expect(message?.status).toBe(AssistantMessageStatus.SUCCESS) + }) + + it('should maintain block reference integrity during streaming', async () => { + const callbacks = streamCallback(dispatch, getState, mockTopicId, mockAssistant, mockAssistantMsgId) + + const chunks: Chunk[] = [ + { type: ChunkType.LLM_RESPONSE_CREATED }, + { type: ChunkType.TEXT_START }, + { type: ChunkType.TEXT_DELTA, text: 'First chunk' }, + { type: ChunkType.TEXT_DELTA, text: 'Second chunk' }, + { type: ChunkType.TEXT_COMPLETE, text: 'First chunkSecond chunk' }, + { type: ChunkType.BLOCK_COMPLETE } + ] + + await processChunks(chunks, callbacks) + + // 验证 Redux 状态 + const state = getState() + const blocks = Object.values(state.messageBlocks.entities) + const message = state.messages.entities[mockAssistantMsgId] + + // 验证消息的 blocks 数组包含正确的块ID + expect(message?.blocks).toBeDefined() + expect(message?.blocks?.length).toBeGreaterThan(0) + + // 验证所有块都存在于 messageBlocks 状态中 + message?.blocks?.forEach((blockId) => { + const block = state.messageBlocks.entities[blockId] + expect(block).toBeDefined() + expect(block?.messageId).toBe(mockAssistantMsgId) + }) + + // 验证blocks包含正确的内容 + expect(blocks.length).toBeGreaterThan(0) + }) +}) diff --git a/src/renderer/src/store/thunk/messageThunk.ts b/src/renderer/src/store/thunk/messageThunk.ts index ad2e63e55c..5aba6499d9 100644 --- a/src/renderer/src/store/thunk/messageThunk.ts +++ b/src/renderer/src/store/thunk/messageThunk.ts @@ -311,6 +311,540 @@ const dispatchMultiModelResponses = async ( // --- End Helper Function --- +export const streamCallback = ( + dispatch: AppDispatch, + getState: () => RootState, + topicId: string, + assistant: Assistant, + assistantMsgId: string +) => { + let lastBlockId: string | null = null + let lastBlockType: MessageBlockType | null = null + // 专注于块内部的生命周期处理 + let initialPlaceholderBlockId: string | null = null + let citationBlockId: string | null = null + let mainTextBlockId: string | null = null + let thinkingBlockId: string | null = null + let imageBlockId: string | null = null + let toolBlockId: string | null = null + + const toolCallIdToBlockIdMap = new Map() + const notificationService = NotificationService.getInstance() + + /** + * 智能更新策略:根据块类型连续性自动判断使用节流还是立即更新 + * - 连续同类块:使用节流(减少重渲染) + * - 块类型切换:立即更新(确保状态正确) + * @param blockId 块ID + * @param changes 块更新内容 + * @param blockType 块类型 + * @param isComplete 是否完成,如果完成,则需要保存块更新到redux中 + */ + const smartBlockUpdate = ( + blockId: string, + changes: Partial, + blockType: MessageBlockType, + isComplete: boolean = false + ) => { + const isBlockTypeChanged = lastBlockType !== null && lastBlockType !== blockType + if (isBlockTypeChanged || isComplete) { + // 如果块类型改变,则取消上一个块的节流更新,并保存块更新到redux中(尽管有可能被上一个块本身的oncomplete事件的取消节流已经取消了) + if (isBlockTypeChanged && lastBlockId) { + cancelThrottledBlockUpdate(lastBlockId) + } + // 如果当前块完成,则取消当前块的节流更新,并保存块更新到redux中,避免streaming状态覆盖掉完成状态 + if (isComplete) { + cancelThrottledBlockUpdate(blockId) + } + dispatch(updateOneBlock({ id: blockId, changes })) + saveUpdatedBlockToDB(blockId, assistantMsgId, topicId, getState) + lastBlockType = blockType + } else { + throttledBlockUpdate(blockId, changes) + } + } + + const handleBlockTransition = async (newBlock: MessageBlock, newBlockType: MessageBlockType) => { + lastBlockId = newBlock.id + lastBlockType = newBlockType + dispatch( + newMessagesActions.updateMessage({ + topicId, + messageId: assistantMsgId, + updates: { blockInstruction: { id: newBlock.id } } + }) + ) + dispatch(upsertOneBlock(newBlock)) + dispatch( + newMessagesActions.upsertBlockReference({ + messageId: assistantMsgId, + blockId: newBlock.id, + status: newBlock.status + }) + ) + + const currentState = getState() + const updatedMessage = currentState.messages.entities[assistantMsgId] + if (updatedMessage) { + await saveUpdatesToDB(assistantMsgId, topicId, { blocks: updatedMessage.blocks }, [newBlock]) + } else { + console.error(`[handleBlockTransition] Failed to get updated message ${assistantMsgId} from state for DB save.`) + } + } + + let startTime = 0 + + return { + onLLMResponseCreated: async () => { + startTime = Date.now() + const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, { + status: MessageBlockStatus.PROCESSING + }) + initialPlaceholderBlockId = baseBlock.id + await handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN) + }, + onTextStart: async () => { + if (initialPlaceholderBlockId) { + const changes = { + type: MessageBlockType.MAIN_TEXT, + content: '', + status: MessageBlockStatus.STREAMING + } + smartBlockUpdate(initialPlaceholderBlockId, changes, MessageBlockType.MAIN_TEXT, true) + mainTextBlockId = initialPlaceholderBlockId + initialPlaceholderBlockId = null + } else if (!mainTextBlockId) { + const newBlock = createMainTextBlock(assistantMsgId, '', { + status: MessageBlockStatus.STREAMING + }) + mainTextBlockId = newBlock.id + await handleBlockTransition(newBlock, MessageBlockType.MAIN_TEXT) + } + }, + onTextChunk: async (text) => { + const citationBlockSource = citationBlockId + ? (getState().messageBlocks.entities[citationBlockId] as CitationMessageBlock).response?.source + : WebSearchSource.WEBSEARCH + if (mainTextBlockId) { + const blockChanges: Partial = { + content: text, + status: MessageBlockStatus.STREAMING, + citationReferences: citationBlockId ? [{ citationBlockId, citationBlockSource }] : [] + } + smartBlockUpdate(mainTextBlockId, blockChanges, MessageBlockType.MAIN_TEXT) + } + }, + onTextComplete: async (finalText) => { + if (mainTextBlockId) { + const changes = { + content: finalText, + status: MessageBlockStatus.SUCCESS + } + smartBlockUpdate(mainTextBlockId, changes, MessageBlockType.MAIN_TEXT, true) + mainTextBlockId = null + } else { + console.warn( + `[onTextComplete] Received text.complete but last block was not MAIN_TEXT (was ${lastBlockType}) or lastBlockId is null.` + ) + } + }, + onThinkingStart: async () => { + if (initialPlaceholderBlockId) { + const changes = { + type: MessageBlockType.THINKING, + content: '', + status: MessageBlockStatus.STREAMING, + thinking_millsec: 0 + } + thinkingBlockId = initialPlaceholderBlockId + initialPlaceholderBlockId = null + smartBlockUpdate(thinkingBlockId, changes, MessageBlockType.THINKING, true) + } else if (!thinkingBlockId) { + const newBlock = createThinkingBlock(assistantMsgId, '', { + status: MessageBlockStatus.STREAMING, + thinking_millsec: 0 + }) + thinkingBlockId = newBlock.id + await handleBlockTransition(newBlock, MessageBlockType.THINKING) + } + }, + onThinkingChunk: async (text, thinking_millsec) => { + if (thinkingBlockId) { + const blockChanges: Partial = { + content: text, + status: MessageBlockStatus.STREAMING, + thinking_millsec: thinking_millsec + } + smartBlockUpdate(thinkingBlockId, blockChanges, MessageBlockType.THINKING) + } + }, + onThinkingComplete: (finalText, final_thinking_millsec) => { + if (thinkingBlockId) { + const changes = { + type: MessageBlockType.THINKING, + content: finalText, + status: MessageBlockStatus.SUCCESS, + thinking_millsec: final_thinking_millsec + } + smartBlockUpdate(thinkingBlockId, changes, MessageBlockType.THINKING, true) + } else { + console.warn( + `[onThinkingComplete] Received thinking.complete but last block was not THINKING (was ${lastBlockType}) or lastBlockId is null.` + ) + } + thinkingBlockId = null + }, + onToolCallPending: (toolResponse: MCPToolResponse) => { + if (initialPlaceholderBlockId) { + const changes = { + type: MessageBlockType.TOOL, + status: MessageBlockStatus.PENDING, + toolName: toolResponse.tool.name, + metadata: { rawMcpToolResponse: toolResponse } + } + toolBlockId = initialPlaceholderBlockId + initialPlaceholderBlockId = null + smartBlockUpdate(toolBlockId, changes, MessageBlockType.TOOL) + toolCallIdToBlockIdMap.set(toolResponse.id, toolBlockId) + } else if (toolResponse.status === 'pending') { + const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, { + toolName: toolResponse.tool.name, + status: MessageBlockStatus.PENDING, + metadata: { rawMcpToolResponse: toolResponse } + }) + toolBlockId = toolBlock.id + handleBlockTransition(toolBlock, MessageBlockType.TOOL) + toolCallIdToBlockIdMap.set(toolResponse.id, toolBlock.id) + } else { + console.warn( + `[onToolCallPending] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}` + ) + } + }, + onToolCallInProgress: (toolResponse: MCPToolResponse) => { + // 根据 toolResponse.id 查找对应的块ID + const targetBlockId = toolCallIdToBlockIdMap.get(toolResponse.id) + + if (targetBlockId && toolResponse.status === 'invoking') { + const changes = { + status: MessageBlockStatus.PROCESSING, + metadata: { rawMcpToolResponse: toolResponse } + } + smartBlockUpdate(targetBlockId, changes, MessageBlockType.TOOL) + } else if (!targetBlockId) { + console.warn( + `[onToolCallInProgress] No block ID found for tool ID: ${toolResponse.id}. Available mappings:`, + Array.from(toolCallIdToBlockIdMap.entries()) + ) + } else { + console.warn( + `[onToolCallInProgress] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}` + ) + } + }, + onToolCallComplete: (toolResponse: MCPToolResponse) => { + const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id) + toolCallIdToBlockIdMap.delete(toolResponse.id) + if (toolResponse.status === 'done' || toolResponse.status === 'error' || toolResponse.status === 'cancelled') { + if (!existingBlockId) { + console.error( + `[onToolCallComplete] No existing block found for completed/error tool call ID: ${toolResponse.id}. Cannot update.` + ) + return + } + const finalStatus = + toolResponse.status === 'done' || toolResponse.status === 'cancelled' + ? MessageBlockStatus.SUCCESS + : MessageBlockStatus.ERROR + const changes: Partial = { + content: toolResponse.response, + status: finalStatus, + metadata: { rawMcpToolResponse: toolResponse } + } + if (finalStatus === MessageBlockStatus.ERROR) { + changes.error = { message: `Tool execution failed/error`, details: toolResponse.response } + } + smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true) + } else { + console.warn( + `[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}` + ) + } + toolBlockId = null + }, + onExternalToolInProgress: async () => { + const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING }) + citationBlockId = citationBlock.id + await handleBlockTransition(citationBlock, MessageBlockType.CITATION) + // saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState) + }, + onExternalToolComplete: (externalToolResult: ExternalToolResult) => { + if (citationBlockId) { + const changes: Partial = { + response: externalToolResult.webSearch, + knowledge: externalToolResult.knowledge, + status: MessageBlockStatus.SUCCESS + } + smartBlockUpdate(citationBlockId, changes, MessageBlockType.CITATION, true) + } else { + console.error('[onExternalToolComplete] citationBlockId is null. Cannot update.') + } + }, + onLLMWebSearchInProgress: async () => { + if (initialPlaceholderBlockId) { + citationBlockId = initialPlaceholderBlockId + const changes = { + type: MessageBlockType.CITATION, + status: MessageBlockStatus.PROCESSING + } + smartBlockUpdate(initialPlaceholderBlockId, changes, MessageBlockType.CITATION) + initialPlaceholderBlockId = null + } else { + const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING }) + citationBlockId = citationBlock.id + await handleBlockTransition(citationBlock, MessageBlockType.CITATION) + } + }, + onLLMWebSearchComplete: async (llmWebSearchResult) => { + const blockId = citationBlockId || initialPlaceholderBlockId + if (blockId) { + const changes: Partial = { + type: MessageBlockType.CITATION, + response: llmWebSearchResult, + status: MessageBlockStatus.SUCCESS + } + smartBlockUpdate(blockId, changes, MessageBlockType.CITATION, true) + + const state = getState() + const existingMainTextBlocks = findMainTextBlocks(state.messages.entities[assistantMsgId]) + if (existingMainTextBlocks.length > 0) { + const existingMainTextBlock = existingMainTextBlocks[0] + const currentRefs = existingMainTextBlock.citationReferences || [] + const mainTextChanges = { + citationReferences: [...currentRefs, { blockId, citationBlockSource: llmWebSearchResult.source }] + } + smartBlockUpdate(existingMainTextBlock.id, mainTextChanges, MessageBlockType.MAIN_TEXT, true) + } + + if (initialPlaceholderBlockId) { + citationBlockId = initialPlaceholderBlockId + initialPlaceholderBlockId = null + } + } else { + const citationBlock = createCitationBlock( + assistantMsgId, + { + response: llmWebSearchResult + }, + { + status: MessageBlockStatus.SUCCESS + } + ) + citationBlockId = citationBlock.id + const state = getState() + const existingMainTextBlocks = findMainTextBlocks(state.messages.entities[assistantMsgId]) + if (existingMainTextBlocks.length > 0) { + const existingMainTextBlock = existingMainTextBlocks[0] + const currentRefs = existingMainTextBlock.citationReferences || [] + const mainTextChanges = { + citationReferences: [...currentRefs, { citationBlockId, citationBlockSource: llmWebSearchResult.source }] + } + smartBlockUpdate(existingMainTextBlock.id, mainTextChanges, MessageBlockType.MAIN_TEXT, true) + } + await handleBlockTransition(citationBlock, MessageBlockType.CITATION) + } + }, + onImageCreated: async () => { + if (initialPlaceholderBlockId) { + const initialChanges: Partial = { + type: MessageBlockType.IMAGE, + status: MessageBlockStatus.PENDING + } + imageBlockId = initialPlaceholderBlockId + initialPlaceholderBlockId = null + smartBlockUpdate(imageBlockId, initialChanges, MessageBlockType.IMAGE) + } else if (!imageBlockId) { + const imageBlock = createImageBlock(assistantMsgId, { + status: MessageBlockStatus.PENDING + }) + imageBlockId = imageBlock.id + await handleBlockTransition(imageBlock, MessageBlockType.IMAGE) + } + }, + onImageDelta: (imageData) => { + const imageUrl = imageData.images?.[0] || 'placeholder_image_url' + if (imageBlockId) { + const changes: Partial = { + url: imageUrl, + metadata: { generateImageResponse: imageData }, + status: MessageBlockStatus.STREAMING + } + smartBlockUpdate(imageBlockId, changes, MessageBlockType.IMAGE, true) + } + }, + onImageGenerated: (imageData) => { + if (imageBlockId) { + if (!imageData) { + const changes: Partial = { + status: MessageBlockStatus.SUCCESS + } + smartBlockUpdate(imageBlockId, changes, MessageBlockType.IMAGE) + } else { + const imageUrl = imageData.images?.[0] || 'placeholder_image_url' + const changes: Partial = { + url: imageUrl, + metadata: { generateImageResponse: imageData }, + status: MessageBlockStatus.SUCCESS + } + smartBlockUpdate(imageBlockId, changes, MessageBlockType.IMAGE, true) + } + } else { + console.error('[onImageGenerated] Last block was not an Image block or ID is missing.') + } + imageBlockId = null + }, + onError: async (error) => { + console.dir(error, { depth: null }) + const isErrorTypeAbort = isAbortError(error) + let pauseErrorLanguagePlaceholder = '' + if (isErrorTypeAbort) { + pauseErrorLanguagePlaceholder = 'pause_placeholder' + } + + const serializableError = { + name: error.name, + message: pauseErrorLanguagePlaceholder || error.message || formatErrorMessage(error), + originalMessage: error.message, + stack: error.stack, + status: error.status || error.code, + requestId: error.request_id + } + if (!isOnHomePage()) { + await notificationService.send({ + id: uuid(), + type: 'error', + title: t('notification.assistant'), + message: serializableError.message, + silent: false, + timestamp: Date.now(), + source: 'assistant' + }) + } + const possibleBlockId = + mainTextBlockId || + thinkingBlockId || + toolBlockId || + imageBlockId || + citationBlockId || + initialPlaceholderBlockId || + lastBlockId + + if (possibleBlockId) { + // 更改上一个block的状态为ERROR + const changes: Partial = { + status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR + } + smartBlockUpdate(possibleBlockId, changes, lastBlockType!, true) + } + + const errorBlock = createErrorBlock(assistantMsgId, serializableError, { status: MessageBlockStatus.SUCCESS }) + await handleBlockTransition(errorBlock, MessageBlockType.ERROR) + const messageErrorUpdate = { + status: isErrorTypeAbort ? AssistantMessageStatus.SUCCESS : AssistantMessageStatus.ERROR + } + dispatch(newMessagesActions.updateMessage({ topicId, messageId: assistantMsgId, updates: messageErrorUpdate })) + + saveUpdatesToDB(assistantMsgId, topicId, messageErrorUpdate, []) + + EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, { + id: assistantMsgId, + topicId, + status: isErrorTypeAbort ? 'pause' : 'error', + error: error.message + }) + }, + onComplete: async (status: AssistantMessageStatus, response?: Response) => { + const finalStateOnComplete = getState() + const finalAssistantMsg = finalStateOnComplete.messages.entities[assistantMsgId] + + if (status === 'success' && finalAssistantMsg) { + const userMsgId = finalAssistantMsg.askId + const orderedMsgs = selectMessagesForTopic(finalStateOnComplete, topicId) + const userMsgIndex = orderedMsgs.findIndex((m) => m.id === userMsgId) + const contextForUsage = userMsgIndex !== -1 ? orderedMsgs.slice(0, userMsgIndex + 1) : [] + const finalContextWithAssistant = [...contextForUsage, finalAssistantMsg] + + const possibleBlockId = + mainTextBlockId || + thinkingBlockId || + toolBlockId || + imageBlockId || + citationBlockId || + initialPlaceholderBlockId || + lastBlockId + if (possibleBlockId) { + const changes: Partial = { + status: MessageBlockStatus.SUCCESS + } + smartBlockUpdate(possibleBlockId, changes, lastBlockType!, true) + } + + const endTime = Date.now() + const duration = endTime - startTime + const content = getMainTextContent(finalAssistantMsg) + if (!isOnHomePage() && duration > 60 * 1000) { + await notificationService.send({ + id: uuid(), + type: 'success', + title: t('notification.assistant'), + message: content.length > 50 ? content.slice(0, 47) + '...' : content, + silent: false, + timestamp: Date.now(), + source: 'assistant' + }) + } + + // 更新topic的name + autoRenameTopic(assistant, topicId) + + if ( + response && + (response.usage?.total_tokens === 0 || + response?.usage?.prompt_tokens === 0 || + response?.usage?.completion_tokens === 0) + ) { + const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant }) + response.usage = usage + } + // dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) + } + if (response && response.metrics) { + if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) { + response = { + ...response, + metrics: { + ...response.metrics, + completion_tokens: response.usage.completion_tokens + } + } + } + } + + const messageUpdates: Partial = { status, metrics: response?.metrics, usage: response?.usage } + dispatch( + newMessagesActions.updateMessage({ + topicId, + messageId: assistantMsgId, + updates: messageUpdates + }) + ) + saveUpdatesToDB(assistantMsgId, topicId, messageUpdates, []) + + EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, { id: assistantMsgId, topicId, status }) + } + } +} + // Internal function extracted from sendMessage to handle fetching and processing assistant response const fetchAndProcessAssistantResponseImpl = async ( dispatch: AppDispatch, @@ -320,95 +854,10 @@ const fetchAndProcessAssistantResponseImpl = async ( assistantMessage: Message // Pass the prepared assistant message (new or reset) ) => { const assistantMsgId = assistantMessage.id - let callbacks: StreamProcessorCallbacks = {} + const callbacks: StreamProcessorCallbacks = streamCallback(dispatch, getState, topicId, assistant, assistantMsgId) try { dispatch(newMessagesActions.setTopicLoading({ topicId, loading: true })) - let accumulatedContent = '' - let accumulatedThinking = '' - let lastBlockId: string | null = null - let lastBlockType: MessageBlockType | null = null - let currentActiveBlockType: MessageBlockType | null = null - // 专注于块内部的生命周期处理 - let initialPlaceholderBlockId: string | null = null - let citationBlockId: string | null = null - let mainTextBlockId: string | null = null - let thinkingBlockId: string | null = null - let imageBlockId: string | null = null - let toolBlockId: string | null = null - - const toolCallIdToBlockIdMap = new Map() - const notificationService = NotificationService.getInstance() - - /** - * 智能更新策略:根据块类型连续性自动判断使用节流还是立即更新 - * - 连续同类块:使用节流(减少重渲染) - * - 块类型切换:立即更新(确保状态正确) - * @param blockId 块ID - * @param changes 块更新内容 - * @param blockType 块类型 - * @param isComplete 是否完成,如果完成,则需要保存块更新到redux中 - */ - const smartBlockUpdate = ( - blockId: string, - changes: Partial, - blockType: MessageBlockType, - isComplete: boolean = false - ) => { - const isBlockTypeChanged = currentActiveBlockType !== null && currentActiveBlockType !== blockType - if (isBlockTypeChanged || isComplete) { - // 如果块类型改变,则取消上一个块的节流更新,并保存块更新到redux中(尽管有可能被上一个块本身的oncomplete事件的取消节流已经取消了) - if (isBlockTypeChanged && lastBlockId) { - cancelThrottledBlockUpdate(lastBlockId) - } - // 如果当前块完成,则取消当前块的节流更新,并保存块更新到redux中,避免streaming状态覆盖掉完成状态 - if (isComplete) { - cancelThrottledBlockUpdate(blockId) - } - dispatch(updateOneBlock({ id: blockId, changes })) - saveUpdatedBlockToDB(blockId, assistantMsgId, topicId, getState) - } else { - throttledBlockUpdate(blockId, changes) - } - - // 更新当前活跃块类型 - currentActiveBlockType = blockType - } - - const handleBlockTransition = async (newBlock: MessageBlock, newBlockType: MessageBlockType) => { - lastBlockId = newBlock.id - lastBlockType = newBlockType - if (newBlockType !== MessageBlockType.MAIN_TEXT) { - accumulatedContent = '' - } - if (newBlockType !== MessageBlockType.THINKING) { - accumulatedThinking = '' - } - dispatch( - newMessagesActions.updateMessage({ - topicId, - messageId: assistantMsgId, - updates: { blockInstruction: { id: newBlock.id } } - }) - ) - dispatch(upsertOneBlock(newBlock)) - dispatch( - newMessagesActions.upsertBlockReference({ - messageId: assistantMsgId, - blockId: newBlock.id, - status: newBlock.status - }) - ) - - const currentState = getState() - const updatedMessage = currentState.messages.entities[assistantMsgId] - if (updatedMessage) { - await saveUpdatesToDB(assistantMsgId, topicId, { blocks: updatedMessage.blocks }, [newBlock]) - } else { - console.error(`[handleBlockTransition] Failed to get updated message ${assistantMsgId} from state for DB save.`) - } - } - const allMessagesForTopic = selectMessagesForTopic(getState(), topicId) let messagesForContext: Message[] = [] @@ -430,467 +879,8 @@ const fetchAndProcessAssistantResponseImpl = async ( messagesForContext = contextSlice.filter((m) => m && !m.status?.includes('ing')) } - callbacks = { - onLLMResponseCreated: async () => { - const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, { - status: MessageBlockStatus.PROCESSING - }) - initialPlaceholderBlockId = baseBlock.id - await handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN) - }, - onTextStart: async () => { - if (initialPlaceholderBlockId) { - lastBlockType = MessageBlockType.MAIN_TEXT - const changes = { - type: MessageBlockType.MAIN_TEXT, - content: accumulatedContent, - status: MessageBlockStatus.STREAMING - } - smartBlockUpdate(initialPlaceholderBlockId, changes, MessageBlockType.MAIN_TEXT, true) - mainTextBlockId = initialPlaceholderBlockId - initialPlaceholderBlockId = null - } else if (!mainTextBlockId) { - const newBlock = createMainTextBlock(assistantMsgId, accumulatedContent, { - status: MessageBlockStatus.STREAMING - }) - mainTextBlockId = newBlock.id - await handleBlockTransition(newBlock, MessageBlockType.MAIN_TEXT) - } - }, - onTextChunk: async (text) => { - const citationBlockSource = citationBlockId - ? (getState().messageBlocks.entities[citationBlockId] as CitationMessageBlock).response?.source - : WebSearchSource.WEBSEARCH - accumulatedContent += text - if (mainTextBlockId) { - const blockChanges: Partial = { - content: accumulatedContent, - status: MessageBlockStatus.STREAMING, - citationReferences: citationBlockId ? [{ citationBlockId, citationBlockSource }] : [] - } - smartBlockUpdate(mainTextBlockId, blockChanges, MessageBlockType.MAIN_TEXT) - } - }, - onTextComplete: async (finalText) => { - if (mainTextBlockId) { - const changes = { - content: finalText, - status: MessageBlockStatus.SUCCESS - } - smartBlockUpdate(mainTextBlockId, changes, MessageBlockType.MAIN_TEXT, true) - mainTextBlockId = null - } else { - console.warn( - `[onTextComplete] Received text.complete but last block was not MAIN_TEXT (was ${lastBlockType}) or lastBlockId is null.` - ) - } - }, - onThinkingStart: async () => { - if (initialPlaceholderBlockId) { - lastBlockType = MessageBlockType.THINKING - const changes = { - type: MessageBlockType.THINKING, - content: accumulatedThinking, - status: MessageBlockStatus.STREAMING, - thinking_millsec: 0 - } - thinkingBlockId = initialPlaceholderBlockId - initialPlaceholderBlockId = null - smartBlockUpdate(thinkingBlockId, changes, MessageBlockType.THINKING, true) - } else if (!thinkingBlockId) { - const newBlock = createThinkingBlock(assistantMsgId, accumulatedThinking, { - status: MessageBlockStatus.STREAMING, - thinking_millsec: 0 - }) - thinkingBlockId = newBlock.id - await handleBlockTransition(newBlock, MessageBlockType.THINKING) - } - }, - onThinkingChunk: async (text, thinking_millsec) => { - accumulatedThinking += text - if (thinkingBlockId) { - const blockChanges: Partial = { - content: accumulatedThinking, - status: MessageBlockStatus.STREAMING, - thinking_millsec: thinking_millsec - } - smartBlockUpdate(thinkingBlockId, blockChanges, MessageBlockType.THINKING) - } - }, - onThinkingComplete: (finalText, final_thinking_millsec) => { - if (thinkingBlockId) { - const changes = { - type: MessageBlockType.THINKING, - content: finalText, - status: MessageBlockStatus.SUCCESS, - thinking_millsec: final_thinking_millsec - } - smartBlockUpdate(thinkingBlockId, changes, MessageBlockType.THINKING, true) - } else { - console.warn( - `[onThinkingComplete] Received thinking.complete but last block was not THINKING (was ${lastBlockType}) or lastBlockId is null.` - ) - } - thinkingBlockId = null - }, - onToolCallPending: (toolResponse: MCPToolResponse) => { - if (initialPlaceholderBlockId) { - lastBlockType = MessageBlockType.TOOL - const changes = { - type: MessageBlockType.TOOL, - status: MessageBlockStatus.PENDING, - toolName: toolResponse.tool.name, - metadata: { rawMcpToolResponse: toolResponse } - } - toolBlockId = initialPlaceholderBlockId - initialPlaceholderBlockId = null - smartBlockUpdate(toolBlockId, changes, MessageBlockType.TOOL) - toolCallIdToBlockIdMap.set(toolResponse.id, toolBlockId) - } else if (toolResponse.status === 'pending') { - const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, { - toolName: toolResponse.tool.name, - status: MessageBlockStatus.PENDING, - metadata: { rawMcpToolResponse: toolResponse } - }) - toolBlockId = toolBlock.id - handleBlockTransition(toolBlock, MessageBlockType.TOOL) - toolCallIdToBlockIdMap.set(toolResponse.id, toolBlock.id) - } else { - console.warn( - `[onToolCallPending] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}` - ) - } - }, - onToolCallInProgress: (toolResponse: MCPToolResponse) => { - // 根据 toolResponse.id 查找对应的块ID - const targetBlockId = toolCallIdToBlockIdMap.get(toolResponse.id) - - if (targetBlockId && toolResponse.status === 'invoking') { - const changes = { - status: MessageBlockStatus.PROCESSING, - metadata: { rawMcpToolResponse: toolResponse } - } - smartBlockUpdate(targetBlockId, changes, MessageBlockType.TOOL) - } else if (!targetBlockId) { - console.warn( - `[onToolCallInProgress] No block ID found for tool ID: ${toolResponse.id}. Available mappings:`, - Array.from(toolCallIdToBlockIdMap.entries()) - ) - } else { - console.warn( - `[onToolCallInProgress] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}` - ) - } - }, - onToolCallComplete: (toolResponse: MCPToolResponse) => { - const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id) - toolCallIdToBlockIdMap.delete(toolResponse.id) - if (toolResponse.status === 'done' || toolResponse.status === 'error' || toolResponse.status === 'cancelled') { - if (!existingBlockId) { - console.error( - `[onToolCallComplete] No existing block found for completed/error tool call ID: ${toolResponse.id}. Cannot update.` - ) - return - } - const finalStatus = - toolResponse.status === 'done' || toolResponse.status === 'cancelled' - ? MessageBlockStatus.SUCCESS - : MessageBlockStatus.ERROR - const changes: Partial = { - content: toolResponse.response, - status: finalStatus, - metadata: { rawMcpToolResponse: toolResponse } - } - if (finalStatus === MessageBlockStatus.ERROR) { - changes.error = { message: `Tool execution failed/error`, details: toolResponse.response } - } - smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true) - } else { - console.warn( - `[onToolCallComplete] Received unhandled tool status: ${toolResponse.status} for ID: ${toolResponse.id}` - ) - } - toolBlockId = null - }, - onExternalToolInProgress: async () => { - const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING }) - citationBlockId = citationBlock.id - await handleBlockTransition(citationBlock, MessageBlockType.CITATION) - // saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState) - }, - onExternalToolComplete: (externalToolResult: ExternalToolResult) => { - if (citationBlockId) { - const changes: Partial = { - response: externalToolResult.webSearch, - knowledge: externalToolResult.knowledge, - status: MessageBlockStatus.SUCCESS - } - smartBlockUpdate(citationBlockId, changes, MessageBlockType.CITATION, true) - } else { - console.error('[onExternalToolComplete] citationBlockId is null. Cannot update.') - } - }, - onLLMWebSearchInProgress: async () => { - if (initialPlaceholderBlockId) { - lastBlockType = MessageBlockType.CITATION - citationBlockId = initialPlaceholderBlockId - const changes = { - type: MessageBlockType.CITATION, - status: MessageBlockStatus.PROCESSING - } - lastBlockType = MessageBlockType.CITATION - smartBlockUpdate(initialPlaceholderBlockId, changes, MessageBlockType.CITATION) - initialPlaceholderBlockId = null - } else { - const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING }) - citationBlockId = citationBlock.id - await handleBlockTransition(citationBlock, MessageBlockType.CITATION) - } - }, - onLLMWebSearchComplete: async (llmWebSearchResult) => { - const blockId = citationBlockId || initialPlaceholderBlockId - if (blockId) { - const changes: Partial = { - type: MessageBlockType.CITATION, - response: llmWebSearchResult, - status: MessageBlockStatus.SUCCESS - } - smartBlockUpdate(blockId, changes, MessageBlockType.CITATION) - - const state = getState() - const existingMainTextBlocks = findMainTextBlocks(state.messages.entities[assistantMsgId]) - if (existingMainTextBlocks.length > 0) { - const existingMainTextBlock = existingMainTextBlocks[0] - const currentRefs = existingMainTextBlock.citationReferences || [] - const mainTextChanges = { - citationReferences: [...currentRefs, { blockId, citationBlockSource: llmWebSearchResult.source }] - } - smartBlockUpdate(existingMainTextBlock.id, mainTextChanges, MessageBlockType.MAIN_TEXT, true) - } - - if (initialPlaceholderBlockId) { - citationBlockId = initialPlaceholderBlockId - initialPlaceholderBlockId = null - } - } else { - const citationBlock = createCitationBlock( - assistantMsgId, - { - response: llmWebSearchResult - }, - { - status: MessageBlockStatus.SUCCESS - } - ) - citationBlockId = citationBlock.id - const state = getState() - const existingMainTextBlocks = findMainTextBlocks(state.messages.entities[assistantMsgId]) - if (existingMainTextBlocks.length > 0) { - const existingMainTextBlock = existingMainTextBlocks[0] - const currentRefs = existingMainTextBlock.citationReferences || [] - const mainTextChanges = { - citationReferences: [...currentRefs, { citationBlockId, citationBlockSource: llmWebSearchResult.source }] - } - smartBlockUpdate(existingMainTextBlock.id, mainTextChanges, MessageBlockType.MAIN_TEXT, true) - } - await handleBlockTransition(citationBlock, MessageBlockType.CITATION) - } - }, - onImageCreated: async () => { - if (initialPlaceholderBlockId) { - lastBlockType = MessageBlockType.IMAGE - const initialChanges: Partial = { - type: MessageBlockType.IMAGE, - status: MessageBlockStatus.PENDING - } - lastBlockType = MessageBlockType.IMAGE - imageBlockId = initialPlaceholderBlockId - initialPlaceholderBlockId = null - smartBlockUpdate(imageBlockId, initialChanges, MessageBlockType.IMAGE) - } else if (!imageBlockId) { - const imageBlock = createImageBlock(assistantMsgId, { - status: MessageBlockStatus.PENDING - }) - imageBlockId = imageBlock.id - await handleBlockTransition(imageBlock, MessageBlockType.IMAGE) - } - }, - onImageDelta: (imageData) => { - const imageUrl = imageData.images?.[0] || 'placeholder_image_url' - if (imageBlockId) { - const changes: Partial = { - url: imageUrl, - metadata: { generateImageResponse: imageData }, - status: MessageBlockStatus.STREAMING - } - smartBlockUpdate(imageBlockId, changes, MessageBlockType.IMAGE, true) - } - }, - onImageGenerated: (imageData) => { - if (imageBlockId) { - if (!imageData) { - const changes: Partial = { - status: MessageBlockStatus.SUCCESS - } - smartBlockUpdate(imageBlockId, changes, MessageBlockType.IMAGE) - } else { - const imageUrl = imageData.images?.[0] || 'placeholder_image_url' - const changes: Partial = { - url: imageUrl, - metadata: { generateImageResponse: imageData }, - status: MessageBlockStatus.SUCCESS - } - smartBlockUpdate(imageBlockId, changes, MessageBlockType.IMAGE, true) - } - } else { - console.error('[onImageGenerated] Last block was not an Image block or ID is missing.') - } - imageBlockId = null - }, - onError: async (error) => { - console.dir(error, { depth: null }) - const isErrorTypeAbort = isAbortError(error) - let pauseErrorLanguagePlaceholder = '' - if (isErrorTypeAbort) { - pauseErrorLanguagePlaceholder = 'pause_placeholder' - } - - const serializableError = { - name: error.name, - message: pauseErrorLanguagePlaceholder || error.message || formatErrorMessage(error), - originalMessage: error.message, - stack: error.stack, - status: error.status || error.code, - requestId: error.request_id - } - if (!isOnHomePage()) { - await notificationService.send({ - id: uuid(), - type: 'error', - title: t('notification.assistant'), - message: serializableError.message, - silent: false, - timestamp: Date.now(), - source: 'assistant' - }) - } - const possibleBlockId = - mainTextBlockId || - thinkingBlockId || - toolBlockId || - imageBlockId || - citationBlockId || - initialPlaceholderBlockId || - lastBlockId - - if (possibleBlockId) { - // 更改上一个block的状态为ERROR - const changes: Partial = { - status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR - } - smartBlockUpdate(possibleBlockId, changes, lastBlockType!, true) - } - - const errorBlock = createErrorBlock(assistantMsgId, serializableError, { status: MessageBlockStatus.SUCCESS }) - await handleBlockTransition(errorBlock, MessageBlockType.ERROR) - const messageErrorUpdate = { - status: isErrorTypeAbort ? AssistantMessageStatus.SUCCESS : AssistantMessageStatus.ERROR - } - dispatch(newMessagesActions.updateMessage({ topicId, messageId: assistantMsgId, updates: messageErrorUpdate })) - - saveUpdatesToDB(assistantMsgId, topicId, messageErrorUpdate, []) - - EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, { - id: assistantMsgId, - topicId, - status: isErrorTypeAbort ? 'pause' : 'error', - error: error.message - }) - }, - onComplete: async (status: AssistantMessageStatus, response?: Response) => { - const finalStateOnComplete = getState() - const finalAssistantMsg = finalStateOnComplete.messages.entities[assistantMsgId] - - if (status === 'success' && finalAssistantMsg) { - const userMsgId = finalAssistantMsg.askId - const orderedMsgs = selectMessagesForTopic(finalStateOnComplete, topicId) - const userMsgIndex = orderedMsgs.findIndex((m) => m.id === userMsgId) - const contextForUsage = userMsgIndex !== -1 ? orderedMsgs.slice(0, userMsgIndex + 1) : [] - const finalContextWithAssistant = [...contextForUsage, finalAssistantMsg] - - const possibleBlockId = - mainTextBlockId || - thinkingBlockId || - toolBlockId || - imageBlockId || - citationBlockId || - initialPlaceholderBlockId || - lastBlockId - if (possibleBlockId) { - const changes: Partial = { - status: MessageBlockStatus.SUCCESS - } - smartBlockUpdate(possibleBlockId, changes, lastBlockType!, true) - } - - const endTime = Date.now() - const duration = endTime - startTime - const content = getMainTextContent(finalAssistantMsg) - if (!isOnHomePage() && duration > 60 * 1000) { - await notificationService.send({ - id: uuid(), - type: 'success', - title: t('notification.assistant'), - message: content.length > 50 ? content.slice(0, 47) + '...' : content, - silent: false, - timestamp: Date.now(), - source: 'assistant' - }) - } - - // 更新topic的name - autoRenameTopic(assistant, topicId) - - if ( - response && - (response.usage?.total_tokens === 0 || - response?.usage?.prompt_tokens === 0 || - response?.usage?.completion_tokens === 0) - ) { - const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant }) - response.usage = usage - } - // dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false })) - } - if (response && response.metrics) { - if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) { - response = { - ...response, - metrics: { - ...response.metrics, - completion_tokens: response.usage.completion_tokens - } - } - } - } - - const messageUpdates: Partial = { status, metrics: response?.metrics, usage: response?.usage } - dispatch( - newMessagesActions.updateMessage({ - topicId, - messageId: assistantMsgId, - updates: messageUpdates - }) - ) - saveUpdatesToDB(assistantMsgId, topicId, messageUpdates, []) - - EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, { id: assistantMsgId, topicId, status }) - } - } - const streamProcessorCallbacks = createStreamProcessor(callbacks) - const startTime = Date.now() await fetchChatCompletion({ messages: messagesForContext, assistant: assistant,