fix(RawStreamListenerMiddleware): update model check (#8433)

* fix(RawStreamListenerMiddleware): update model check for Anthropic API integration

- Replaced provider type check with model ID check to enhance compatibility with Claude models.
- Improved clarity in the middleware logic for handling raw output from the SDK.

* refactor(RawStreamListenerMiddleware): enhance model identification for Anthropic integration

- Introduced a new utility function `isAnthropicModel` to streamline model checks across the codebase.
- Updated middleware logic to utilize the new function for improved clarity and maintainability.
- Adjusted related tests to ensure compatibility with the updated model identification approach.

* test(ApiService.test): add mock for isAnthropicModel to enhance test coverage for model identification
This commit is contained in:
SuYao 2025-07-24 17:47:00 +08:00 committed by GitHub
parent 2721930294
commit d302785241
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 4 deletions

View File

@ -1,4 +1,5 @@
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
import { isAnthropicModel } from '@renderer/config/models'
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
import { AnthropicStreamListener } from '../../clients/types'
@ -15,9 +16,9 @@ export const RawStreamListenerMiddleware: CompletionsMiddleware =
// 在这里可以监听到从SDK返回的最原始流
if (result.rawOutput) {
const providerType = ctx.apiClientInstance.provider.type
const model = params.assistant.model
// TODO: 后面下放到AnthropicAPIClient
if (providerType === 'anthropic') {
if (isAnthropicModel(model)) {
const anthropicListener: AnthropicStreamListener<AnthropicSdkRawChunk> = {
onMessage: (message) => {
if (ctx._internal?.toolProcessingState) {

View File

@ -2761,7 +2761,7 @@ export function isWebSearchModel(model: Model): boolean {
const baseName = getLowerBaseModelName(model.id, '/')
// 不管哪个供应商都判断了
if (model.id.includes('claude')) {
if (isAnthropicModel(model)) {
return CLAUDE_SUPPORTED_WEBSEARCH_REGEX.test(baseName)
}
@ -3017,3 +3017,11 @@ export const isVisionModels = (models: Model[]) => {
export const isGenerateImageModels = (models: Model[]) => {
return models.every((model) => isGenerateImageModel(model))
}
export const isAnthropicModel = (model?: Model): boolean => {
if (!model) {
return false
}
return getLowerBaseModelName(model.id).startsWith('claude')
}

View File

@ -65,7 +65,8 @@ vi.mock('@renderer/config/models', () => ({
id: 'gemini-2.5-pro',
name: 'Gemini 2.5 Pro'
}
}
},
isAnthropicModel: vi.fn(() => false)
}))
// Mock uuid
@ -1422,6 +1423,9 @@ const mockGeminiApiClient = {
const mockAnthropicApiClient = {
createCompletions: vi.fn().mockImplementation(() => anthropicTextNonStreamChunkGenerator()),
attachRawStreamListener: vi.fn().mockImplementation((rawOutput: any) => {
return rawOutput
}),
getResponseChunkTransformer: vi.fn().mockImplementation(() => {
return () => {
let accumulatedJson = ''