fix: thinking not display (#8222)

* feat(ThinkingTagExtraction): accumulate thinking content for improved processing

- Introduced an `accumulatedThinkingContent` variable to gather content from multiple chunks before enqueuing.
- Updated the `ThinkingDeltaChunk` to use the accumulated content instead of individual extraction results, enhancing the coherence of thinking messages.

* feat(OpenAIAPIClient): enhance chunk processing for reasoning and content extraction

- Updated the OpenAIAPIClient to handle additional fields in response chunks, including `reasoning_content` and `reasoning`, improving the extraction of relevant information.
- Introduced a new mock implementation for testing OpenAI completions, ensuring accurate handling of thinking and text chunks in the response.
- Enhanced unit tests to validate the processing of OpenAI thinking chunks, ensuring expected behavior and output.
This commit is contained in:
SuYao 2025-07-17 11:40:15 +08:00 committed by GitHub
parent e7d38d340f
commit 720c5d6080
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 481 additions and 4 deletions

View File

@ -710,7 +710,10 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
choice.delta &&
Object.keys(choice.delta).length > 0 &&
(!('content' in choice.delta) ||
(typeof choice.delta.content === 'string' && choice.delta.content !== ''))
(typeof choice.delta.content === 'string' && choice.delta.content !== '') ||
(typeof (choice.delta as any).reasoning_content === 'string' &&
(choice.delta as any).reasoning_content !== '') ||
(typeof (choice.delta as any).reasoning === 'string' && (choice.delta as any).reasoning !== ''))
) {
contentSource = choice.delta
} else if ('message' in choice) {

View File

@ -66,7 +66,7 @@ export const ThinkingTagExtractionMiddleware: CompletionsMiddleware =
let thinkingStartTime = 0
let isFirstTextChunk = true
let accumulatedThinkingContent = ''
const processedStream = resultFromUpstream.pipeThrough(
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
@ -101,9 +101,10 @@ export const ThinkingTagExtractionMiddleware: CompletionsMiddleware =
}
if (extractionResult.content?.trim()) {
accumulatedThinkingContent += extractionResult.content
const thinkingDeltaChunk: ThinkingDeltaChunk = {
type: ChunkType.THINKING_DELTA,
text: extractionResult.content,
text: accumulatedThinkingContent,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(thinkingDeltaChunk)

View File

@ -1,6 +1,7 @@
import { FinishReason, MediaModality } from '@google/genai'
import { FunctionCall } from '@google/genai'
import AiProvider from '@renderer/aiCore'
import { OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/clients'
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient'
@ -14,8 +15,10 @@ import {
TextStartChunk,
ThinkingStartChunk
} from '@renderer/types/chunk'
import { GeminiSdkRawChunk } from '@renderer/types/sdk'
import { GeminiSdkRawChunk, OpenAISdkRawChunk, OpenAISdkRawContentSource } from '@renderer/types/sdk'
import { cloneDeep } from 'lodash'
import OpenAI from 'openai'
import { ChatCompletionChunk } from 'openai/resources'
import { beforeEach, describe, expect, it, vi } from 'vitest'
// Mock the ApiClientFactory
@ -615,6 +618,117 @@ const geminiToolUseChunks: GeminiSdkRawChunk[] = [
} as GeminiSdkRawChunk
]
const openaiCompletionChunks: OpenAISdkRawChunk[] = [
{
id: 'cmpl-123',
created: 1715811200,
model: 'gpt-4o',
object: 'chat.completion.chunk',
choices: [
{
delta: {
content: null,
role: 'assistant',
reasoning_content: ''
} as ChatCompletionChunk.Choice.Delta,
index: 0,
logprobs: null,
finish_reason: null
} as ChatCompletionChunk.Choice
]
},
{
id: 'cmpl-123',
created: 1715811200,
model: 'gpt-4o',
object: 'chat.completion.chunk',
choices: [
{
delta: {
content: null,
role: 'assistant',
reasoning_content: '好的,用户打招呼说“你好'
} as ChatCompletionChunk.Choice.Delta,
index: 0,
logprobs: null,
finish_reason: null
}
]
},
{
id: 'cmpl-123',
created: 1715811200,
model: 'gpt-4o',
object: 'chat.completion.chunk',
choices: [
{
delta: {
content: null,
role: 'assistant',
reasoning_content: '”,我需要友好回应。'
} as ChatCompletionChunk.Choice.Delta,
index: 0,
logprobs: null,
finish_reason: null
}
]
},
{
id: 'cmpl-123',
created: 1715811200,
model: 'gpt-4o',
object: 'chat.completion.chunk',
choices: [
{
delta: {
content: '你好!有什么问题',
role: 'assistant',
reasoning_content: null
} as ChatCompletionChunk.Choice.Delta,
index: 0,
logprobs: null,
finish_reason: null
}
]
},
{
id: 'cmpl-123',
created: 1715811200,
model: 'gpt-4o',
object: 'chat.completion.chunk',
choices: [
{
delta: {
content: '或者需要我帮忙的吗?',
role: 'assistant',
reasoning_content: null
} as ChatCompletionChunk.Choice.Delta,
index: 0,
logprobs: null,
finish_reason: null
}
]
},
{
id: 'cmpl-123',
created: 1715811200,
model: 'gpt-4o',
object: 'chat.completion.chunk',
choices: [
{
delta: {
content: null,
role: 'assistant',
reasoning_content: null
} as ChatCompletionChunk.Choice.Delta,
index: 0,
logprobs: null,
finish_reason: 'stop'
}
]
}
]
// 正确的 async generator 函数
async function* geminiChunkGenerator(): AsyncGenerator<GeminiSdkRawChunk> {
for (const chunk of geminiChunks) {
@ -634,6 +748,276 @@ async function* geminiToolUseChunkGenerator(): AsyncGenerator<GeminiSdkRawChunk>
}
}
async function* openaiThinkingChunkGenerator(): AsyncGenerator<OpenAISdkRawChunk> {
for (const chunk of openaiCompletionChunks) {
yield chunk
}
}
const mockOpenaiApiClient = {
createCompletions: vi.fn().mockImplementation(() => openaiThinkingChunkGenerator()),
getResponseChunkTransformer: vi.fn().mockImplementation(() => {
let hasBeenCollectedWebSearch = false
const collectWebSearchData = (
chunk: OpenAISdkRawChunk,
contentSource: OpenAISdkRawContentSource,
context: ResponseChunkTransformerContext
) => {
if (hasBeenCollectedWebSearch) {
return
}
// OpenAI annotations
// @ts-ignore - annotations may not be in standard type definitions
const annotations = contentSource.annotations || chunk.annotations
if (annotations && annotations.length > 0 && annotations[0].type === 'url_citation') {
hasBeenCollectedWebSearch = true
return {
results: annotations,
source: WebSearchSource.OPENAI
}
}
// Grok citations
// @ts-ignore - citations may not be in standard type definitions
if (context.provider?.id === 'grok' && chunk.citations) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - citations may not be in standard type definitions
results: chunk.citations,
source: WebSearchSource.GROK
}
}
// Perplexity citations
// @ts-ignore - citations may not be in standard type definitions
if (context.provider?.id === 'perplexity' && chunk.search_results && chunk.search_results.length > 0) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - citations may not be in standard type definitions
results: chunk.search_results,
source: WebSearchSource.PERPLEXITY
}
}
// OpenRouter citations
// @ts-ignore - citations may not be in standard type definitions
if (context.provider?.id === 'openrouter' && chunk.citations && chunk.citations.length > 0) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - citations may not be in standard type definitions
results: chunk.citations,
source: WebSearchSource.OPENROUTER
}
}
// Zhipu web search
// @ts-ignore - web_search may not be in standard type definitions
if (context.provider?.id === 'zhipu' && chunk.web_search) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - web_search may not be in standard type definitions
results: chunk.web_search,
source: WebSearchSource.ZHIPU
}
}
// Hunyuan web search
// @ts-ignore - search_info may not be in standard type definitions
if (context.provider?.id === 'hunyuan' && chunk.search_info?.search_results) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - search_info may not be in standard type definitions
results: chunk.search_info.search_results,
source: WebSearchSource.HUNYUAN
}
}
return null
}
const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = []
let isFinished = false
let lastUsageInfo: any = null
/**
*
* - finish_reason
* - finish_reason
*/
const emitCompletionSignals = (controller: TransformStreamDefaultController<GenericChunk>) => {
if (isFinished) return
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
const usage = lastUsageInfo || {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: { usage }
})
// 防止重复发送
isFinished = true
}
let isFirstThinkingChunk = true
let isFirstTextChunk = true
return (context: ResponseChunkTransformerContext) => ({
async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
// 持续更新usage信息
if (chunk.usage) {
lastUsageInfo = {
prompt_tokens: chunk.usage.prompt_tokens || 0,
completion_tokens: chunk.usage.completion_tokens || 0,
total_tokens: (chunk.usage.prompt_tokens || 0) + (chunk.usage.completion_tokens || 0)
}
}
// 处理chunk
if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) {
for (const choice of chunk.choices) {
if (!choice) continue
// 对于流式响应,使用 delta对于非流式响应使用 message。
// 然而某些 OpenAI 兼容平台在非流式请求时会错误地返回一个空对象的 delta 字段。
// 如果 delta 为空对象或content为空应当忽略它并回退到 message避免造成内容缺失。
let contentSource: OpenAISdkRawContentSource | null = null
if (
'delta' in choice &&
choice.delta &&
Object.keys(choice.delta).length > 0 &&
(!('content' in choice.delta) ||
(typeof choice.delta.content === 'string' && choice.delta.content !== '') ||
(typeof (choice.delta as any).reasoning_content === 'string' &&
(choice.delta as any).reasoning_content !== '') ||
(typeof (choice.delta as any).reasoning === 'string' && (choice.delta as any).reasoning !== ''))
) {
contentSource = choice.delta
} else if ('message' in choice) {
contentSource = choice.message
}
if (!contentSource) {
if ('finish_reason' in choice && choice.finish_reason) {
emitCompletionSignals(controller)
}
continue
}
const webSearchData = collectWebSearchData(chunk, contentSource, context)
if (webSearchData) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: webSearchData
})
}
// 处理推理内容 (e.g. from OpenRouter DeepSeek-R1)
// @ts-ignore - reasoning_content is not in standard OpenAI types but some providers use it
const reasoningText = contentSource.reasoning_content || contentSource.reasoning
if (reasoningText) {
if (isFirstThinkingChunk) {
controller.enqueue({
type: ChunkType.THINKING_START
} as ThinkingStartChunk)
isFirstThinkingChunk = false
}
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: reasoningText
})
}
// 处理文本内容
if (contentSource.content) {
if (isFirstTextChunk) {
controller.enqueue({
type: ChunkType.TEXT_START
} as TextStartChunk)
isFirstTextChunk = false
}
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: contentSource.content
})
}
// 处理工具调用
if (contentSource.tool_calls) {
for (const toolCall of contentSource.tool_calls) {
if ('index' in toolCall) {
const { id, index, function: fun } = toolCall
if (fun?.name) {
toolCalls[index] = {
id: id || '',
function: {
name: fun.name,
arguments: fun.arguments || ''
},
type: 'function'
}
} else if (fun?.arguments) {
toolCalls[index].function.arguments += fun.arguments
}
} else {
toolCalls.push(toolCall)
}
}
}
// 处理finish_reason发送流结束信号
if ('finish_reason' in choice && choice.finish_reason) {
const webSearchData = collectWebSearchData(chunk, contentSource, context)
if (webSearchData) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: webSearchData
})
}
emitCompletionSignals(controller)
}
}
}
},
// 流正常结束时,检查是否需要发送完成信号
flush(controller) {
if (isFinished) return
emitCompletionSignals(controller)
}
})
}),
getSdkInstance: vi.fn(),
getRequestTransformer: vi.fn().mockImplementation(() => ({
async transform(params: any) {
return {
payload: {
model: params.assistant?.model?.id || 'gpt-4o',
messages: params.messages || [],
tools: params.tools || []
},
metadata: {}
}
}
})),
convertMcpToolsToSdkTools: vi.fn(() => []),
convertSdkToolCallToMcpToolResponse: vi.fn(),
buildSdkMessages: vi.fn(() => []),
extractMessagesFromSdkPayload: vi.fn(() => []),
provider: {} as Provider,
useSystemPromptForTools: true,
getBaseURL: vi.fn(() => 'https://api.openai.com'),
getApiKey: vi.fn(() => 'mock-api-key')
} as unknown as OpenAIAPIClient
// 创建 mock 的 GeminiAPIClient
const mockGeminiApiClient = {
createCompletions: vi.fn().mockImplementation(() => geminiChunkGenerator()),
@ -1064,6 +1448,95 @@ describe('ApiService', () => {
expect(filteredChunks).toEqual(expectedChunks)
})
it('should handle openai thinking chunk correctly', async () => {
const mockCreate = vi.mocked(ApiClientFactory.create)
mockCreate.mockReturnValue(mockOpenaiApiClient as unknown as BaseApiClient)
const AI = new AiProvider(mockProvider as Provider)
const result = await AI.completions({
callType: 'test',
messages: [],
assistant: {
id: '1',
name: 'test',
prompt: 'test',
model: {
id: 'gpt-4o',
name: 'GPT-4o'
}
} as Assistant,
onChunk: mockOnChunk,
enableReasoning: true,
streamOutput: true
})
expect(result).toBeDefined()
expect(ApiClientFactory.create).toHaveBeenCalledWith(mockProvider)
expect(result.stream).toBeDefined()
const stream = result.stream! as ReadableStream<GenericChunk>
const reader = stream.getReader()
const chunks: GenericChunk[] = []
while (true) {
const { done, value } = await reader.read()
if (done) break
chunks.push(value)
}
reader.releaseLock()
const filteredChunks = chunks.map((chunk) => {
if (chunk.type === ChunkType.THINKING_DELTA || chunk.type === ChunkType.THINKING_COMPLETE) {
delete (chunk as any).thinking_millsec
return chunk
}
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
delete (chunk as any).response.usage
return chunk
}
return chunk
})
const expectedChunks = [
{
type: ChunkType.THINKING_START
},
{
type: ChunkType.THINKING_DELTA,
text: '好的,用户打招呼说“你好'
},
{
type: ChunkType.THINKING_DELTA,
text: '好的,用户打招呼说“你好”,我需要友好回应。'
},
{
type: ChunkType.THINKING_COMPLETE,
text: '好的,用户打招呼说“你好”,我需要友好回应。'
},
{
type: ChunkType.TEXT_START
},
{
type: ChunkType.TEXT_DELTA,
text: '你好!有什么问题'
},
{
type: ChunkType.TEXT_DELTA,
text: '你好!有什么问题或者需要我帮忙的吗?'
},
{
type: ChunkType.TEXT_COMPLETE,
text: '你好!有什么问题或者需要我帮忙的吗?'
},
{
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {}
}
]
expect(filteredChunks).toEqual(expectedChunks)
})
// it('should extract tool use responses correctly', async () => {
// const mockCreate = vi.mocked(ApiClientFactory.create)
// mockCreate.mockReturnValue(mockGeminiToolUseApiClient as unknown as BaseApiClient)