mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-18 22:20:56 +08:00
feat: enhance support for AWS Bedrock and Azure OpenAI providers (#11510)
* feat: enhance support for AWS Bedrock and Azure OpenAI providers * fix: resolve PR review issues for AWS Bedrock support - Fix header.ts logic bug: change && to || for Vertex/Bedrock provider check - Fix regex in reasoning.ts to match AWS Bedrock model format (anthropic.claude-*) - Add test coverage for AWS Bedrock format in isClaude4SeriesModel - Add Bedrock provider tests including anthropicBeta parameter 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
7ce1590eaf
commit
b18c64b725
@ -17,7 +17,7 @@ export function addAnthropicHeaders(assistant: Assistant, model: Model): string[
|
|||||||
if (
|
if (
|
||||||
isClaude45ReasoningModel(model) &&
|
isClaude45ReasoningModel(model) &&
|
||||||
isToolUseModeFunction(assistant) &&
|
isToolUseModeFunction(assistant) &&
|
||||||
!(isVertexProvider(provider) && isAwsBedrockProvider(provider))
|
!(isVertexProvider(provider) || isAwsBedrockProvider(provider))
|
||||||
) {
|
) {
|
||||||
anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
|
anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -28,6 +28,7 @@ import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
|
|||||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||||
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
|
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
|
||||||
import { replacePromptVariables } from '@renderer/utils/prompt'
|
import { replacePromptVariables } from '@renderer/utils/prompt'
|
||||||
|
import { isAwsBedrockProvider } from '@renderer/utils/provider'
|
||||||
import type { ModelMessage, Tool } from 'ai'
|
import type { ModelMessage, Tool } from 'ai'
|
||||||
import { stepCountIs } from 'ai'
|
import { stepCountIs } from 'ai'
|
||||||
|
|
||||||
@ -175,7 +176,7 @@ export async function buildStreamTextParams(
|
|||||||
|
|
||||||
let headers: Record<string, string | undefined> = options.requestOptions?.headers ?? {}
|
let headers: Record<string, string | undefined> = options.requestOptions?.headers ?? {}
|
||||||
|
|
||||||
if (isAnthropicModel(model)) {
|
if (isAnthropicModel(model) && !isAwsBedrockProvider(provider)) {
|
||||||
const newBetaHeaders = { 'anthropic-beta': addAnthropicHeaders(assistant, model).join(',') }
|
const newBetaHeaders = { 'anthropic-beta': addAnthropicHeaders(assistant, model).join(',') }
|
||||||
headers = combineHeaders(headers, newBetaHeaders)
|
headers = combineHeaders(headers, newBetaHeaders)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import type { Provider } from '@renderer/types'
|
import type { Model, Provider } from '@renderer/types'
|
||||||
import { describe, expect, it, vi } from 'vitest'
|
import { describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
import { getAiSdkProviderId } from '../factory'
|
import { getAiSdkProviderId } from '../factory'
|
||||||
@ -68,6 +68,18 @@ function createTestProvider(id: string, type: string): Provider {
|
|||||||
} as Provider
|
} as Provider
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function createAzureProvider(id: string, apiVersion?: string, model?: string): Provider {
|
||||||
|
return {
|
||||||
|
id,
|
||||||
|
type: 'azure-openai',
|
||||||
|
name: `Azure Test ${id}`,
|
||||||
|
apiKey: 'azure-test-key',
|
||||||
|
apiHost: 'azure-test-host',
|
||||||
|
apiVersion,
|
||||||
|
models: [{ id: model || 'gpt-4' } as Model]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
describe('Integrated Provider Registry', () => {
|
describe('Integrated Provider Registry', () => {
|
||||||
describe('Provider ID Resolution', () => {
|
describe('Provider ID Resolution', () => {
|
||||||
it('should resolve openrouter provider correctly', () => {
|
it('should resolve openrouter provider correctly', () => {
|
||||||
@ -111,6 +123,24 @@ describe('Integrated Provider Registry', () => {
|
|||||||
const result = getAiSdkProviderId(unknownProvider)
|
const result = getAiSdkProviderId(unknownProvider)
|
||||||
expect(result).toBe('unknown-provider')
|
expect(result).toBe('unknown-provider')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('should handle Azure OpenAI providers correctly', () => {
|
||||||
|
const azureProvider = createAzureProvider('azure-test', '2024-02-15', 'gpt-4o')
|
||||||
|
const result = getAiSdkProviderId(azureProvider)
|
||||||
|
expect(result).toBe('azure')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle Azure OpenAI providers response endpoint correctly', () => {
|
||||||
|
const azureProvider = createAzureProvider('azure-test', 'v1', 'gpt-4o')
|
||||||
|
const result = getAiSdkProviderId(azureProvider)
|
||||||
|
expect(result).toBe('azure-responses')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle Azure provider Claude Models', () => {
|
||||||
|
const provider = createTestProvider('azure-anthropic', 'anthropic')
|
||||||
|
const result = getAiSdkProviderId(provider)
|
||||||
|
expect(result).toBe('azure-anthropic')
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('Backward Compatibility', () => {
|
describe('Backward Compatibility', () => {
|
||||||
|
|||||||
@ -154,6 +154,10 @@ vi.mock('../websearch', () => ({
|
|||||||
getWebSearchParams: vi.fn(() => ({ enable_search: true }))
|
getWebSearchParams: vi.fn(() => ({ enable_search: true }))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../prepareParams/header', () => ({
|
||||||
|
addAnthropicHeaders: vi.fn(() => ['context-1m-2025-08-07'])
|
||||||
|
}))
|
||||||
|
|
||||||
const ensureWindowApi = () => {
|
const ensureWindowApi = () => {
|
||||||
const globalWindow = window as any
|
const globalWindow = window as any
|
||||||
globalWindow.api = globalWindow.api || {}
|
globalWindow.api = globalWindow.api || {}
|
||||||
@ -633,5 +637,64 @@ describe('options utils', () => {
|
|||||||
expect(result.providerOptions).toHaveProperty('anthropic')
|
expect(result.providerOptions).toHaveProperty('anthropic')
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('AWS Bedrock provider', () => {
|
||||||
|
const bedrockProvider = {
|
||||||
|
id: 'bedrock',
|
||||||
|
name: 'AWS Bedrock',
|
||||||
|
type: 'aws-bedrock',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://bedrock.us-east-1.amazonaws.com',
|
||||||
|
models: [] as Model[]
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
const bedrockModel: Model = {
|
||||||
|
id: 'anthropic.claude-sonnet-4-20250514-v1:0',
|
||||||
|
name: 'Claude Sonnet 4',
|
||||||
|
provider: 'bedrock'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
it('should build basic Bedrock options', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, bedrockModel, bedrockProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.providerOptions).toHaveProperty('bedrock')
|
||||||
|
expect(result.providerOptions.bedrock).toBeDefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include anthropicBeta when Anthropic headers are needed', async () => {
|
||||||
|
const { addAnthropicHeaders } = await import('../../prepareParams/header')
|
||||||
|
vi.mocked(addAnthropicHeaders).mockReturnValue(['interleaved-thinking-2025-05-14', 'context-1m-2025-08-07'])
|
||||||
|
|
||||||
|
const result = buildProviderOptions(mockAssistant, bedrockModel, bedrockProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.providerOptions.bedrock).toHaveProperty('anthropicBeta')
|
||||||
|
expect(result.providerOptions.bedrock.anthropicBeta).toEqual([
|
||||||
|
'interleaved-thinking-2025-05-14',
|
||||||
|
'context-1m-2025-08-07'
|
||||||
|
])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include reasoning parameters when enabled', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, bedrockModel, bedrockProvider, {
|
||||||
|
enableReasoning: true,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.providerOptions.bedrock).toHaveProperty('reasoningConfig')
|
||||||
|
expect(result.providerOptions.bedrock.reasoningConfig).toEqual({
|
||||||
|
type: 'enabled',
|
||||||
|
budgetTokens: 5000
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@ -36,6 +36,7 @@ import { isSupportServiceTierProvider, isSupportVerbosityProvider } from '@rende
|
|||||||
import type { JSONValue } from 'ai'
|
import type { JSONValue } from 'ai'
|
||||||
import { t } from 'i18next'
|
import { t } from 'i18next'
|
||||||
|
|
||||||
|
import { addAnthropicHeaders } from '../prepareParams/header'
|
||||||
import { getAiSdkProviderId } from '../provider/factory'
|
import { getAiSdkProviderId } from '../provider/factory'
|
||||||
import { buildGeminiGenerateImageParams } from './image'
|
import { buildGeminiGenerateImageParams } from './image'
|
||||||
import {
|
import {
|
||||||
@ -469,6 +470,11 @@ function buildBedrockProviderOptions(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const betaHeaders = addAnthropicHeaders(assistant, model)
|
||||||
|
if (betaHeaders.length > 0) {
|
||||||
|
providerOptions.anthropicBeta = betaHeaders
|
||||||
|
}
|
||||||
|
|
||||||
return providerOptions
|
return providerOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -309,11 +309,14 @@ describe('Ling Models', () => {
|
|||||||
describe('Claude & regional providers', () => {
|
describe('Claude & regional providers', () => {
|
||||||
it('identifies claude 4.5 variants', () => {
|
it('identifies claude 4.5 variants', () => {
|
||||||
expect(isClaude45ReasoningModel(createModel({ id: 'claude-sonnet-4.5-preview' }))).toBe(true)
|
expect(isClaude45ReasoningModel(createModel({ id: 'claude-sonnet-4.5-preview' }))).toBe(true)
|
||||||
|
expect(isClaude4SeriesModel(createModel({ id: 'claude-sonnet-4-5@20250929' }))).toBe(true)
|
||||||
expect(isClaude45ReasoningModel(createModel({ id: 'claude-3-sonnet' }))).toBe(false)
|
expect(isClaude45ReasoningModel(createModel({ id: 'claude-3-sonnet' }))).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('identifies claude 4 variants', () => {
|
it('identifies claude 4 variants', () => {
|
||||||
expect(isClaude4SeriesModel(createModel({ id: 'claude-opus-4' }))).toBe(true)
|
expect(isClaude4SeriesModel(createModel({ id: 'claude-opus-4' }))).toBe(true)
|
||||||
|
expect(isClaude4SeriesModel(createModel({ id: 'claude-sonnet-4@20250514' }))).toBe(true)
|
||||||
|
expect(isClaude4SeriesModel(createModel({ id: 'anthropic.claude-sonnet-4-20250514-v1:0' }))).toBe(true)
|
||||||
expect(isClaude4SeriesModel(createModel({ id: 'claude-4.2-sonnet-variant' }))).toBe(false)
|
expect(isClaude4SeriesModel(createModel({ id: 'claude-4.2-sonnet-variant' }))).toBe(false)
|
||||||
expect(isClaude4SeriesModel(createModel({ id: 'claude-3-haiku' }))).toBe(false)
|
expect(isClaude4SeriesModel(createModel({ id: 'claude-3-haiku' }))).toBe(false)
|
||||||
})
|
})
|
||||||
|
|||||||
@ -396,7 +396,11 @@ export function isClaude45ReasoningModel(model: Model): boolean {
|
|||||||
|
|
||||||
export function isClaude4SeriesModel(model: Model): boolean {
|
export function isClaude4SeriesModel(model: Model): boolean {
|
||||||
const modelId = getLowerBaseModelName(model.id, '/')
|
const modelId = getLowerBaseModelName(model.id, '/')
|
||||||
const regex = /claude-(sonnet|opus|haiku)-4(?:[.-]\d+)?(?:-[\w-]+)?$/i
|
// Supports various formats including:
|
||||||
|
// - Direct API: claude-sonnet-4, claude-opus-4-20250514
|
||||||
|
// - GCP Vertex AI: claude-sonnet-4@20250514
|
||||||
|
// - AWS Bedrock: anthropic.claude-sonnet-4-20250514-v1:0
|
||||||
|
const regex = /claude-(sonnet|opus|haiku)-4(?:[.-]\d+)?(?:[@\-:][\w\-:]+)?$/i
|
||||||
return regex.test(modelId)
|
return regex.test(modelId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -108,6 +108,7 @@ const SUPPORT_URL_CONTEXT_PROVIDER_TYPES = [
|
|||||||
'gemini',
|
'gemini',
|
||||||
'vertexai',
|
'vertexai',
|
||||||
'anthropic',
|
'anthropic',
|
||||||
|
'azure-openai',
|
||||||
'new-api'
|
'new-api'
|
||||||
] as const satisfies ProviderType[]
|
] as const satisfies ProviderType[]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user