From d3027852415a41ba6c088210670b95bb6e31711a Mon Sep 17 00:00:00 2001 From: SuYao Date: Thu, 24 Jul 2025 17:47:00 +0800 Subject: [PATCH] fix(RawStreamListenerMiddleware): update model check (#8433) * fix(RawStreamListenerMiddleware): update model check for Anthropic API integration - Replaced provider type check with model ID check to enhance compatibility with Claude models. - Improved clarity in the middleware logic for handling raw output from the SDK. * refactor(RawStreamListenerMiddleware): enhance model identification for Anthropic integration - Introduced a new utility function `isAnthropicModel` to streamline model checks across the codebase. - Updated middleware logic to utilize the new function for improved clarity and maintainability. - Adjusted related tests to ensure compatibility with the updated model identification approach. * test(ApiService.test): add mock for isAnthropicModel to enhance test coverage for model identification --- .../middleware/core/RawStreamListenerMiddleware.ts | 5 +++-- src/renderer/src/config/models.ts | 10 +++++++++- src/renderer/src/services/__tests__/ApiService.test.ts | 6 +++++- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts b/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts index 3c5df05b28..25d0e358c6 100644 --- a/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts +++ b/src/renderer/src/aiCore/middleware/core/RawStreamListenerMiddleware.ts @@ -1,4 +1,5 @@ import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient' +import { isAnthropicModel } from '@renderer/config/models' import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk' import { AnthropicStreamListener } from '../../clients/types' @@ -15,9 +16,9 @@ export const RawStreamListenerMiddleware: CompletionsMiddleware = // 在这里可以监听到从SDK返回的最原始流 if (result.rawOutput) { - const providerType = ctx.apiClientInstance.provider.type + const model = params.assistant.model // TODO: 后面下放到AnthropicAPIClient - if (providerType === 'anthropic') { + if (isAnthropicModel(model)) { const anthropicListener: AnthropicStreamListener = { onMessage: (message) => { if (ctx._internal?.toolProcessingState) { diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index 2301092fa5..76c002ce29 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -2761,7 +2761,7 @@ export function isWebSearchModel(model: Model): boolean { const baseName = getLowerBaseModelName(model.id, '/') // 不管哪个供应商都判断了 - if (model.id.includes('claude')) { + if (isAnthropicModel(model)) { return CLAUDE_SUPPORTED_WEBSEARCH_REGEX.test(baseName) } @@ -3017,3 +3017,11 @@ export const isVisionModels = (models: Model[]) => { export const isGenerateImageModels = (models: Model[]) => { return models.every((model) => isGenerateImageModel(model)) } + +export const isAnthropicModel = (model?: Model): boolean => { + if (!model) { + return false + } + + return getLowerBaseModelName(model.id).startsWith('claude') +} diff --git a/src/renderer/src/services/__tests__/ApiService.test.ts b/src/renderer/src/services/__tests__/ApiService.test.ts index dfb3b2add8..fd80abd3ad 100644 --- a/src/renderer/src/services/__tests__/ApiService.test.ts +++ b/src/renderer/src/services/__tests__/ApiService.test.ts @@ -65,7 +65,8 @@ vi.mock('@renderer/config/models', () => ({ id: 'gemini-2.5-pro', name: 'Gemini 2.5 Pro' } - } + }, + isAnthropicModel: vi.fn(() => false) })) // Mock uuid @@ -1422,6 +1423,9 @@ const mockGeminiApiClient = { const mockAnthropicApiClient = { createCompletions: vi.fn().mockImplementation(() => anthropicTextNonStreamChunkGenerator()), + attachRawStreamListener: vi.fn().mockImplementation((rawOutput: any) => { + return rawOutput + }), getResponseChunkTransformer: vi.fn().mockImplementation(() => { return () => { let accumulatedJson = ''