mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-10 23:59:45 +08:00
fix: streamCallback.integration.test.ts test logic
- Refined `streamCallback.integration.test.ts` to streamline mock setups and enhance clarity in testing logic, including the addition of utility functions for handling persisted data. - Improved the organization of mock services and their integration into tests, ensuring better maintainability and readability of test cases. - Enhanced comments and documentation within tests to provide clearer guidance on the purpose and functionality of various mock utilities.
This commit is contained in:
parent
11843e21d5
commit
60204e2166
@ -84,9 +84,9 @@ vi.mock('electron-updater', () => ({
|
||||
// Import after mocks
|
||||
import { preferenceService } from '@data/PreferenceService'
|
||||
import { UpdateMirror } from '@shared/config/constant'
|
||||
import { MockMainPreferenceServiceUtils } from '@test-mocks/main/PreferenceService'
|
||||
import { app, net } from 'electron'
|
||||
|
||||
import { MockMainPreferenceServiceUtils } from '@test-mocks/main/PreferenceService'
|
||||
import AppUpdater from '../AppUpdater'
|
||||
|
||||
// Mock clientId for ConfigManager since it's not migrated yet
|
||||
|
||||
@ -9,7 +9,9 @@ import type { Assistant, ExternalToolResult, MCPTool, Model } from '@renderer/ty
|
||||
import { WebSearchSource } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
|
||||
import { AssistantMessageStatus } from '@renderer/types/newMessage'
|
||||
import { MockCacheUtils } from '@test-mocks/renderer/CacheService'
|
||||
import { MockDataApiUtils } from '@test-mocks/renderer/DataApiService'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
/**
|
||||
@ -49,67 +51,33 @@ const createMockCallbacks = (
|
||||
}
|
||||
|
||||
// Mock external dependencies
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
SYSTEM_MODELS: {
|
||||
defaultModel: [{}, {}, {}],
|
||||
silicon: [],
|
||||
aihubmix: [],
|
||||
ocoolai: [],
|
||||
deepseek: [],
|
||||
ppio: [],
|
||||
alayanew: [],
|
||||
qiniu: [],
|
||||
dmxapi: [],
|
||||
burncloud: [],
|
||||
tokenflux: [],
|
||||
'302ai': [],
|
||||
cephalon: [],
|
||||
lanyun: [],
|
||||
ph8: [],
|
||||
openrouter: [],
|
||||
ollama: [],
|
||||
'new-api': [],
|
||||
lmstudio: [],
|
||||
anthropic: [],
|
||||
openai: [],
|
||||
'azure-openai': [],
|
||||
gemini: [],
|
||||
vertexai: [],
|
||||
github: [],
|
||||
copilot: [],
|
||||
zhipu: [],
|
||||
yi: [],
|
||||
moonshot: [],
|
||||
baichuan: [],
|
||||
dashscope: [],
|
||||
stepfun: [],
|
||||
doubao: [],
|
||||
infini: [],
|
||||
minimax: [],
|
||||
groq: [],
|
||||
together: [],
|
||||
fireworks: [],
|
||||
nvidia: [],
|
||||
grok: [],
|
||||
hyperbolic: [],
|
||||
mistral: [],
|
||||
jina: [],
|
||||
perplexity: [],
|
||||
modelscope: [],
|
||||
xirang: [],
|
||||
hunyuan: [],
|
||||
'tencent-cloud-ti': [],
|
||||
'baidu-cloud': [],
|
||||
gpustack: [],
|
||||
voyageai: []
|
||||
},
|
||||
getModelLogo: vi.fn(),
|
||||
isVisionModel: vi.fn(() => false),
|
||||
isFunctionCallingModel: vi.fn(() => false),
|
||||
isEmbeddingModel: vi.fn(() => false),
|
||||
isReasoningModel: vi.fn(() => false)
|
||||
// ... 其他需要用到的函数也可以在这里 mock
|
||||
}))
|
||||
// NOTE: CacheService and DataApiService are globally mocked in tests/renderer.setup.ts
|
||||
// Use MockCacheUtils and MockDataApiUtils for testing utilities
|
||||
|
||||
/**
|
||||
* Helper function to get persisted data from mock DataApiService calls
|
||||
* Finds the PATCH call for a specific message path and returns the body
|
||||
*/
|
||||
const getPersistedDataForMessage = (messageId: string) => {
|
||||
const patchCalls = MockDataApiUtils.getCalls('patch')
|
||||
// Find the last call for this message (most recent state)
|
||||
const matchingCalls = patchCalls.filter(([path]: [string]) => path === `/messages/${messageId}`)
|
||||
if (matchingCalls.length === 0) return undefined
|
||||
const lastCall = matchingCalls[matchingCalls.length - 1]
|
||||
return lastCall[1]?.body
|
||||
}
|
||||
|
||||
vi.mock('@renderer/config/models', async (importOriginal) => {
|
||||
const actual = (await importOriginal()) as Record<string, unknown>
|
||||
return {
|
||||
...actual,
|
||||
// Override functions that need mocking for tests
|
||||
isVisionModel: vi.fn(() => false),
|
||||
isFunctionCallingModel: vi.fn(() => false),
|
||||
isEmbeddingModel: vi.fn(() => false),
|
||||
isReasoningModel: vi.fn(() => false)
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@renderer/databases', () => ({
|
||||
default: {
|
||||
@ -167,12 +135,41 @@ vi.mock('@renderer/services/NotificationService', () => ({
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/db/DbService', () => ({
|
||||
DbService: {
|
||||
getInstance: vi.fn(() => ({
|
||||
createMessage: vi.fn(),
|
||||
updateMessage: vi.fn(),
|
||||
deleteMessage: vi.fn(),
|
||||
createBlock: vi.fn(),
|
||||
updateBlock: vi.fn(),
|
||||
deleteBlock: vi.fn(),
|
||||
createBlocks: vi.fn(),
|
||||
getMessageById: vi.fn(),
|
||||
getBlocksByMessageId: vi.fn()
|
||||
}))
|
||||
},
|
||||
dbService: {
|
||||
createMessage: vi.fn(),
|
||||
updateMessage: vi.fn(),
|
||||
deleteMessage: vi.fn(),
|
||||
createBlock: vi.fn(),
|
||||
updateBlock: vi.fn(),
|
||||
deleteBlock: vi.fn(),
|
||||
createBlocks: vi.fn(),
|
||||
getMessageById: vi.fn(),
|
||||
getBlocksByMessageId: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/EventService', () => ({
|
||||
EventEmitter: {
|
||||
emit: vi.fn()
|
||||
emit: vi.fn(),
|
||||
on: vi.fn()
|
||||
},
|
||||
EVENT_NAMES: {
|
||||
MESSAGE_COMPLETE: 'MESSAGE_COMPLETE'
|
||||
MESSAGE_COMPLETE: 'MESSAGE_COMPLETE',
|
||||
SEND_MESSAGE: 'SEND_MESSAGE'
|
||||
}
|
||||
}))
|
||||
|
||||
@ -340,6 +337,8 @@ describe('streamCallback Integration Tests', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
MockCacheUtils.resetMocks()
|
||||
MockDataApiUtils.resetMocks()
|
||||
store = createMockStore()
|
||||
|
||||
// Add initial message state for tests
|
||||
@ -391,20 +390,25 @@ describe('streamCallback Integration Tests', () => {
|
||||
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
// 验证持久化数据 (v2架构通过DataApiService持久化)
|
||||
const persistedData = getPersistedDataForMessage(mockAssistantMsgId) as {
|
||||
status?: string
|
||||
stats?: { totalTokens?: number }
|
||||
data?: { blocks?: Array<{ type: string; content?: string }> }
|
||||
}
|
||||
expect(persistedData).toBeDefined()
|
||||
|
||||
// 验证blocks (data.blocks 格式)
|
||||
const blocks = persistedData?.data?.blocks || []
|
||||
expect(blocks.length).toBeGreaterThan(0)
|
||||
|
||||
const textBlock = blocks.find((block) => block.type === MessageBlockType.MAIN_TEXT)
|
||||
const textBlock = blocks.find((block) => block.type === 'main_text')
|
||||
expect(textBlock).toBeDefined()
|
||||
expect(textBlock?.content).toBe('Hello world!')
|
||||
expect(textBlock?.status).toBe(MessageBlockStatus.SUCCESS)
|
||||
|
||||
// 验证消息状态更新
|
||||
const message = state.messages.entities[mockAssistantMsgId]
|
||||
expect(message?.status).toBe(AssistantMessageStatus.SUCCESS)
|
||||
expect(message?.usage?.total_tokens).toBe(150)
|
||||
expect(persistedData?.status).toBe('success')
|
||||
expect(persistedData?.stats?.totalTokens).toBe(150)
|
||||
})
|
||||
|
||||
it('should handle thinking flow', async () => {
|
||||
@ -422,18 +426,20 @@ describe('streamCallback Integration Tests', () => {
|
||||
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
// 验证持久化数据 (v2架构通过DataApiService持久化)
|
||||
const persistedData = getPersistedDataForMessage(mockAssistantMsgId) as {
|
||||
data?: { blocks?: Array<{ type: string; content?: string; thinking_millsec?: number }> }
|
||||
}
|
||||
expect(persistedData).toBeDefined()
|
||||
|
||||
const thinkingBlock = blocks.find((block) => block.type === MessageBlockType.THINKING)
|
||||
const blocks = persistedData?.data?.blocks || []
|
||||
const thinkingBlock = blocks.find((block) => block.type === 'thinking')
|
||||
expect(thinkingBlock).toBeDefined()
|
||||
expect(thinkingBlock?.content).toBe('Final thoughts')
|
||||
expect(thinkingBlock?.status).toBe(MessageBlockStatus.SUCCESS)
|
||||
// thinking_millsec 现在是本地计算的,只验证它存在且是一个合理的数字
|
||||
expect((thinkingBlock as any)?.thinking_millsec).toBeDefined()
|
||||
expect(typeof (thinkingBlock as any)?.thinking_millsec).toBe('number')
|
||||
expect((thinkingBlock as any)?.thinking_millsec).toBeGreaterThanOrEqual(0)
|
||||
expect(thinkingBlock?.thinking_millsec).toBeDefined()
|
||||
expect(typeof thinkingBlock?.thinking_millsec).toBe('number')
|
||||
expect(thinkingBlock?.thinking_millsec).toBeGreaterThanOrEqual(0)
|
||||
})
|
||||
|
||||
it('should handle tool call flow', async () => {
|
||||
@ -496,15 +502,17 @@ describe('streamCallback Integration Tests', () => {
|
||||
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
// 验证持久化数据
|
||||
const persistedData = getPersistedDataForMessage(mockAssistantMsgId) as {
|
||||
data?: { blocks?: Array<{ type: string; content?: string; toolName?: string }> }
|
||||
}
|
||||
expect(persistedData).toBeDefined()
|
||||
|
||||
const toolBlock = blocks.find((block) => block.type === MessageBlockType.TOOL)
|
||||
const blocks = persistedData?.data?.blocks || []
|
||||
const toolBlock = blocks.find((block) => block.type === 'tool')
|
||||
expect(toolBlock).toBeDefined()
|
||||
expect(toolBlock?.content).toBe('Tool result')
|
||||
expect(toolBlock?.status).toBe(MessageBlockStatus.SUCCESS)
|
||||
expect((toolBlock as any)?.toolName).toBe('test-tool')
|
||||
expect(toolBlock?.toolName).toBe('test-tool')
|
||||
})
|
||||
|
||||
it('should handle image generation flow', async () => {
|
||||
@ -536,15 +544,18 @@ describe('streamCallback Integration Tests', () => {
|
||||
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
const imageBlock = blocks.find((block) => block.type === MessageBlockType.IMAGE)
|
||||
// 验证持久化数据
|
||||
const persistedData = getPersistedDataForMessage(mockAssistantMsgId) as {
|
||||
data?: { blocks?: Array<{ type: string; url?: string }> }
|
||||
}
|
||||
expect(persistedData).toBeDefined()
|
||||
|
||||
const blocks = persistedData?.data?.blocks || []
|
||||
const imageBlock = blocks.find((block) => block.type === 'image')
|
||||
expect(imageBlock).toBeDefined()
|
||||
expect(imageBlock?.url).toBe(
|
||||
'data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAAQABADASIAAhEBAxEB/8QAFwAAAwEAAAAAAAAAAAAAAAAAAQMEB//EACMQAAIBAwMEAwAAAAAAAAAAAAECAwAEEQUSIQYxQVExUYH/xAAVAQEBAAAAAAAAAAAAAAAAAAAAAf/EABQRAQAAAAAAAAAAAAAAAAAAAAD/2gAMAwEAAhEDEQA/AM/8A//Z'
|
||||
)
|
||||
expect(imageBlock?.status).toBe(MessageBlockStatus.SUCCESS)
|
||||
})
|
||||
|
||||
it('should handle web search flow', async () => {
|
||||
@ -564,13 +575,16 @@ describe('streamCallback Integration Tests', () => {
|
||||
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
const citationBlock = blocks.find((block) => block.type === MessageBlockType.CITATION)
|
||||
// 验证持久化数据
|
||||
const persistedData = getPersistedDataForMessage(mockAssistantMsgId) as {
|
||||
data?: { blocks?: Array<{ type: string; response?: { source?: string } }> }
|
||||
}
|
||||
expect(persistedData).toBeDefined()
|
||||
|
||||
const blocks = persistedData?.data?.blocks || []
|
||||
const citationBlock = blocks.find((block) => block.type === 'citation')
|
||||
expect(citationBlock).toBeDefined()
|
||||
expect(citationBlock?.response?.source).toEqual(mockWebSearchResult.source)
|
||||
expect(citationBlock?.status).toBe(MessageBlockStatus.SUCCESS)
|
||||
})
|
||||
|
||||
it('should handle mixed content flow (thinking + tool + text)', async () => {
|
||||
@ -656,23 +670,23 @@ describe('streamCallback Integration Tests', () => {
|
||||
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
// 验证持久化数据
|
||||
const persistedData = getPersistedDataForMessage(mockAssistantMsgId) as {
|
||||
data?: { blocks?: Array<{ type: string; content?: string }> }
|
||||
}
|
||||
expect(persistedData).toBeDefined()
|
||||
|
||||
const blocks = persistedData?.data?.blocks || []
|
||||
expect(blocks.length).toBeGreaterThan(2) // 至少有思考块、工具块、文本块
|
||||
|
||||
const thinkingBlock = blocks.find((block) => block.type === MessageBlockType.THINKING)
|
||||
const thinkingBlock = blocks.find((block) => block.type === 'thinking')
|
||||
expect(thinkingBlock?.content).toBe('Let me calculate this..., I need to use a calculator')
|
||||
expect(thinkingBlock?.status).toBe(MessageBlockStatus.SUCCESS)
|
||||
|
||||
const toolBlock = blocks.find((block) => block.type === MessageBlockType.TOOL)
|
||||
const toolBlock = blocks.find((block) => block.type === 'tool')
|
||||
expect(toolBlock?.content).toBe('42')
|
||||
expect(toolBlock?.status).toBe(MessageBlockStatus.SUCCESS)
|
||||
|
||||
const textBlock = blocks.find((block) => block.type === MessageBlockType.MAIN_TEXT)
|
||||
const textBlock = blocks.find((block) => block.type === 'main_text')
|
||||
expect(textBlock?.content).toBe('The answer is 42')
|
||||
expect(textBlock?.status).toBe(MessageBlockStatus.SUCCESS)
|
||||
})
|
||||
|
||||
it('should handle error flow', async () => {
|
||||
@ -689,20 +703,22 @@ describe('streamCallback Integration Tests', () => {
|
||||
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
// 验证持久化数据
|
||||
const persistedData = getPersistedDataForMessage(mockAssistantMsgId) as {
|
||||
status?: string
|
||||
data?: { blocks?: Array<{ type: string; error?: { message: string } }> }
|
||||
}
|
||||
expect(persistedData).toBeDefined()
|
||||
|
||||
const blocks = persistedData?.data?.blocks || []
|
||||
expect(blocks.length).toBeGreaterThan(0)
|
||||
|
||||
const errorBlock = blocks.find((block) => block.type === MessageBlockType.ERROR)
|
||||
const errorBlock = blocks.find((block) => block.type === 'error')
|
||||
expect(errorBlock).toBeDefined()
|
||||
expect(errorBlock?.status).toBe(MessageBlockStatus.SUCCESS)
|
||||
expect((errorBlock as any)?.error?.message).toBe('Test error')
|
||||
expect(errorBlock?.error?.message).toBe('Test error')
|
||||
|
||||
// 验证消息状态更新
|
||||
const message = state.messages.entities[mockAssistantMsgId]
|
||||
expect(message?.status).toBe(AssistantMessageStatus.ERROR)
|
||||
expect(persistedData?.status).toBe('error')
|
||||
})
|
||||
|
||||
it('should handle external tool flow', async () => {
|
||||
@ -732,15 +748,17 @@ describe('streamCallback Integration Tests', () => {
|
||||
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
// 验证持久化数据
|
||||
const persistedData = getPersistedDataForMessage(mockAssistantMsgId) as {
|
||||
data?: { blocks?: Array<{ type: string; response?: unknown; knowledge?: unknown }> }
|
||||
}
|
||||
expect(persistedData).toBeDefined()
|
||||
|
||||
const citationBlock = blocks.find((block) => block.type === MessageBlockType.CITATION)
|
||||
const blocks = persistedData?.data?.blocks || []
|
||||
const citationBlock = blocks.find((block) => block.type === 'citation')
|
||||
expect(citationBlock).toBeDefined()
|
||||
expect((citationBlock as any)?.response).toEqual(mockExternalToolResult.webSearch)
|
||||
expect((citationBlock as any)?.knowledge).toEqual(mockExternalToolResult.knowledge)
|
||||
expect(citationBlock?.status).toBe(MessageBlockStatus.SUCCESS)
|
||||
expect(citationBlock?.response).toEqual(mockExternalToolResult.webSearch)
|
||||
expect(citationBlock?.knowledge).toEqual(mockExternalToolResult.knowledge)
|
||||
})
|
||||
|
||||
it('should handle abort error correctly', async () => {
|
||||
@ -759,19 +777,21 @@ describe('streamCallback Integration Tests', () => {
|
||||
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
// 验证持久化数据
|
||||
const persistedData = getPersistedDataForMessage(mockAssistantMsgId) as {
|
||||
status?: string
|
||||
data?: { blocks?: Array<{ type: string }> }
|
||||
}
|
||||
expect(persistedData).toBeDefined()
|
||||
|
||||
const blocks = persistedData?.data?.blocks || []
|
||||
expect(blocks.length).toBeGreaterThan(0)
|
||||
|
||||
const errorBlock = blocks.find((block) => block.type === MessageBlockType.ERROR)
|
||||
const errorBlock = blocks.find((block) => block.type === 'error')
|
||||
expect(errorBlock).toBeDefined()
|
||||
expect(errorBlock?.status).toBe(MessageBlockStatus.SUCCESS)
|
||||
|
||||
// 验证消息状态更新为成功(因为是暂停,不是真正的错误)
|
||||
const message = state.messages.entities[mockAssistantMsgId]
|
||||
expect(message?.status).toBe(AssistantMessageStatus.SUCCESS)
|
||||
expect(persistedData?.status).toBe('success')
|
||||
})
|
||||
|
||||
it('should maintain block reference integrity during streaming', async () => {
|
||||
@ -788,23 +808,20 @@ describe('streamCallback Integration Tests', () => {
|
||||
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
const message = state.messages.entities[mockAssistantMsgId]
|
||||
// 验证持久化数据
|
||||
const persistedData = getPersistedDataForMessage(mockAssistantMsgId) as {
|
||||
data?: { blocks?: Array<{ type: string; content?: string }> }
|
||||
}
|
||||
expect(persistedData).toBeDefined()
|
||||
|
||||
// 验证消息的 blocks 数组包含正确的块ID
|
||||
expect(message?.blocks).toBeDefined()
|
||||
expect(message?.blocks?.length).toBeGreaterThan(0)
|
||||
|
||||
// 验证所有块都存在于 messageBlocks 状态中
|
||||
message?.blocks?.forEach((blockId) => {
|
||||
const block = state.messageBlocks.entities[blockId]
|
||||
expect(block).toBeDefined()
|
||||
expect(block?.messageId).toBe(mockAssistantMsgId)
|
||||
})
|
||||
const blocks = persistedData?.data?.blocks || []
|
||||
|
||||
// 验证blocks包含正确的内容
|
||||
expect(blocks.length).toBeGreaterThan(0)
|
||||
|
||||
// 验证有main_text block
|
||||
const textBlock = blocks.find((block) => block.type === 'main_text')
|
||||
expect(textBlock).toBeDefined()
|
||||
expect(textBlock?.content).toBe('First chunkSecond chunk')
|
||||
})
|
||||
})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user