mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-12 00:49:14 +08:00
fix: refactor provider middleware (#7164)
This commit is contained in:
parent
f2d4255193
commit
eeb504d447
@ -11,6 +11,7 @@ import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient'
|
|||||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||||
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
|
||||||
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
||||||
import { applyCompletionsMiddlewares } from './middleware/composer'
|
import { applyCompletionsMiddlewares } from './middleware/composer'
|
||||||
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
||||||
@ -62,6 +63,7 @@ export default class AiProvider {
|
|||||||
builder.clear()
|
builder.clear()
|
||||||
builder
|
builder
|
||||||
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
||||||
|
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
|
||||||
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
||||||
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import { Chunk } from '@renderer/types/chunk'
|
import { Chunk } from '@renderer/types/chunk'
|
||||||
import { isAbortError } from '@renderer/utils/error'
|
|
||||||
|
|
||||||
import { CompletionsResult } from '../schemas'
|
import { CompletionsResult } from '../schemas'
|
||||||
import { CompletionsContext } from '../types'
|
import { CompletionsContext } from '../types'
|
||||||
@ -26,30 +25,27 @@ export const ErrorHandlerMiddleware =
|
|||||||
// 尝试执行下一个中间件
|
// 尝试执行下一个中间件
|
||||||
return await next(ctx, params)
|
return await next(ctx, params)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
let errorStream: ReadableStream<Chunk> | undefined
|
console.log('ErrorHandlerMiddleware_error', error)
|
||||||
// 有些sdk的abort error 是直接抛出的
|
// 1. 使用通用的工具函数将错误解析为标准格式
|
||||||
if (!isAbortError(error)) {
|
const errorChunk = createErrorChunk(error)
|
||||||
// 1. 使用通用的工具函数将错误解析为标准格式
|
// 2. 调用从外部传入的 onError 回调
|
||||||
const errorChunk = createErrorChunk(error)
|
if (params.onError) {
|
||||||
// 2. 调用从外部传入的 onError 回调
|
params.onError(error)
|
||||||
if (params.onError) {
|
|
||||||
params.onError(error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
|
|
||||||
if (shouldThrow) {
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
|
|
||||||
errorStream = new ReadableStream<Chunk>({
|
|
||||||
start(controller) {
|
|
||||||
controller.enqueue(errorChunk)
|
|
||||||
controller.close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
|
||||||
|
if (shouldThrow) {
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
|
||||||
|
const errorStream = new ReadableStream<Chunk>({
|
||||||
|
start(controller) {
|
||||||
|
controller.enqueue(errorChunk)
|
||||||
|
controller.close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
rawOutput: undefined,
|
rawOutput: undefined,
|
||||||
stream: errorStream, // 将包含错误的流传递下去
|
stream: errorStream, // 将包含错误的流传递下去
|
||||||
|
|||||||
@ -17,7 +17,6 @@ export const ImageGenerationMiddleware: CompletionsMiddleware =
|
|||||||
const { assistant, messages } = params
|
const { assistant, messages } = params
|
||||||
const client = context.apiClientInstance as BaseApiClient<OpenAI>
|
const client = context.apiClientInstance as BaseApiClient<OpenAI>
|
||||||
const signal = context._internal?.flowControl?.abortSignal
|
const signal = context._internal?.flowControl?.abortSignal
|
||||||
|
|
||||||
if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') {
|
if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') {
|
||||||
return next(context, params)
|
return next(context, params)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user