diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts b/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts index 2034b3951e..80a611493d 100644 --- a/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts +++ b/src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts @@ -710,7 +710,10 @@ export class OpenAIAPIClient extends OpenAIBaseClient< choice.delta && Object.keys(choice.delta).length > 0 && (!('content' in choice.delta) || - (typeof choice.delta.content === 'string' && choice.delta.content !== '')) + (typeof choice.delta.content === 'string' && choice.delta.content !== '') || + (typeof (choice.delta as any).reasoning_content === 'string' && + (choice.delta as any).reasoning_content !== '') || + (typeof (choice.delta as any).reasoning === 'string' && (choice.delta as any).reasoning !== '')) ) { contentSource = choice.delta } else if ('message' in choice) { diff --git a/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts b/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts index cf9bd918e1..425f29c705 100644 --- a/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts +++ b/src/renderer/src/aiCore/middleware/feat/ThinkingTagExtractionMiddleware.ts @@ -66,7 +66,7 @@ export const ThinkingTagExtractionMiddleware: CompletionsMiddleware = let thinkingStartTime = 0 let isFirstTextChunk = true - + let accumulatedThinkingContent = '' const processedStream = resultFromUpstream.pipeThrough( new TransformStream({ transform(chunk: GenericChunk, controller) { @@ -101,9 +101,10 @@ export const ThinkingTagExtractionMiddleware: CompletionsMiddleware = } if (extractionResult.content?.trim()) { + accumulatedThinkingContent += extractionResult.content const thinkingDeltaChunk: ThinkingDeltaChunk = { type: ChunkType.THINKING_DELTA, - text: extractionResult.content, + text: accumulatedThinkingContent, thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0 } controller.enqueue(thinkingDeltaChunk) diff --git a/src/renderer/src/services/__tests__/ApiService.test.ts b/src/renderer/src/services/__tests__/ApiService.test.ts index 239e73b2a6..6fd37370d0 100644 --- a/src/renderer/src/services/__tests__/ApiService.test.ts +++ b/src/renderer/src/services/__tests__/ApiService.test.ts @@ -1,6 +1,7 @@ import { FinishReason, MediaModality } from '@google/genai' import { FunctionCall } from '@google/genai' import AiProvider from '@renderer/aiCore' +import { OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/clients' import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory' import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient' @@ -14,8 +15,10 @@ import { TextStartChunk, ThinkingStartChunk } from '@renderer/types/chunk' -import { GeminiSdkRawChunk } from '@renderer/types/sdk' +import { GeminiSdkRawChunk, OpenAISdkRawChunk, OpenAISdkRawContentSource } from '@renderer/types/sdk' import { cloneDeep } from 'lodash' +import OpenAI from 'openai' +import { ChatCompletionChunk } from 'openai/resources' import { beforeEach, describe, expect, it, vi } from 'vitest' // Mock the ApiClientFactory @@ -615,6 +618,117 @@ const geminiToolUseChunks: GeminiSdkRawChunk[] = [ } as GeminiSdkRawChunk ] +const openaiCompletionChunks: OpenAISdkRawChunk[] = [ + { + id: 'cmpl-123', + created: 1715811200, + model: 'gpt-4o', + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: null, + role: 'assistant', + reasoning_content: '' + } as ChatCompletionChunk.Choice.Delta, + index: 0, + logprobs: null, + finish_reason: null + } as ChatCompletionChunk.Choice + ] + }, + { + id: 'cmpl-123', + created: 1715811200, + model: 'gpt-4o', + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: null, + role: 'assistant', + reasoning_content: '好的,用户打招呼说“你好' + } as ChatCompletionChunk.Choice.Delta, + index: 0, + logprobs: null, + finish_reason: null + } + ] + }, + { + id: 'cmpl-123', + created: 1715811200, + model: 'gpt-4o', + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: null, + role: 'assistant', + reasoning_content: '”,我需要友好回应。' + } as ChatCompletionChunk.Choice.Delta, + index: 0, + logprobs: null, + finish_reason: null + } + ] + }, + { + id: 'cmpl-123', + created: 1715811200, + model: 'gpt-4o', + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: '你好!有什么问题', + role: 'assistant', + reasoning_content: null + } as ChatCompletionChunk.Choice.Delta, + index: 0, + logprobs: null, + finish_reason: null + } + ] + }, + { + id: 'cmpl-123', + created: 1715811200, + model: 'gpt-4o', + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: '或者需要我帮忙的吗?', + role: 'assistant', + reasoning_content: null + } as ChatCompletionChunk.Choice.Delta, + index: 0, + logprobs: null, + finish_reason: null + } + ] + }, + { + id: 'cmpl-123', + created: 1715811200, + model: 'gpt-4o', + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: null, + role: 'assistant', + reasoning_content: null + } as ChatCompletionChunk.Choice.Delta, + index: 0, + logprobs: null, + finish_reason: 'stop' + } + ] + } +] + // 正确的 async generator 函数 async function* geminiChunkGenerator(): AsyncGenerator { for (const chunk of geminiChunks) { @@ -634,6 +748,276 @@ async function* geminiToolUseChunkGenerator(): AsyncGenerator } } +async function* openaiThinkingChunkGenerator(): AsyncGenerator { + for (const chunk of openaiCompletionChunks) { + yield chunk + } +} + +const mockOpenaiApiClient = { + createCompletions: vi.fn().mockImplementation(() => openaiThinkingChunkGenerator()), + getResponseChunkTransformer: vi.fn().mockImplementation(() => { + let hasBeenCollectedWebSearch = false + const collectWebSearchData = ( + chunk: OpenAISdkRawChunk, + contentSource: OpenAISdkRawContentSource, + context: ResponseChunkTransformerContext + ) => { + if (hasBeenCollectedWebSearch) { + return + } + // OpenAI annotations + // @ts-ignore - annotations may not be in standard type definitions + const annotations = contentSource.annotations || chunk.annotations + if (annotations && annotations.length > 0 && annotations[0].type === 'url_citation') { + hasBeenCollectedWebSearch = true + return { + results: annotations, + source: WebSearchSource.OPENAI + } + } + + // Grok citations + // @ts-ignore - citations may not be in standard type definitions + if (context.provider?.id === 'grok' && chunk.citations) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - citations may not be in standard type definitions + results: chunk.citations, + source: WebSearchSource.GROK + } + } + + // Perplexity citations + // @ts-ignore - citations may not be in standard type definitions + if (context.provider?.id === 'perplexity' && chunk.search_results && chunk.search_results.length > 0) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - citations may not be in standard type definitions + results: chunk.search_results, + source: WebSearchSource.PERPLEXITY + } + } + + // OpenRouter citations + // @ts-ignore - citations may not be in standard type definitions + if (context.provider?.id === 'openrouter' && chunk.citations && chunk.citations.length > 0) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - citations may not be in standard type definitions + results: chunk.citations, + source: WebSearchSource.OPENROUTER + } + } + + // Zhipu web search + // @ts-ignore - web_search may not be in standard type definitions + if (context.provider?.id === 'zhipu' && chunk.web_search) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - web_search may not be in standard type definitions + results: chunk.web_search, + source: WebSearchSource.ZHIPU + } + } + + // Hunyuan web search + // @ts-ignore - search_info may not be in standard type definitions + if (context.provider?.id === 'hunyuan' && chunk.search_info?.search_results) { + hasBeenCollectedWebSearch = true + return { + // @ts-ignore - search_info may not be in standard type definitions + results: chunk.search_info.search_results, + source: WebSearchSource.HUNYUAN + } + } + return null + } + + const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = [] + let isFinished = false + let lastUsageInfo: any = null + + /** + * 统一的完成信号发送逻辑 + * - 有 finish_reason 时 + * - 无 finish_reason 但是流正常结束时 + */ + const emitCompletionSignals = (controller: TransformStreamDefaultController) => { + if (isFinished) return + + if (toolCalls.length > 0) { + controller.enqueue({ + type: ChunkType.MCP_TOOL_CREATED, + tool_calls: toolCalls + }) + } + + const usage = lastUsageInfo || { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0 + } + + controller.enqueue({ + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: { usage } + }) + + // 防止重复发送 + isFinished = true + } + + let isFirstThinkingChunk = true + let isFirstTextChunk = true + return (context: ResponseChunkTransformerContext) => ({ + async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController) { + // 持续更新usage信息 + if (chunk.usage) { + lastUsageInfo = { + prompt_tokens: chunk.usage.prompt_tokens || 0, + completion_tokens: chunk.usage.completion_tokens || 0, + total_tokens: (chunk.usage.prompt_tokens || 0) + (chunk.usage.completion_tokens || 0) + } + } + + // 处理chunk + if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) { + for (const choice of chunk.choices) { + if (!choice) continue + + // 对于流式响应,使用 delta;对于非流式响应,使用 message。 + // 然而某些 OpenAI 兼容平台在非流式请求时会错误地返回一个空对象的 delta 字段。 + // 如果 delta 为空对象或content为空,应当忽略它并回退到 message,避免造成内容缺失。 + let contentSource: OpenAISdkRawContentSource | null = null + if ( + 'delta' in choice && + choice.delta && + Object.keys(choice.delta).length > 0 && + (!('content' in choice.delta) || + (typeof choice.delta.content === 'string' && choice.delta.content !== '') || + (typeof (choice.delta as any).reasoning_content === 'string' && + (choice.delta as any).reasoning_content !== '') || + (typeof (choice.delta as any).reasoning === 'string' && (choice.delta as any).reasoning !== '')) + ) { + contentSource = choice.delta + } else if ('message' in choice) { + contentSource = choice.message + } + + if (!contentSource) { + if ('finish_reason' in choice && choice.finish_reason) { + emitCompletionSignals(controller) + } + continue + } + + const webSearchData = collectWebSearchData(chunk, contentSource, context) + if (webSearchData) { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: webSearchData + }) + } + + // 处理推理内容 (e.g. from OpenRouter DeepSeek-R1) + // @ts-ignore - reasoning_content is not in standard OpenAI types but some providers use it + const reasoningText = contentSource.reasoning_content || contentSource.reasoning + if (reasoningText) { + if (isFirstThinkingChunk) { + controller.enqueue({ + type: ChunkType.THINKING_START + } as ThinkingStartChunk) + isFirstThinkingChunk = false + } + controller.enqueue({ + type: ChunkType.THINKING_DELTA, + text: reasoningText + }) + } + + // 处理文本内容 + if (contentSource.content) { + if (isFirstTextChunk) { + controller.enqueue({ + type: ChunkType.TEXT_START + } as TextStartChunk) + isFirstTextChunk = false + } + controller.enqueue({ + type: ChunkType.TEXT_DELTA, + text: contentSource.content + }) + } + + // 处理工具调用 + if (contentSource.tool_calls) { + for (const toolCall of contentSource.tool_calls) { + if ('index' in toolCall) { + const { id, index, function: fun } = toolCall + if (fun?.name) { + toolCalls[index] = { + id: id || '', + function: { + name: fun.name, + arguments: fun.arguments || '' + }, + type: 'function' + } + } else if (fun?.arguments) { + toolCalls[index].function.arguments += fun.arguments + } + } else { + toolCalls.push(toolCall) + } + } + } + + // 处理finish_reason,发送流结束信号 + if ('finish_reason' in choice && choice.finish_reason) { + const webSearchData = collectWebSearchData(chunk, contentSource, context) + if (webSearchData) { + controller.enqueue({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: webSearchData + }) + } + emitCompletionSignals(controller) + } + } + } + }, + + // 流正常结束时,检查是否需要发送完成信号 + flush(controller) { + if (isFinished) return + emitCompletionSignals(controller) + } + }) + }), + getSdkInstance: vi.fn(), + getRequestTransformer: vi.fn().mockImplementation(() => ({ + async transform(params: any) { + return { + payload: { + model: params.assistant?.model?.id || 'gpt-4o', + messages: params.messages || [], + tools: params.tools || [] + }, + metadata: {} + } + } + })), + convertMcpToolsToSdkTools: vi.fn(() => []), + convertSdkToolCallToMcpToolResponse: vi.fn(), + buildSdkMessages: vi.fn(() => []), + extractMessagesFromSdkPayload: vi.fn(() => []), + provider: {} as Provider, + useSystemPromptForTools: true, + getBaseURL: vi.fn(() => 'https://api.openai.com'), + getApiKey: vi.fn(() => 'mock-api-key') +} as unknown as OpenAIAPIClient + // 创建 mock 的 GeminiAPIClient const mockGeminiApiClient = { createCompletions: vi.fn().mockImplementation(() => geminiChunkGenerator()), @@ -1064,6 +1448,95 @@ describe('ApiService', () => { expect(filteredChunks).toEqual(expectedChunks) }) + it('should handle openai thinking chunk correctly', async () => { + const mockCreate = vi.mocked(ApiClientFactory.create) + mockCreate.mockReturnValue(mockOpenaiApiClient as unknown as BaseApiClient) + const AI = new AiProvider(mockProvider as Provider) + const result = await AI.completions({ + callType: 'test', + messages: [], + assistant: { + id: '1', + name: 'test', + prompt: 'test', + model: { + id: 'gpt-4o', + name: 'GPT-4o' + } + } as Assistant, + onChunk: mockOnChunk, + enableReasoning: true, + streamOutput: true + }) + + expect(result).toBeDefined() + expect(ApiClientFactory.create).toHaveBeenCalledWith(mockProvider) + expect(result.stream).toBeDefined() + + const stream = result.stream! as ReadableStream + const reader = stream.getReader() + + const chunks: GenericChunk[] = [] + + while (true) { + const { done, value } = await reader.read() + if (done) break + chunks.push(value) + } + + reader.releaseLock() + + const filteredChunks = chunks.map((chunk) => { + if (chunk.type === ChunkType.THINKING_DELTA || chunk.type === ChunkType.THINKING_COMPLETE) { + delete (chunk as any).thinking_millsec + return chunk + } + if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) { + delete (chunk as any).response.usage + return chunk + } + return chunk + }) + const expectedChunks = [ + { + type: ChunkType.THINKING_START + }, + { + type: ChunkType.THINKING_DELTA, + text: '好的,用户打招呼说“你好' + }, + { + type: ChunkType.THINKING_DELTA, + text: '好的,用户打招呼说“你好”,我需要友好回应。' + }, + { + type: ChunkType.THINKING_COMPLETE, + text: '好的,用户打招呼说“你好”,我需要友好回应。' + }, + { + type: ChunkType.TEXT_START + }, + { + type: ChunkType.TEXT_DELTA, + text: '你好!有什么问题' + }, + { + type: ChunkType.TEXT_DELTA, + text: '你好!有什么问题或者需要我帮忙的吗?' + }, + { + type: ChunkType.TEXT_COMPLETE, + text: '你好!有什么问题或者需要我帮忙的吗?' + }, + { + type: ChunkType.LLM_RESPONSE_COMPLETE, + response: {} + } + ] + + expect(filteredChunks).toEqual(expectedChunks) + }) + // it('should extract tool use responses correctly', async () => { // const mockCreate = vi.mocked(ApiClientFactory.create) // mockCreate.mockReturnValue(mockGeminiToolUseApiClient as unknown as BaseApiClient)