mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-30 07:39:06 +08:00
fix(OpenAIResponseAPIClient): add self-referential compatibility type check to prevent circular calls (#8424)
fixOpenAIResponseAPIClient): add self-referential compatibility type check to prevent circular calls
This commit is contained in:
parent
5f5dfd13c7
commit
f5b6a4be49
@ -104,6 +104,10 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
}
|
||||
|
||||
const actualClient = this.getClient(model)
|
||||
// 避免循环调用:如果返回的是自己,直接返回自己的类型
|
||||
if (actualClient === this) {
|
||||
return [this.constructor.name]
|
||||
}
|
||||
return actualClient.getClientCompatibilityType(model)
|
||||
}
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/Anthropic
|
||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||
import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient'
|
||||
import { OpenAIResponseAPIClient } from '@renderer/aiCore/clients/openai/OpenAIResponseAPIClient'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { isVisionModel } from '@renderer/config/models'
|
||||
import { Assistant, MCPCallToolResponse, MCPToolResponse, Model, Provider, WebSearchSource } from '@renderer/types'
|
||||
@ -1254,6 +1255,35 @@ const mockOpenaiApiClient = {
|
||||
getClientCompatibilityType: vi.fn(() => ['OpenAIAPIClient'])
|
||||
} as unknown as OpenAIAPIClient
|
||||
|
||||
// Mock OpenAIResponseAPIClient
|
||||
const mockOpenAIResponseAPIClient = {
|
||||
createCompletions: vi.fn().mockImplementation(() => openaiThinkingChunkGenerator()),
|
||||
getResponseChunkTransformer: mockOpenaiApiClient.getResponseChunkTransformer,
|
||||
getSdkInstance: vi.fn(),
|
||||
getRequestTransformer: vi.fn().mockImplementation(() => ({
|
||||
async transform(params: any) {
|
||||
return {
|
||||
payload: {
|
||||
model: params.assistant?.model?.id || 'gpt-4o',
|
||||
messages: params.messages || [],
|
||||
tools: params.tools || []
|
||||
},
|
||||
metadata: {}
|
||||
}
|
||||
}
|
||||
})),
|
||||
convertMcpToolsToSdkTools: vi.fn(() => []),
|
||||
convertSdkToolCallToMcpToolResponse: vi.fn(),
|
||||
buildSdkMessages: vi.fn(() => []),
|
||||
extractMessagesFromSdkPayload: vi.fn(() => []),
|
||||
provider: {} as Provider,
|
||||
useSystemPromptForTools: true,
|
||||
getBaseURL: vi.fn(() => 'https://api.openai.com'),
|
||||
getApiKey: vi.fn(() => 'mock-api-key'),
|
||||
getClient: vi.fn(() => mockOpenaiApiClient), // 模拟返回内部客户端
|
||||
getClientCompatibilityType: vi.fn(() => ['OpenAIResponseAPIClient'])
|
||||
} as unknown as OpenAIResponseAPIClient
|
||||
|
||||
const mockOpenaiNeedExtractContentApiClient = cloneDeep(mockOpenaiApiClient)
|
||||
mockOpenaiNeedExtractContentApiClient.createCompletions = vi
|
||||
.fn()
|
||||
@ -2252,6 +2282,49 @@ describe('ApiService', () => {
|
||||
expect(filteredChunks).toEqual(expectedChunks)
|
||||
})
|
||||
|
||||
it('should handle OpenAIResponseAPIClient compatibility type without circular call', async () => {
|
||||
const mockCreate = vi.mocked(ApiClientFactory.create)
|
||||
|
||||
// 创建一个模拟的 OpenAIResponseAPIClient,getClient 返回自身
|
||||
const mockSelfReturningClient = {
|
||||
...mockOpenAIResponseAPIClient,
|
||||
getClient: vi.fn(() => mockSelfReturningClient), // 返回自身,模拟循环调用场景
|
||||
getClientCompatibilityType: vi.fn((model) => {
|
||||
// 模拟真实的逻辑:检查是否返回自身
|
||||
const actualClient = mockSelfReturningClient.getClient()
|
||||
if (actualClient === mockSelfReturningClient) {
|
||||
return ['OpenAIResponseAPIClient']
|
||||
}
|
||||
return actualClient.getClientCompatibilityType(model)
|
||||
})
|
||||
}
|
||||
|
||||
mockCreate.mockReturnValue(mockSelfReturningClient as unknown as BaseApiClient)
|
||||
const AI = new AiProvider(mockProvider)
|
||||
|
||||
const result = await AI.completions({
|
||||
callType: 'test',
|
||||
messages: [],
|
||||
assistant: {
|
||||
id: '1',
|
||||
name: 'test',
|
||||
prompt: 'test',
|
||||
model: {
|
||||
id: 'gpt-4o',
|
||||
name: 'GPT-4o'
|
||||
}
|
||||
} as Assistant,
|
||||
onChunk: mockOnChunk,
|
||||
streamOutput: true
|
||||
})
|
||||
|
||||
expect(result).toBeDefined()
|
||||
expect(mockSelfReturningClient.getClientCompatibilityType).toHaveBeenCalled()
|
||||
|
||||
// 验证没有抛出堆栈溢出错误,表明没有无限循环
|
||||
expect(() => mockSelfReturningClient.getClientCompatibilityType({ id: 'gpt-4o' })).not.toThrow()
|
||||
})
|
||||
|
||||
it('should extract tool use responses correctly', async () => {
|
||||
const mockCreate = vi.mocked(ApiClientFactory.create)
|
||||
mockCreate.mockReturnValue(mockGeminiToolUseApiClient as unknown as BaseApiClient)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user