mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-29 23:12:38 +08:00
fix: thinking not display (#8222)
* feat(ThinkingTagExtraction): accumulate thinking content for improved processing - Introduced an `accumulatedThinkingContent` variable to gather content from multiple chunks before enqueuing. - Updated the `ThinkingDeltaChunk` to use the accumulated content instead of individual extraction results, enhancing the coherence of thinking messages. * feat(OpenAIAPIClient): enhance chunk processing for reasoning and content extraction - Updated the OpenAIAPIClient to handle additional fields in response chunks, including `reasoning_content` and `reasoning`, improving the extraction of relevant information. - Introduced a new mock implementation for testing OpenAI completions, ensuring accurate handling of thinking and text chunks in the response. - Enhanced unit tests to validate the processing of OpenAI thinking chunks, ensuring expected behavior and output.
This commit is contained in:
parent
e7d38d340f
commit
720c5d6080
@ -710,7 +710,10 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
choice.delta &&
|
||||
Object.keys(choice.delta).length > 0 &&
|
||||
(!('content' in choice.delta) ||
|
||||
(typeof choice.delta.content === 'string' && choice.delta.content !== ''))
|
||||
(typeof choice.delta.content === 'string' && choice.delta.content !== '') ||
|
||||
(typeof (choice.delta as any).reasoning_content === 'string' &&
|
||||
(choice.delta as any).reasoning_content !== '') ||
|
||||
(typeof (choice.delta as any).reasoning === 'string' && (choice.delta as any).reasoning !== ''))
|
||||
) {
|
||||
contentSource = choice.delta
|
||||
} else if ('message' in choice) {
|
||||
|
||||
@ -66,7 +66,7 @@ export const ThinkingTagExtractionMiddleware: CompletionsMiddleware =
|
||||
let thinkingStartTime = 0
|
||||
|
||||
let isFirstTextChunk = true
|
||||
|
||||
let accumulatedThinkingContent = ''
|
||||
const processedStream = resultFromUpstream.pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
@ -101,9 +101,10 @@ export const ThinkingTagExtractionMiddleware: CompletionsMiddleware =
|
||||
}
|
||||
|
||||
if (extractionResult.content?.trim()) {
|
||||
accumulatedThinkingContent += extractionResult.content
|
||||
const thinkingDeltaChunk: ThinkingDeltaChunk = {
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: extractionResult.content,
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(thinkingDeltaChunk)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { FinishReason, MediaModality } from '@google/genai'
|
||||
import { FunctionCall } from '@google/genai'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { OpenAIAPIClient, ResponseChunkTransformerContext } from '@renderer/aiCore/clients'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||
import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient'
|
||||
@ -14,8 +15,10 @@ import {
|
||||
TextStartChunk,
|
||||
ThinkingStartChunk
|
||||
} from '@renderer/types/chunk'
|
||||
import { GeminiSdkRawChunk } from '@renderer/types/sdk'
|
||||
import { GeminiSdkRawChunk, OpenAISdkRawChunk, OpenAISdkRawContentSource } from '@renderer/types/sdk'
|
||||
import { cloneDeep } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
import { ChatCompletionChunk } from 'openai/resources'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
// Mock the ApiClientFactory
|
||||
@ -615,6 +618,117 @@ const geminiToolUseChunks: GeminiSdkRawChunk[] = [
|
||||
} as GeminiSdkRawChunk
|
||||
]
|
||||
|
||||
const openaiCompletionChunks: OpenAISdkRawChunk[] = [
|
||||
{
|
||||
id: 'cmpl-123',
|
||||
created: 1715811200,
|
||||
model: 'gpt-4o',
|
||||
object: 'chat.completion.chunk',
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: null,
|
||||
role: 'assistant',
|
||||
reasoning_content: ''
|
||||
} as ChatCompletionChunk.Choice.Delta,
|
||||
index: 0,
|
||||
logprobs: null,
|
||||
finish_reason: null
|
||||
} as ChatCompletionChunk.Choice
|
||||
]
|
||||
},
|
||||
{
|
||||
id: 'cmpl-123',
|
||||
created: 1715811200,
|
||||
model: 'gpt-4o',
|
||||
object: 'chat.completion.chunk',
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: null,
|
||||
role: 'assistant',
|
||||
reasoning_content: '好的,用户打招呼说“你好'
|
||||
} as ChatCompletionChunk.Choice.Delta,
|
||||
index: 0,
|
||||
logprobs: null,
|
||||
finish_reason: null
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
id: 'cmpl-123',
|
||||
created: 1715811200,
|
||||
model: 'gpt-4o',
|
||||
object: 'chat.completion.chunk',
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: null,
|
||||
role: 'assistant',
|
||||
reasoning_content: '”,我需要友好回应。'
|
||||
} as ChatCompletionChunk.Choice.Delta,
|
||||
index: 0,
|
||||
logprobs: null,
|
||||
finish_reason: null
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
id: 'cmpl-123',
|
||||
created: 1715811200,
|
||||
model: 'gpt-4o',
|
||||
object: 'chat.completion.chunk',
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: '你好!有什么问题',
|
||||
role: 'assistant',
|
||||
reasoning_content: null
|
||||
} as ChatCompletionChunk.Choice.Delta,
|
||||
index: 0,
|
||||
logprobs: null,
|
||||
finish_reason: null
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
id: 'cmpl-123',
|
||||
created: 1715811200,
|
||||
model: 'gpt-4o',
|
||||
object: 'chat.completion.chunk',
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: '或者需要我帮忙的吗?',
|
||||
role: 'assistant',
|
||||
reasoning_content: null
|
||||
} as ChatCompletionChunk.Choice.Delta,
|
||||
index: 0,
|
||||
logprobs: null,
|
||||
finish_reason: null
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
id: 'cmpl-123',
|
||||
created: 1715811200,
|
||||
model: 'gpt-4o',
|
||||
object: 'chat.completion.chunk',
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: null,
|
||||
role: 'assistant',
|
||||
reasoning_content: null
|
||||
} as ChatCompletionChunk.Choice.Delta,
|
||||
index: 0,
|
||||
logprobs: null,
|
||||
finish_reason: 'stop'
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
// 正确的 async generator 函数
|
||||
async function* geminiChunkGenerator(): AsyncGenerator<GeminiSdkRawChunk> {
|
||||
for (const chunk of geminiChunks) {
|
||||
@ -634,6 +748,276 @@ async function* geminiToolUseChunkGenerator(): AsyncGenerator<GeminiSdkRawChunk>
|
||||
}
|
||||
}
|
||||
|
||||
async function* openaiThinkingChunkGenerator(): AsyncGenerator<OpenAISdkRawChunk> {
|
||||
for (const chunk of openaiCompletionChunks) {
|
||||
yield chunk
|
||||
}
|
||||
}
|
||||
|
||||
const mockOpenaiApiClient = {
|
||||
createCompletions: vi.fn().mockImplementation(() => openaiThinkingChunkGenerator()),
|
||||
getResponseChunkTransformer: vi.fn().mockImplementation(() => {
|
||||
let hasBeenCollectedWebSearch = false
|
||||
const collectWebSearchData = (
|
||||
chunk: OpenAISdkRawChunk,
|
||||
contentSource: OpenAISdkRawContentSource,
|
||||
context: ResponseChunkTransformerContext
|
||||
) => {
|
||||
if (hasBeenCollectedWebSearch) {
|
||||
return
|
||||
}
|
||||
// OpenAI annotations
|
||||
// @ts-ignore - annotations may not be in standard type definitions
|
||||
const annotations = contentSource.annotations || chunk.annotations
|
||||
if (annotations && annotations.length > 0 && annotations[0].type === 'url_citation') {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
results: annotations,
|
||||
source: WebSearchSource.OPENAI
|
||||
}
|
||||
}
|
||||
|
||||
// Grok citations
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
if (context.provider?.id === 'grok' && chunk.citations) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
results: chunk.citations,
|
||||
source: WebSearchSource.GROK
|
||||
}
|
||||
}
|
||||
|
||||
// Perplexity citations
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
if (context.provider?.id === 'perplexity' && chunk.search_results && chunk.search_results.length > 0) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
results: chunk.search_results,
|
||||
source: WebSearchSource.PERPLEXITY
|
||||
}
|
||||
}
|
||||
|
||||
// OpenRouter citations
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
if (context.provider?.id === 'openrouter' && chunk.citations && chunk.citations.length > 0) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
results: chunk.citations,
|
||||
source: WebSearchSource.OPENROUTER
|
||||
}
|
||||
}
|
||||
|
||||
// Zhipu web search
|
||||
// @ts-ignore - web_search may not be in standard type definitions
|
||||
if (context.provider?.id === 'zhipu' && chunk.web_search) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - web_search may not be in standard type definitions
|
||||
results: chunk.web_search,
|
||||
source: WebSearchSource.ZHIPU
|
||||
}
|
||||
}
|
||||
|
||||
// Hunyuan web search
|
||||
// @ts-ignore - search_info may not be in standard type definitions
|
||||
if (context.provider?.id === 'hunyuan' && chunk.search_info?.search_results) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - search_info may not be in standard type definitions
|
||||
results: chunk.search_info.search_results,
|
||||
source: WebSearchSource.HUNYUAN
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = []
|
||||
let isFinished = false
|
||||
let lastUsageInfo: any = null
|
||||
|
||||
/**
|
||||
* 统一的完成信号发送逻辑
|
||||
* - 有 finish_reason 时
|
||||
* - 无 finish_reason 但是流正常结束时
|
||||
*/
|
||||
const emitCompletionSignals = (controller: TransformStreamDefaultController<GenericChunk>) => {
|
||||
if (isFinished) return
|
||||
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
}
|
||||
|
||||
const usage = lastUsageInfo || {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: { usage }
|
||||
})
|
||||
|
||||
// 防止重复发送
|
||||
isFinished = true
|
||||
}
|
||||
|
||||
let isFirstThinkingChunk = true
|
||||
let isFirstTextChunk = true
|
||||
return (context: ResponseChunkTransformerContext) => ({
|
||||
async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
// 持续更新usage信息
|
||||
if (chunk.usage) {
|
||||
lastUsageInfo = {
|
||||
prompt_tokens: chunk.usage.prompt_tokens || 0,
|
||||
completion_tokens: chunk.usage.completion_tokens || 0,
|
||||
total_tokens: (chunk.usage.prompt_tokens || 0) + (chunk.usage.completion_tokens || 0)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理chunk
|
||||
if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) {
|
||||
for (const choice of chunk.choices) {
|
||||
if (!choice) continue
|
||||
|
||||
// 对于流式响应,使用 delta;对于非流式响应,使用 message。
|
||||
// 然而某些 OpenAI 兼容平台在非流式请求时会错误地返回一个空对象的 delta 字段。
|
||||
// 如果 delta 为空对象或content为空,应当忽略它并回退到 message,避免造成内容缺失。
|
||||
let contentSource: OpenAISdkRawContentSource | null = null
|
||||
if (
|
||||
'delta' in choice &&
|
||||
choice.delta &&
|
||||
Object.keys(choice.delta).length > 0 &&
|
||||
(!('content' in choice.delta) ||
|
||||
(typeof choice.delta.content === 'string' && choice.delta.content !== '') ||
|
||||
(typeof (choice.delta as any).reasoning_content === 'string' &&
|
||||
(choice.delta as any).reasoning_content !== '') ||
|
||||
(typeof (choice.delta as any).reasoning === 'string' && (choice.delta as any).reasoning !== ''))
|
||||
) {
|
||||
contentSource = choice.delta
|
||||
} else if ('message' in choice) {
|
||||
contentSource = choice.message
|
||||
}
|
||||
|
||||
if (!contentSource) {
|
||||
if ('finish_reason' in choice && choice.finish_reason) {
|
||||
emitCompletionSignals(controller)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
const webSearchData = collectWebSearchData(chunk, contentSource, context)
|
||||
if (webSearchData) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: webSearchData
|
||||
})
|
||||
}
|
||||
|
||||
// 处理推理内容 (e.g. from OpenRouter DeepSeek-R1)
|
||||
// @ts-ignore - reasoning_content is not in standard OpenAI types but some providers use it
|
||||
const reasoningText = contentSource.reasoning_content || contentSource.reasoning
|
||||
if (reasoningText) {
|
||||
if (isFirstThinkingChunk) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_START
|
||||
} as ThinkingStartChunk)
|
||||
isFirstThinkingChunk = false
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: reasoningText
|
||||
})
|
||||
}
|
||||
|
||||
// 处理文本内容
|
||||
if (contentSource.content) {
|
||||
if (isFirstTextChunk) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_START
|
||||
} as TextStartChunk)
|
||||
isFirstTextChunk = false
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: contentSource.content
|
||||
})
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
if (contentSource.tool_calls) {
|
||||
for (const toolCall of contentSource.tool_calls) {
|
||||
if ('index' in toolCall) {
|
||||
const { id, index, function: fun } = toolCall
|
||||
if (fun?.name) {
|
||||
toolCalls[index] = {
|
||||
id: id || '',
|
||||
function: {
|
||||
name: fun.name,
|
||||
arguments: fun.arguments || ''
|
||||
},
|
||||
type: 'function'
|
||||
}
|
||||
} else if (fun?.arguments) {
|
||||
toolCalls[index].function.arguments += fun.arguments
|
||||
}
|
||||
} else {
|
||||
toolCalls.push(toolCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理finish_reason,发送流结束信号
|
||||
if ('finish_reason' in choice && choice.finish_reason) {
|
||||
const webSearchData = collectWebSearchData(chunk, contentSource, context)
|
||||
if (webSearchData) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: webSearchData
|
||||
})
|
||||
}
|
||||
emitCompletionSignals(controller)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
// 流正常结束时,检查是否需要发送完成信号
|
||||
flush(controller) {
|
||||
if (isFinished) return
|
||||
emitCompletionSignals(controller)
|
||||
}
|
||||
})
|
||||
}),
|
||||
getSdkInstance: vi.fn(),
|
||||
getRequestTransformer: vi.fn().mockImplementation(() => ({
|
||||
async transform(params: any) {
|
||||
return {
|
||||
payload: {
|
||||
model: params.assistant?.model?.id || 'gpt-4o',
|
||||
messages: params.messages || [],
|
||||
tools: params.tools || []
|
||||
},
|
||||
metadata: {}
|
||||
}
|
||||
}
|
||||
})),
|
||||
convertMcpToolsToSdkTools: vi.fn(() => []),
|
||||
convertSdkToolCallToMcpToolResponse: vi.fn(),
|
||||
buildSdkMessages: vi.fn(() => []),
|
||||
extractMessagesFromSdkPayload: vi.fn(() => []),
|
||||
provider: {} as Provider,
|
||||
useSystemPromptForTools: true,
|
||||
getBaseURL: vi.fn(() => 'https://api.openai.com'),
|
||||
getApiKey: vi.fn(() => 'mock-api-key')
|
||||
} as unknown as OpenAIAPIClient
|
||||
|
||||
// 创建 mock 的 GeminiAPIClient
|
||||
const mockGeminiApiClient = {
|
||||
createCompletions: vi.fn().mockImplementation(() => geminiChunkGenerator()),
|
||||
@ -1064,6 +1448,95 @@ describe('ApiService', () => {
|
||||
expect(filteredChunks).toEqual(expectedChunks)
|
||||
})
|
||||
|
||||
it('should handle openai thinking chunk correctly', async () => {
|
||||
const mockCreate = vi.mocked(ApiClientFactory.create)
|
||||
mockCreate.mockReturnValue(mockOpenaiApiClient as unknown as BaseApiClient)
|
||||
const AI = new AiProvider(mockProvider as Provider)
|
||||
const result = await AI.completions({
|
||||
callType: 'test',
|
||||
messages: [],
|
||||
assistant: {
|
||||
id: '1',
|
||||
name: 'test',
|
||||
prompt: 'test',
|
||||
model: {
|
||||
id: 'gpt-4o',
|
||||
name: 'GPT-4o'
|
||||
}
|
||||
} as Assistant,
|
||||
onChunk: mockOnChunk,
|
||||
enableReasoning: true,
|
||||
streamOutput: true
|
||||
})
|
||||
|
||||
expect(result).toBeDefined()
|
||||
expect(ApiClientFactory.create).toHaveBeenCalledWith(mockProvider)
|
||||
expect(result.stream).toBeDefined()
|
||||
|
||||
const stream = result.stream! as ReadableStream<GenericChunk>
|
||||
const reader = stream.getReader()
|
||||
|
||||
const chunks: GenericChunk[] = []
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
chunks.push(value)
|
||||
}
|
||||
|
||||
reader.releaseLock()
|
||||
|
||||
const filteredChunks = chunks.map((chunk) => {
|
||||
if (chunk.type === ChunkType.THINKING_DELTA || chunk.type === ChunkType.THINKING_COMPLETE) {
|
||||
delete (chunk as any).thinking_millsec
|
||||
return chunk
|
||||
}
|
||||
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||
delete (chunk as any).response.usage
|
||||
return chunk
|
||||
}
|
||||
return chunk
|
||||
})
|
||||
const expectedChunks = [
|
||||
{
|
||||
type: ChunkType.THINKING_START
|
||||
},
|
||||
{
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: '好的,用户打招呼说“你好'
|
||||
},
|
||||
{
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: '好的,用户打招呼说“你好”,我需要友好回应。'
|
||||
},
|
||||
{
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: '好的,用户打招呼说“你好”,我需要友好回应。'
|
||||
},
|
||||
{
|
||||
type: ChunkType.TEXT_START
|
||||
},
|
||||
{
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: '你好!有什么问题'
|
||||
},
|
||||
{
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: '你好!有什么问题或者需要我帮忙的吗?'
|
||||
},
|
||||
{
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: '你好!有什么问题或者需要我帮忙的吗?'
|
||||
},
|
||||
{
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {}
|
||||
}
|
||||
]
|
||||
|
||||
expect(filteredChunks).toEqual(expectedChunks)
|
||||
})
|
||||
|
||||
// it('should extract tool use responses correctly', async () => {
|
||||
// const mockCreate = vi.mocked(ApiClientFactory.create)
|
||||
// mockCreate.mockReturnValue(mockGeminiToolUseApiClient as unknown as BaseApiClient)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user