diff --git a/packages/shared/provider/sdk-config.ts b/packages/shared/provider/sdk-config.ts index b87a31ed25..350e686e61 100644 --- a/packages/shared/provider/sdk-config.ts +++ b/packages/shared/provider/sdk-config.ts @@ -10,7 +10,7 @@ import { defaultAppHeaders } from '@shared/utils' import { isEmpty } from 'lodash' import { routeToEndpoint } from '../api' -import { isOllamaProvider } from './detection' +import { isAzureOpenAIProvider, isOllamaProvider } from './detection' import { getAiSdkProviderId } from './mapping' import type { MinimalProvider } from './types' import { SystemProviderIds } from './types' @@ -210,6 +210,16 @@ export function providerToAiSdkConfig( extraOptions.mode = 'chat' } + if (isAzureOpenAIProvider(provider)) { + const apiVersion = provider.apiVersion?.trim() + if (apiVersion) { + extraOptions.apiVersion = apiVersion + if (!['preview', 'v1'].includes(apiVersion)) { + extraOptions.useDeploymentBasedUrls = true + } + } + } + // Handle AWS Bedrock if (aiSdkProviderId === 'bedrock') { const bedrockConfig = context.getAwsBedrockConfig?.() diff --git a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts index bbeedd69c9..59975c824c 100644 --- a/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts +++ b/src/renderer/src/aiCore/provider/__tests__/providerConfig.test.ts @@ -92,7 +92,7 @@ vi.mock('@renderer/services/AssistantService', () => ({ import { getProviderByModel } from '@renderer/services/AssistantService' import type { Model, Provider } from '@renderer/types' -import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider' +import { isAzureOpenAIProvider, isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider' import { formatApiHost } from '@shared/api' import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants' @@ -172,6 +172,17 @@ const createPerplexityProvider = (): Provider => ({ isSystem: false }) +const createAzureProvider = (apiVersion: string): Provider => ({ + id: 'azure-openai', + type: 'azure-openai', + name: 'Azure OpenAI', + apiKey: 'test-key', + apiHost: 'https://example.openai.azure.com/openai', + apiVersion, + models: [], + isSystem: true +}) + describe('Copilot responses routing', () => { beforeEach(() => { ;(globalThis as any).window = { @@ -454,3 +465,46 @@ describe('Stream options includeUsage configuration', () => { expect(config.providerId).toBe('github-copilot-openai-compatible') }) }) + +describe('Azure OpenAI traditional API routing', () => { + beforeEach(() => { + ;(globalThis as any).window = { + ...(globalThis as any).window, + keyv: createWindowKeyv() + } + mockGetState.mockReturnValue({ + settings: { + openAI: { + streamOptions: { + includeUsage: undefined + } + } + } + }) + + vi.mocked(isAzureOpenAIProvider).mockImplementation((provider) => provider.type === 'azure-openai') + }) + + it('uses deployment-based URLs when apiVersion is a date version', () => { + const provider = createAzureProvider('2024-02-15-preview') + const config = providerToAiSdkConfig(provider, createModel('gpt-4o', 'GPT-4o', provider.id)) + + expect(config.providerId).toBe('azure') + expect(config.options.apiVersion).toBe('2024-02-15-preview') + expect(config.options.useDeploymentBasedUrls).toBe(true) + }) + + it('does not force deployment-based URLs for apiVersion v1/preview', () => { + const v1Provider = createAzureProvider('v1') + const v1Config = providerToAiSdkConfig(v1Provider, createModel('gpt-4o', 'GPT-4o', v1Provider.id)) + expect(v1Config.providerId).toBe('azure-responses') + expect(v1Config.options.apiVersion).toBe('v1') + expect(v1Config.options.useDeploymentBasedUrls).toBeUndefined() + + const previewProvider = createAzureProvider('preview') + const previewConfig = providerToAiSdkConfig(previewProvider, createModel('gpt-4o', 'GPT-4o', previewProvider.id)) + expect(previewConfig.providerId).toBe('azure-responses') + expect(previewConfig.options.apiVersion).toBe('preview') + expect(previewConfig.options.useDeploymentBasedUrls).toBeUndefined() + }) +}) diff --git a/tests/renderer.setup.ts b/tests/renderer.setup.ts index 9e10e5363a..2f47a8d8ec 100644 --- a/tests/renderer.setup.ts +++ b/tests/renderer.setup.ts @@ -1,8 +1,15 @@ import '@testing-library/jest-dom/vitest' +import { createRequire } from 'node:module' import { styleSheetSerializer } from 'jest-styled-components/serializer' import { expect, vi } from 'vitest' +const require = createRequire(import.meta.url) +const bufferModule = require('buffer') +if (!bufferModule.SlowBuffer) { + bufferModule.SlowBuffer = bufferModule.Buffer +} + expect.addSnapshotSerializer(styleSheetSerializer) // Mock LoggerService globally for renderer tests @@ -48,3 +55,29 @@ vi.stubGlobal('api', { writeWithId: vi.fn().mockResolvedValue(undefined) } }) + +if (typeof globalThis.localStorage === 'undefined' || typeof (globalThis.localStorage as any).getItem !== 'function') { + let store = new Map() + + const localStorageMock = { + getItem: (key: string) => store.get(key) ?? null, + setItem: (key: string, value: string) => { + store.set(key, String(value)) + }, + removeItem: (key: string) => { + store.delete(key) + }, + clear: () => { + store.clear() + }, + key: (index: number) => Array.from(store.keys())[index] ?? null, + get length() { + return store.size + } + } + + vi.stubGlobal('localStorage', localStorageMock) + if (typeof window !== 'undefined') { + Object.defineProperty(window, 'localStorage', { value: localStorageMock }) + } +}