mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-26 11:44:28 +08:00
fix: openai response tool use (#7332)
* fix: openai response tool use - Added OpenAIResponseStreamListener interface for handling OpenAI response streams. - Implemented attachRawStreamListener method in OpenAIResponseAPIClient to manage raw output. - Updated RawStreamListenerMiddleware to integrate OpenAI response handling. - Refactored BaseApiClient to remove unused attachRawStreamListener method. - Improved buildSdkMessages to handle OpenAI response formats. * fix: remove logging from StreamAdapterMiddleware - Removed Logger.info call from StreamAdapterMiddleware to streamline output and reduce unnecessary logging. * fix: update attachRawStreamListener to return a Promise - Changed attachRawStreamListener method in OpenAIResponseAPIClient to be asynchronous, returning a Promise for better handling of raw output. - Updated RawStreamListenerMiddleware to await the result of attachRawStreamListener, ensuring proper flow of data handling. * refactor: enhance attachRawStreamListener to return a ReadableStream - Updated the attachRawStreamListener method in OpenAIResponseAPIClient to return a ReadableStream, allowing for more efficient handling of streamed responses. - Modified RawStreamListenerMiddleware to accommodate the new return type, ensuring proper integration of the transformed stream into the middleware flow. * refactor: update getResponseChunkTransformer to accept CompletionsContext - Modified the getResponseChunkTransformer method in BaseApiClient and its implementations to accept a CompletionsContext parameter, enhancing the flexibility of response handling. - Adjusted related middleware and client classes to ensure compatibility with the new method signature, improving the overall integration of response transformations. * refactor: update getResponseChunkTransformer to accept CompletionsContext - Modified the getResponseChunkTransformer method in AihubmixAPIClient to accept a CompletionsContext parameter, enhancing the flexibility of response handling. - Ensured compatibility with the updated method signature across related client classes.
This commit is contained in:
parent
f42054ed03
commit
48016d7620
@ -20,6 +20,7 @@ import {
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import { CompletionsContext } from '../middleware/types'
|
||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||
import { BaseApiClient } from './BaseApiClient'
|
||||
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||
@ -163,8 +164,8 @@ export class AihubmixAPIClient extends BaseApiClient {
|
||||
return this.currentClient.getRequestTransformer()
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<SdkRawChunk> {
|
||||
return this.currentClient.getResponseChunkTransformer()
|
||||
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<SdkRawChunk> {
|
||||
return this.currentClient.getResponseChunkTransformer(ctx)
|
||||
}
|
||||
|
||||
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
|
||||
|
||||
@ -42,7 +42,8 @@ import { defaultTimeout } from '@shared/config/constant'
|
||||
import Logger from 'electron-log/renderer'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { ApiClient, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from './types'
|
||||
import { CompletionsContext } from '../middleware/types'
|
||||
import { ApiClient, RequestTransformer, ResponseChunkTransformer } from './types'
|
||||
|
||||
/**
|
||||
* Abstract base class for API clients.
|
||||
@ -95,7 +96,7 @@ export abstract class BaseApiClient<
|
||||
// 在 CoreRequestToSdkParamsMiddleware中使用
|
||||
abstract getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
|
||||
// 在RawSdkChunkToGenericChunkMiddleware中使用
|
||||
abstract getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
|
||||
abstract getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
|
||||
|
||||
/**
|
||||
* 工具转换
|
||||
@ -129,17 +130,6 @@ export abstract class BaseApiClient<
|
||||
*/
|
||||
abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
|
||||
|
||||
/**
|
||||
* 附加原始流监听器
|
||||
*/
|
||||
public attachRawStreamListener<TListener extends RawStreamListener<TRawChunk>>(
|
||||
rawOutput: TRawOutput,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
_listener: TListener
|
||||
): TRawOutput {
|
||||
return rawOutput
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用函数
|
||||
**/
|
||||
|
||||
@ -367,7 +367,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
* Anthropic专用的原始流监听器
|
||||
* 处理MessageStream对象的特定事件
|
||||
*/
|
||||
override attachRawStreamListener(
|
||||
attachRawStreamListener(
|
||||
rawOutput: AnthropicSdkRawOutput,
|
||||
listener: RawStreamListener<AnthropicSdkRawChunk>
|
||||
): AnthropicSdkRawOutput {
|
||||
|
||||
@ -494,7 +494,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
}
|
||||
|
||||
// 在RawSdkChunkToGenericChunkMiddleware中使用
|
||||
getResponseChunkTransformer = (): ResponseChunkTransformer<OpenAISdkRawChunk> => {
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<OpenAISdkRawChunk> {
|
||||
let hasBeenCollectedWebSearch = false
|
||||
const collectWebSearchData = (
|
||||
chunk: OpenAISdkRawChunk,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { CompletionsContext } from '@renderer/aiCore/middleware/types'
|
||||
import {
|
||||
isOpenAIChatCompletionOnlyModel,
|
||||
isSupportedReasoningEffortOpenAIModel,
|
||||
@ -38,6 +39,7 @@ import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { MB } from '@shared/config/constant'
|
||||
import { isEmpty } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
import { ResponseInput } from 'openai/resources/responses/responses'
|
||||
|
||||
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
import { OpenAIAPIClient } from './OpenAIApiClient'
|
||||
@ -225,9 +227,15 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
return
|
||||
}
|
||||
|
||||
private convertResponseToMessageContent(response: OpenAI.Responses.Response): ResponseInput {
|
||||
const content: OpenAI.Responses.ResponseInput = []
|
||||
content.push(...response.output)
|
||||
return content
|
||||
}
|
||||
|
||||
public buildSdkMessages(
|
||||
currentReqMessages: OpenAIResponseSdkMessageParam[],
|
||||
output: string | undefined,
|
||||
output: OpenAI.Responses.Response | undefined,
|
||||
toolResults: OpenAIResponseSdkMessageParam[],
|
||||
toolCalls: OpenAIResponseSdkToolCall[]
|
||||
): OpenAIResponseSdkMessageParam[] {
|
||||
@ -239,11 +247,9 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
return [...currentReqMessages, ...(toolCalls || []), ...(toolResults || [])]
|
||||
}
|
||||
|
||||
const assistantMessage: OpenAIResponseSdkMessageParam = {
|
||||
role: 'assistant',
|
||||
content: [{ type: 'input_text', text: output }]
|
||||
}
|
||||
const newReqMessages = [...currentReqMessages, assistantMessage, ...(toolCalls || []), ...(toolResults || [])]
|
||||
const content = this.convertResponseToMessageContent(output)
|
||||
|
||||
const newReqMessages = [...currentReqMessages, ...content, ...(toolResults || [])]
|
||||
return newReqMessages
|
||||
}
|
||||
|
||||
@ -415,13 +421,17 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
|
||||
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
|
||||
const toolCalls: OpenAIResponseSdkToolCall[] = []
|
||||
const outputItems: OpenAI.Responses.ResponseOutputItem[] = []
|
||||
let hasBeenCollectedToolCalls = false
|
||||
return () => ({
|
||||
async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
// 处理chunk
|
||||
if ('output' in chunk) {
|
||||
if (ctx._internal?.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState.output = chunk
|
||||
}
|
||||
for (const output of chunk.output) {
|
||||
switch (output.type) {
|
||||
case 'message':
|
||||
@ -463,6 +473,22 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
})
|
||||
}
|
||||
}
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: chunk.usage?.input_tokens || 0,
|
||||
completion_tokens: chunk.usage?.output_tokens || 0,
|
||||
total_tokens: chunk.usage?.total_tokens || 0
|
||||
}
|
||||
}
|
||||
})
|
||||
} else {
|
||||
switch (chunk.type) {
|
||||
case 'response.output_item.added':
|
||||
@ -510,7 +536,8 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
if (outputItem.type === 'function_call') {
|
||||
toolCalls.push({
|
||||
...outputItem,
|
||||
arguments: chunk.arguments
|
||||
arguments: chunk.arguments,
|
||||
status: 'completed'
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -526,15 +553,26 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
}
|
||||
})
|
||||
}
|
||||
if (toolCalls.length > 0) {
|
||||
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
hasBeenCollectedToolCalls = true
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'response.completed': {
|
||||
if (ctx._internal?.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState.output = chunk.response
|
||||
}
|
||||
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
hasBeenCollectedToolCalls = true
|
||||
}
|
||||
const completion_tokens = chunk.response.usage?.output_tokens || 0
|
||||
const total_tokens = chunk.response.usage?.total_tokens || 0
|
||||
controller.enqueue({
|
||||
|
||||
@ -3,6 +3,8 @@ import { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@r
|
||||
import { Provider } from '@renderer/types'
|
||||
import {
|
||||
AnthropicSdkRawChunk,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAISdkRawChunk,
|
||||
SdkMessageParam,
|
||||
SdkParams,
|
||||
@ -14,6 +16,7 @@ import {
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CompletionsParams, GenericChunk } from '../middleware/schemas'
|
||||
import { CompletionsContext } from '../middleware/types'
|
||||
|
||||
/**
|
||||
* 原始流监听器接口
|
||||
@ -33,6 +36,14 @@ export interface OpenAIStreamListener extends RawStreamListener<OpenAISdkRawChun
|
||||
onFinishReason?: (reason: string) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI Response 专用的流监听器
|
||||
*/
|
||||
export interface OpenAIResponseStreamListener<TChunk extends OpenAIResponseSdkRawChunk = OpenAIResponseSdkRawChunk>
|
||||
extends RawStreamListener<TChunk> {
|
||||
onMessage?: (response: OpenAIResponseSdkRawOutput) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic 专用的流监听器
|
||||
*/
|
||||
@ -101,7 +112,7 @@ export interface ApiClient<
|
||||
// SDK相关方法
|
||||
getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
|
||||
getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
|
||||
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
|
||||
|
||||
// 原始流监听方法
|
||||
attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener<TRawChunk>): TRawOutput
|
||||
|
||||
@ -76,7 +76,7 @@ export default class AiProvider {
|
||||
if (!(this.apiClient instanceof OpenAIAPIClient)) {
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
}
|
||||
if (!(this.apiClient instanceof AnthropicAPIClient)) {
|
||||
if (!(this.apiClient instanceof AnthropicAPIClient) && !(this.apiClient instanceof OpenAIResponseAPIClient)) {
|
||||
builder.remove(RawStreamListenerMiddlewareName)
|
||||
}
|
||||
if (!params.enableWebSearch) {
|
||||
|
||||
@ -15,8 +15,6 @@ export const RawStreamListenerMiddleware: CompletionsMiddleware =
|
||||
|
||||
// 在这里可以监听到从SDK返回的最原始流
|
||||
if (result.rawOutput) {
|
||||
console.log(`[${MIDDLEWARE_NAME}] 检测到原始SDK输出,准备附加监听器`)
|
||||
|
||||
const providerType = ctx.apiClientInstance.provider.type
|
||||
// TODO: 后面下放到AnthropicAPIClient
|
||||
if (providerType === 'anthropic') {
|
||||
|
||||
@ -37,7 +37,7 @@ export const ResponseTransformMiddleware: CompletionsMiddleware =
|
||||
}
|
||||
|
||||
// 获取响应转换器
|
||||
const responseChunkTransformer = apiClient.getResponseChunkTransformer?.()
|
||||
const responseChunkTransformer = apiClient.getResponseChunkTransformer(ctx)
|
||||
if (!responseChunkTransformer) {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] No ResponseChunkTransformer available, skipping transformation`)
|
||||
return result
|
||||
|
||||
@ -25,7 +25,6 @@ export const StreamAdapterMiddleware: CompletionsMiddleware =
|
||||
// 但是这个中间件的职责是流适配,是否在这调用优待商榷
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
if (
|
||||
result.rawOutput &&
|
||||
!(result.rawOutput instanceof ReadableStream) &&
|
||||
|
||||
@ -14,8 +14,6 @@ export const TransformCoreToSdkParamsMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
Logger.debug(`🔄 [${MIDDLEWARE_NAME}] Starting core to SDK params transformation:`, ctx)
|
||||
|
||||
const internal = ctx._internal
|
||||
|
||||
// 🔧 检测递归调用:检查 params 中是否携带了预处理的 SDK 消息
|
||||
|
||||
Loading…
Reference in New Issue
Block a user