diff --git a/src/renderer/src/aiCore/index.ts b/src/renderer/src/aiCore/index.ts index 68f3ab662f..3477e2fe59 100644 --- a/src/renderer/src/aiCore/index.ts +++ b/src/renderer/src/aiCore/index.ts @@ -123,6 +123,8 @@ export default class AiProvider { logger.silly('ErrorHandlerMiddleware is removed') builder.remove(FinalChunkConsumerMiddlewareName) logger.silly('FinalChunkConsumerMiddleware is removed') + builder.insertBefore(ThinkChunkMiddlewareName, MiddlewareRegistry[ThinkingTagExtractionMiddlewareName]) + logger.silly('ThinkingTagExtractionMiddleware is inserted') } } diff --git a/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts b/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts index 9d7baccb8f..2082d0702c 100644 --- a/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts +++ b/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts @@ -84,7 +84,7 @@ export const ThinkingTagExtractionMiddleware: CompletionsMiddleware = // 生成 THINKING_COMPLETE 事件 const thinkingCompleteChunk: ThinkingCompleteChunk = { type: ChunkType.THINKING_COMPLETE, - text: extractionResult.tagContentExtracted, + text: extractionResult.tagContentExtracted.trim(), thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 } controller.enqueue(thinkingCompleteChunk) @@ -104,7 +104,7 @@ export const ThinkingTagExtractionMiddleware: CompletionsMiddleware = } if (extractionResult.content?.trim()) { - accumulatedThinkingContent += extractionResult.content + accumulatedThinkingContent += extractionResult.content.trim() const thinkingDeltaChunk: ThinkingDeltaChunk = { type: ChunkType.THINKING_DELTA, text: accumulatedThinkingContent, diff --git a/src/renderer/src/services/__tests__/ApiService.test.ts b/src/renderer/src/services/__tests__/ApiService.test.ts index b4f43e0cd8..09d2aa7fbc 100644 --- a/src/renderer/src/services/__tests__/ApiService.test.ts +++ b/src/renderer/src/services/__tests__/ApiService.test.ts @@ -1,7 +1,16 @@ +import { ToolUseBlock } from '@anthropic-ai/sdk/resources' +import { + TextBlock, + TextDelta, + Usage, + WebSearchResultBlock, + WebSearchToolResultError +} from '@anthropic-ai/sdk/resources/messages' import { FinishReason, MediaModality } from '@google/genai' import { FunctionCall } from '@google/genai' import AiProvider from '@renderer/aiCore' import { OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/clients' +import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient' import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory' import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient' @@ -18,6 +27,7 @@ import { ThinkingStartChunk } from '@renderer/types/chunk' import { + AnthropicSdkRawChunk, GeminiSdkMessageParam, GeminiSdkRawChunk, GeminiSdkToolCall, @@ -760,6 +770,193 @@ const openaiCompletionChunks: OpenAISdkRawChunk[] = [ } ] +const openaiNeedExtractContentChunks: OpenAISdkRawChunk[] = [ + { + id: 'cmpl-123', + created: 1715811200, + model: 'gpt-4o', + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: null, + role: 'assistant', + reasoning_content: null + } as ChatCompletionChunk.Choice.Delta, + index: 0, + logprobs: null, + finish_reason: null + } + ] + }, + { + id: 'cmpl-123', + created: 1715811200, + model: 'gpt-4o', + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: '', + role: 'assistant', + reasoning_content: null + } as ChatCompletionChunk.Choice.Delta, + index: 0, + logprobs: null, + finish_reason: null + } + ] + }, + { + id: 'cmpl-123', + created: 1715811200, + model: 'gpt-4o', + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: '\n好的,用户发来“你好”,我需要友好回应\n', + role: 'assistant', + reasoning_content: null + } as ChatCompletionChunk.Choice.Delta, + index: 0, + logprobs: null, + finish_reason: null + } + ] + }, + { + id: 'cmpl-123', + created: 1715811200, + model: 'gpt-4o', + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: '你好!有什么我可以帮您的吗?', + role: 'assistant', + reasoning_content: null + } as ChatCompletionChunk.Choice.Delta, + index: 0, + logprobs: null, + finish_reason: null + } + ] + }, + { + id: 'cmpl-123', + created: 1715811200, + model: 'gpt-4o', + object: 'chat.completion.chunk', + choices: [ + { + delta: {} as ChatCompletionChunk.Choice.Delta, + index: 0, + logprobs: null, + finish_reason: 'stop' + } + ] + } +] + +const anthropicTextNonStreamChunks: AnthropicSdkRawChunk[] = [ + { + id: 'msg_bdrk_01HctMh5mCpuFRq49KFwTDU6', + type: 'message', + role: 'assistant', + content: [ + { + type: 'text', + text: '你好!有什么我可以帮助你的吗?' + } + ], + model: 'claude-3-7-sonnet-20250219', + stop_reason: 'end_turn', + usage: { + input_tokens: 15, + output_tokens: 21 + } + } as AnthropicSdkRawChunk +] + +const anthropicTextStreamChunks: AnthropicSdkRawChunk[] = [ + { + type: 'message_start', + message: { + id: 'msg_bdrk_013fneHZaGWgKFBzesGM4wu5', + type: 'message', + role: 'assistant', + model: 'claude-3-5-sonnet-20241022', + content: [], + stop_reason: null, + stop_sequence: null, + usage: { + input_tokens: 10, + output_tokens: 2 + } as Usage + } + }, + { + type: 'content_block_start', + index: 0, + content_block: { + type: 'text', + text: '' + } as TextBlock + }, + { + type: 'content_block_delta', + index: 0, + delta: { + type: 'text_delta', + text: '你好!很高兴见到你。有' + } as TextDelta + }, + { + type: 'content_block_delta', + index: 0, + delta: { + type: 'text_delta', + text: '什么我可以帮助你的吗?' + } as TextDelta + }, + { + type: 'content_block_stop', + index: 0 + }, + { + type: 'message_delta', + delta: { + stop_reason: 'end_turn', + stop_sequence: null + }, + usage: { + output_tokens: 28 + } as Usage + }, + { + type: 'message_stop' + } +] + // 正确的 async generator 函数 async function* geminiChunkGenerator(): AsyncGenerator { for (const chunk of geminiChunks) { @@ -785,6 +982,12 @@ async function* openaiThinkingChunkGenerator(): AsyncGenerator { + for (const chunk of openaiNeedExtractContentChunks) { + yield chunk + } +} + const mockOpenaiApiClient = { createCompletions: vi.fn().mockImplementation(() => openaiThinkingChunkGenerator()), getResponseChunkTransformer: vi.fn().mockImplementation(() => { @@ -1051,6 +1254,23 @@ const mockOpenaiApiClient = { getClientCompatibilityType: vi.fn(() => ['OpenAIAPIClient']) } as unknown as OpenAIAPIClient +const mockOpenaiNeedExtractContentApiClient = cloneDeep(mockOpenaiApiClient) +mockOpenaiNeedExtractContentApiClient.createCompletions = vi + .fn() + .mockImplementation(() => openaiNeedExtractContentChunkGenerator()) + +async function* anthropicTextNonStreamChunkGenerator(): AsyncGenerator { + for (const chunk of anthropicTextNonStreamChunks) { + yield chunk + } +} + +async function* anthropicTextStreamChunkGenerator(): AsyncGenerator { + for (const chunk of anthropicTextStreamChunks) { + yield chunk + } +} + // 创建 mock 的 GeminiAPIClient const mockGeminiApiClient = { createCompletions: vi.fn().mockImplementation(() => geminiChunkGenerator()), @@ -1170,6 +1390,224 @@ const mockGeminiApiClient = { getClientCompatibilityType: vi.fn(() => ['GeminiAPIClient']) } as unknown as GeminiAPIClient +const mockAnthropicApiClient = { + createCompletions: vi.fn().mockImplementation(() => anthropicTextNonStreamChunkGenerator()), + getResponseChunkTransformer: vi.fn().mockImplementation(() => { + return () => { + let accumulatedJson = '' + const toolCalls: Record = {} + return { + async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController) { + switch (rawChunk.type) { + case 'message': { + let i = 0 + let hasTextContent = false + let hasThinkingContent = false + + for (const content of rawChunk.content) { + switch (content.type) { + case 'text': { + if (!hasTextContent) { + controller.enqueue({ + type: ChunkType.TEXT_START + } as TextStartChunk) + hasTextContent = true + } + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: content.text + } as TextDeltaChunk) + break + } + case 'tool_use': { + toolCalls[i] = content + i++ + break + } + case 'thinking': { + if (!hasThinkingContent) { + controller.enqueue({ + type: ChunkType.THINKING_START + }) + hasThinkingContent = true + } + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: content.thinking + }) + break + } + case 'web_search_tool_result': { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + results: content.content, + source: WebSearchSource.ANTHROPIC + } + } as LLMWebSearchCompleteChunk) + break + } + } + } + if (i > 0) { + controller.enqueue({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: Object.values(toolCalls) + }) + } + controller.enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + prompt_tokens: rawChunk.usage.input_tokens || 0, + completion_tokens: rawChunk.usage.output_tokens || 0, + total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0) + } + } + }) + break + } + case 'content_block_start': { + const contentBlock = rawChunk.content_block + switch (contentBlock.type) { + case 'server_tool_use': { + if (contentBlock.name === 'web_search') { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS + }) + } + break + } + case 'web_search_tool_result': { + if ( + contentBlock.content && + (contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error' + ) { + controller.enqueue({ + type: ChunkType.ERROR, + error: { + code: (contentBlock.content as WebSearchToolResultError).error_code, + message: (contentBlock.content as WebSearchToolResultError).error_code + } + }) + } else { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + results: contentBlock.content as Array, + source: WebSearchSource.ANTHROPIC + } + }) + } + break + } + case 'tool_use': { + toolCalls[rawChunk.index] = contentBlock + break + } + case 'text': { + controller.enqueue({ + type: ChunkType.TEXT_START + } as TextStartChunk) + break + } + case 'thinking': + case 'redacted_thinking': { + controller.enqueue({ + type: ChunkType.THINKING_START + } as ThinkingStartChunk) + break + } + } + break + } + case 'content_block_delta': { + const messageDelta = rawChunk.delta + switch (messageDelta.type) { + case 'text_delta': { + if (messageDelta.text) { + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: messageDelta.text + } as TextDeltaChunk) + } + break + } + case 'thinking_delta': { + if (messageDelta.thinking) { + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: messageDelta.thinking + }) + } + break + } + case 'input_json_delta': { + if (messageDelta.partial_json) { + accumulatedJson += messageDelta.partial_json + } + break + } + } + break + } + case 'content_block_stop': { + const toolCall = toolCalls[rawChunk.index] + if (toolCall) { + try { + toolCall.input = JSON.parse(accumulatedJson) + controller.enqueue({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: [toolCall] + }) + } catch (error) { + console.error(`Error parsing tool call input: ${error}`) + } + } + break + } + case 'message_delta': { + controller.enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + prompt_tokens: rawChunk.usage.input_tokens || 0, + completion_tokens: rawChunk.usage.output_tokens || 0, + total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0) + } + } + }) + } + } + } + } + } + }), + getRequestTransformer: vi.fn().mockImplementation(() => ({ + async transform(params: any) { + return { + payload: { + model: params.assistant?.model?.id || 'claude-3-7-sonnet-20250219', + messages: params.messages || [], + tools: params.tools || [] + }, + metadata: {} + } + } + })), + convertMcpToolsToSdkTools: vi.fn(() => []), + convertSdkToolCallToMcpToolResponse: vi.fn(), + buildSdkMessages: vi.fn(() => []), + extractMessagesFromSdkPayload: vi.fn(() => []), + provider: {} as Provider, + useSystemPromptForTools: true, + getBaseURL: vi.fn(() => 'https://api.anthropic.com'), + getApiKey: vi.fn(() => 'mock-api-key') +} as unknown as AnthropicAPIClient + +const mockAnthropicApiClientStream = cloneDeep(mockAnthropicApiClient) +mockAnthropicApiClientStream.createCompletions = vi.fn().mockImplementation(() => anthropicTextStreamChunkGenerator()) + const mockGeminiThinkingApiClient = cloneDeep(mockGeminiApiClient) mockGeminiThinkingApiClient.createCompletions = vi.fn().mockImplementation(() => geminiThinkingChunkGenerator()) @@ -1203,7 +1641,7 @@ describe('ApiService', () => { collectedChunks.length = 0 }) - it('should return a stream of chunks with correct types and content', async () => { + it('should return a stream of chunks with correct types and content in gemini', async () => { const mockCreate = vi.mocked(ApiClientFactory.create) mockCreate.mockReturnValue(mockGeminiApiClient as unknown as BaseApiClient) const AI = new AiProvider(mockProvider) @@ -1321,6 +1759,155 @@ describe('ApiService', () => { expect(completionChunk.response?.usage?.completion_tokens).toBe(822) }) + it('should return a non-stream of chunks with correct types and content in anthropic', async () => { + const mockCreate = vi.mocked(ApiClientFactory.create) + mockCreate.mockReturnValue(mockAnthropicApiClient as unknown as BaseApiClient) + const AI = new AiProvider(mockProvider) + + const result = await AI.completions({ + callType: 'test', + messages: [], + assistant: { + id: '1', + name: 'test', + prompt: 'test', + + type: 'anthropic', + model: { + id: 'claude-3-7-sonnet-20250219', + name: 'Claude 3.7 Sonnet' + } + } as Assistant, + onChunk: mockOnChunk, + mcpTools: [], + maxTokens: 1000, + streamOutput: false + }) + + expect(result).toBeDefined() + expect(ApiClientFactory.create).toHaveBeenCalledWith(mockProvider) + expect(result.stream).toBeDefined() + + const stream = result.stream! as ReadableStream + const reader = stream.getReader() + + const chunks: GenericChunk[] = [] + + while (true) { + const { done, value } = await reader.read() + if (done) break + chunks.push(value) + } + + reader.releaseLock() + + const expectedChunks: GenericChunk[] = [ + { + type: ChunkType.TEXT_START + }, + { + type: ChunkType.TEXT_DELTA, + text: '你好!有什么我可以帮助你的吗?' + }, + { + type: ChunkType.TEXT_COMPLETE, + text: '你好!有什么我可以帮助你的吗?' + }, + { + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + completion_tokens: 21, + prompt_tokens: 15, + total_tokens: 36 + } + } + } + ] + + expect(chunks).toEqual(expectedChunks) + + // 验证chunk的数量和类型 + expect(chunks.length).toBeGreaterThan(0) + + // 验证第一个chunk应该是TEXT_START + const firstChunk = chunks[0] + expect(firstChunk.type).toBe(ChunkType.TEXT_START) + }) + + it('should return a stream of chunks with correct types and content in anthropic', async () => { + const mockCreate = vi.mocked(ApiClientFactory.create) + mockCreate.mockReturnValue(mockAnthropicApiClientStream as unknown as BaseApiClient) + const AI = new AiProvider(mockProvider) + + const result = await AI.completions({ + callType: 'test', + messages: [], + assistant: { + id: '1', + name: 'test', + prompt: 'test', + + type: 'anthropic', + model: { + id: 'claude-3-7-sonnet-20250219', + name: 'Claude 3.7 Sonnet' + } + } as Assistant, + onChunk: mockOnChunk, + mcpTools: [], + maxTokens: 1000, + streamOutput: true + }) + + expect(result).toBeDefined() + expect(ApiClientFactory.create).toHaveBeenCalledWith(mockProvider) + expect(result.stream).toBeDefined() + + const stream = result.stream! as ReadableStream + const reader = stream.getReader() + + const chunks: GenericChunk[] = [] + + while (true) { + const { done, value } = await reader.read() + if (done) break + chunks.push(value) + } + + reader.releaseLock() + + const expectedChunks: GenericChunk[] = [ + { + type: ChunkType.TEXT_START + }, + { + type: ChunkType.TEXT_DELTA, + text: '你好!很高兴见到你。有' + }, + { + type: ChunkType.TEXT_DELTA, + text: '你好!很高兴见到你。有什么我可以帮助你的吗?' + }, + { + type: ChunkType.TEXT_COMPLETE, + text: '你好!很高兴见到你。有什么我可以帮助你的吗?' + }, + { + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + completion_tokens: 28, + prompt_tokens: 0, + total_tokens: 28 + } + } + } + ] + + expect(chunks).toEqual(expectedChunks) + }) + it('should return a stream of thinking chunks with correct types and content', async () => { const mockCreate = vi.mocked(ApiClientFactory.create) mockCreate.mockReturnValue(mockGeminiThinkingApiClient as unknown as BaseApiClient) @@ -1579,6 +2166,91 @@ describe('ApiService', () => { expect(filteredChunks).toEqual(expectedChunks) }) + it('should handle openai need extract content chunk correctly', async () => { + const mockCreate = vi.mocked(ApiClientFactory.create) + // @ts-ignore mockOpenaiNeedExtractContentApiClient is a OpenAIAPIClient + mockCreate.mockReturnValue(mockOpenaiNeedExtractContentApiClient as unknown as OpenAIAPIClient) + const AI = new AiProvider(mockProvider as Provider) + const result = await AI.completions({ + callType: 'test', + messages: [], + assistant: { + id: '1', + name: 'test', + prompt: 'test', + model: { + id: 'gpt-4o', + name: 'GPT-4o' + } + } as Assistant, + onChunk: mockOnChunk, + enableReasoning: true, + streamOutput: true + }) + + expect(result).toBeDefined() + expect(ApiClientFactory.create).toHaveBeenCalledWith(mockProvider) + expect(result.stream).toBeDefined() + + const stream = result.stream! as ReadableStream + const reader = stream.getReader() + + const chunks: GenericChunk[] = [] + + while (true) { + const { done, value } = await reader.read() + if (done) break + chunks.push(value) + } + + reader.releaseLock() + + const filteredChunks = chunks.map((chunk) => { + if (chunk.type === ChunkType.THINKING_DELTA || chunk.type === ChunkType.THINKING_COMPLETE) { + delete (chunk as any).thinking_millsec + return chunk + } + return chunk + }) + + const expectedChunks = [ + { + type: ChunkType.THINKING_START + }, + { + type: ChunkType.THINKING_DELTA, + text: '好的,用户发来“你好”,我需要友好回应' + }, + { + type: ChunkType.THINKING_COMPLETE, + text: '好的,用户发来“你好”,我需要友好回应' + }, + { + type: ChunkType.TEXT_START + }, + { + type: ChunkType.TEXT_DELTA, + text: '\n你好!有什么我可以帮您的吗?' + }, + { + type: ChunkType.TEXT_COMPLETE, + text: '\n你好!有什么我可以帮您的吗?' + }, + { + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { + usage: { + completion_tokens: 0, + prompt_tokens: 0, + total_tokens: 0 + } + } + } + ] + + expect(filteredChunks).toEqual(expectedChunks) + }) + it('should extract tool use responses correctly', async () => { const mockCreate = vi.mocked(ApiClientFactory.create) mockCreate.mockReturnValue(mockGeminiToolUseApiClient as unknown as BaseApiClient)