Merge remote-tracking branch 'origin/main' into feat/proxy-api-server

This commit is contained in:
suyao 2025-12-18 18:19:10 +08:00
commit ae2712c963
No known key found for this signature in database
3 changed files with 99 additions and 2 deletions

View File

@ -10,7 +10,7 @@ import { defaultAppHeaders } from '@shared/utils'
import { isEmpty } from 'lodash'
import { routeToEndpoint } from '../api'
import { isOllamaProvider } from './detection'
import { isAzureOpenAIProvider, isOllamaProvider } from './detection'
import { getAiSdkProviderId } from './mapping'
import type { MinimalProvider } from './types'
import { SystemProviderIds } from './types'
@ -210,6 +210,16 @@ export function providerToAiSdkConfig(
extraOptions.mode = 'chat'
}
if (isAzureOpenAIProvider(provider)) {
const apiVersion = provider.apiVersion?.trim()
if (apiVersion) {
extraOptions.apiVersion = apiVersion
if (!['preview', 'v1'].includes(apiVersion)) {
extraOptions.useDeploymentBasedUrls = true
}
}
}
// Handle AWS Bedrock
if (aiSdkProviderId === 'bedrock') {
const bedrockConfig = context.getAwsBedrockConfig?.()

View File

@ -92,7 +92,7 @@ vi.mock('@renderer/services/AssistantService', () => ({
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
import { isAzureOpenAIProvider, isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
import { formatApiHost } from '@shared/api'
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
@ -172,6 +172,17 @@ const createPerplexityProvider = (): Provider => ({
isSystem: false
})
const createAzureProvider = (apiVersion: string): Provider => ({
id: 'azure-openai',
type: 'azure-openai',
name: 'Azure OpenAI',
apiKey: 'test-key',
apiHost: 'https://example.openai.azure.com/openai',
apiVersion,
models: [],
isSystem: true
})
describe('Copilot responses routing', () => {
beforeEach(() => {
;(globalThis as any).window = {
@ -454,3 +465,46 @@ describe('Stream options includeUsage configuration', () => {
expect(config.providerId).toBe('github-copilot-openai-compatible')
})
})
describe('Azure OpenAI traditional API routing', () => {
beforeEach(() => {
;(globalThis as any).window = {
...(globalThis as any).window,
keyv: createWindowKeyv()
}
mockGetState.mockReturnValue({
settings: {
openAI: {
streamOptions: {
includeUsage: undefined
}
}
}
})
vi.mocked(isAzureOpenAIProvider).mockImplementation((provider) => provider.type === 'azure-openai')
})
it('uses deployment-based URLs when apiVersion is a date version', () => {
const provider = createAzureProvider('2024-02-15-preview')
const config = providerToAiSdkConfig(provider, createModel('gpt-4o', 'GPT-4o', provider.id))
expect(config.providerId).toBe('azure')
expect(config.options.apiVersion).toBe('2024-02-15-preview')
expect(config.options.useDeploymentBasedUrls).toBe(true)
})
it('does not force deployment-based URLs for apiVersion v1/preview', () => {
const v1Provider = createAzureProvider('v1')
const v1Config = providerToAiSdkConfig(v1Provider, createModel('gpt-4o', 'GPT-4o', v1Provider.id))
expect(v1Config.providerId).toBe('azure-responses')
expect(v1Config.options.apiVersion).toBe('v1')
expect(v1Config.options.useDeploymentBasedUrls).toBeUndefined()
const previewProvider = createAzureProvider('preview')
const previewConfig = providerToAiSdkConfig(previewProvider, createModel('gpt-4o', 'GPT-4o', previewProvider.id))
expect(previewConfig.providerId).toBe('azure-responses')
expect(previewConfig.options.apiVersion).toBe('preview')
expect(previewConfig.options.useDeploymentBasedUrls).toBeUndefined()
})
})

View File

@ -1,8 +1,15 @@
import '@testing-library/jest-dom/vitest'
import { createRequire } from 'node:module'
import { styleSheetSerializer } from 'jest-styled-components/serializer'
import { expect, vi } from 'vitest'
const require = createRequire(import.meta.url)
const bufferModule = require('buffer')
if (!bufferModule.SlowBuffer) {
bufferModule.SlowBuffer = bufferModule.Buffer
}
expect.addSnapshotSerializer(styleSheetSerializer)
// Mock LoggerService globally for renderer tests
@ -48,3 +55,29 @@ vi.stubGlobal('api', {
writeWithId: vi.fn().mockResolvedValue(undefined)
}
})
if (typeof globalThis.localStorage === 'undefined' || typeof (globalThis.localStorage as any).getItem !== 'function') {
let store = new Map<string, string>()
const localStorageMock = {
getItem: (key: string) => store.get(key) ?? null,
setItem: (key: string, value: string) => {
store.set(key, String(value))
},
removeItem: (key: string) => {
store.delete(key)
},
clear: () => {
store.clear()
},
key: (index: number) => Array.from(store.keys())[index] ?? null,
get length() {
return store.size
}
}
vi.stubGlobal('localStorage', localStorageMock)
if (typeof window !== 'undefined') {
Object.defineProperty(window, 'localStorage', { value: localStorageMock })
}
}