diff --git a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts index 20aa78dcb..b1d8e34fc 100644 --- a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts +++ b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts @@ -79,7 +79,7 @@ vi.mock('@renderer/services/AssistantService', () => ({ import { getProviderByModel } from '@renderer/services/AssistantService' import type { Model, Provider } from '@renderer/types' import { formatApiHost } from '@renderer/utils/api' -import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider' +import { isAzureOpenAIProvider, isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider' import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants' import { getActualProvider, providerToAiSdkConfig } from '../providerConfig' @@ -133,6 +133,17 @@ const createPerplexityProvider = (): Provider => ({ isSystem: false }) +const createAzureProvider = (apiVersion: string): Provider => ({ + id: 'azure-openai', + type: 'azure-openai', + name: 'Azure OpenAI', + apiKey: 'test-key', + apiHost: 'https://example.openai.azure.com/openai', + apiVersion, + models: [], + isSystem: true +}) + describe('Copilot responses routing', () => { beforeEach(() => { ;(globalThis as any).window = { @@ -504,3 +515,46 @@ describe('Stream options includeUsage configuration', () => { expect(config.providerId).toBe('github-copilot-openai-compatible') }) }) + +describe('Azure OpenAI traditional API routing', () => { + beforeEach(() => { + ;(globalThis as any).window = { + ...(globalThis as any).window, + keyv: createWindowKeyv() + } + mockGetState.mockReturnValue({ + settings: { + openAI: { + streamOptions: { + includeUsage: undefined + } + } + } + }) + + vi.mocked(isAzureOpenAIProvider).mockImplementation((provider) => provider.type === 'azure-openai') + }) + + it('uses deployment-based URLs when apiVersion is a date version', () => { + const provider = createAzureProvider('2024-02-15-preview') + const config = providerToAiSdkConfig(provider, createModel('gpt-4o', 'GPT-4o', provider.id)) + + expect(config.providerId).toBe('azure') + expect(config.options.apiVersion).toBe('2024-02-15-preview') + expect(config.options.useDeploymentBasedUrls).toBe(true) + }) + + it('does not force deployment-based URLs for apiVersion v1/preview', () => { + const v1Provider = createAzureProvider('v1') + const v1Config = providerToAiSdkConfig(v1Provider, createModel('gpt-4o', 'GPT-4o', v1Provider.id)) + expect(v1Config.providerId).toBe('azure-responses') + expect(v1Config.options.apiVersion).toBe('v1') + expect(v1Config.options.useDeploymentBasedUrls).toBeUndefined() + + const previewProvider = createAzureProvider('preview') + const previewConfig = providerToAiSdkConfig(previewProvider, createModel('gpt-4o', 'GPT-4o', previewProvider.id)) + expect(previewConfig.providerId).toBe('azure-responses') + expect(previewConfig.options.apiVersion).toBe('preview') + expect(previewConfig.options.useDeploymentBasedUrls).toBeUndefined() + }) +}) diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index 556b870e5..0ad15ea89 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -214,6 +214,15 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A } else if (aiSdkProviderId === 'azure') { extraOptions.mode = 'chat' } + if (isAzureOpenAIProvider(actualProvider)) { + const apiVersion = actualProvider.apiVersion?.trim() + if (apiVersion) { + extraOptions.apiVersion = apiVersion + if (!['preview', 'v1'].includes(apiVersion)) { + extraOptions.useDeploymentBasedUrls = true + } + } + } // bedrock if (aiSdkProviderId === 'bedrock') { diff --git a/tests/renderer.setup.ts b/tests/renderer.setup.ts index 9e10e5363..2f47a8d8e 100644 --- a/tests/renderer.setup.ts +++ b/tests/renderer.setup.ts @@ -1,8 +1,15 @@ import '@testing-library/jest-dom/vitest' +import { createRequire } from 'node:module' import { styleSheetSerializer } from 'jest-styled-components/serializer' import { expect, vi } from 'vitest' +const require = createRequire(import.meta.url) +const bufferModule = require('buffer') +if (!bufferModule.SlowBuffer) { + bufferModule.SlowBuffer = bufferModule.Buffer +} + expect.addSnapshotSerializer(styleSheetSerializer) // Mock LoggerService globally for renderer tests @@ -48,3 +55,29 @@ vi.stubGlobal('api', { writeWithId: vi.fn().mockResolvedValue(undefined) } }) + +if (typeof globalThis.localStorage === 'undefined' || typeof (globalThis.localStorage as any).getItem !== 'function') { + let store = new Map() + + const localStorageMock = { + getItem: (key: string) => store.get(key) ?? null, + setItem: (key: string, value: string) => { + store.set(key, String(value)) + }, + removeItem: (key: string) => { + store.delete(key) + }, + clear: () => { + store.clear() + }, + key: (index: number) => Array.from(store.keys())[index] ?? null, + get length() { + return store.size + } + } + + vi.stubGlobal('localStorage', localStorageMock) + if (typeof window !== 'undefined') { + Object.defineProperty(window, 'localStorage', { value: localStorageMock }) + } +}