mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 06:30:10 +08:00
* feat(utils): add isWithTrailingSharp URL helper function Add new utility function to check if URLs end with trailing '#' character Includes comprehensive test cases covering various URL patterns and edge cases * fix(api): check whether to auto append api version or not when formatting api host - extract api version to variable in GeminiAPIClient for consistency - simplify getBaseURL in OpenAIBaseClient by removing formatApiHost - modify provider api host formatting to respect trailing # - add tests for url parsing with trailing # characters * fix: update provider config tests for new isWithTrailingSharp function - Add isWithTrailingSharp to vi.mock in providerConfig tests - Update test expectations to match new formatApiHost calling behavior - All tests now pass with the new trailing # delimiter functionality 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * fix(anthropic): prevent duplicate api version in base url The Anthropic SDK automatically appends /v1 to endpoints, so we need to avoid duplication by removing the version from baseURL and explicitly setting the path in listModels --------- Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
7507443d8b
commit
058a2c763b
@ -88,16 +88,11 @@ export function getSdkClient(
|
||||
}
|
||||
})
|
||||
}
|
||||
let baseURL =
|
||||
const baseURL =
|
||||
provider.type === 'anthropic'
|
||||
? provider.apiHost
|
||||
: (provider.anthropicApiHost && provider.anthropicApiHost.trim()) || provider.apiHost
|
||||
|
||||
// Anthropic SDK automatically appends /v1 to all endpoints (like /v1/messages, /v1/models)
|
||||
// We need to strip api version from baseURL to avoid duplication (e.g., /v3/v1/models)
|
||||
// formatProviderApiHost adds /v1 for AI SDK compatibility, but Anthropic SDK needs it removed
|
||||
baseURL = baseURL.replace(/\/v\d+(?:alpha|beta)?(?=\/|$)/i, '')
|
||||
|
||||
logger.debug('Anthropic API baseURL', { baseURL, providerId: provider.id })
|
||||
|
||||
if (provider.id === 'aihubmix') {
|
||||
|
||||
@ -124,7 +124,8 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
|
||||
override async listModels(): Promise<Anthropic.ModelInfo[]> {
|
||||
const sdk = (await this.getSdkInstance()) as Anthropic
|
||||
const response = await sdk.models.list()
|
||||
// prevent auto appended /v1. It's included in baseUrl.
|
||||
const response = await sdk.models.list({ path: '/models' })
|
||||
return response.data
|
||||
}
|
||||
|
||||
|
||||
@ -173,13 +173,15 @@ export class GeminiAPIClient extends BaseApiClient<
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
const apiVersion = this.getApiVersion()
|
||||
|
||||
this.sdkInstance = new GoogleGenAI({
|
||||
vertexai: false,
|
||||
apiKey: this.apiKey,
|
||||
apiVersion: this.getApiVersion(),
|
||||
apiVersion,
|
||||
httpOptions: {
|
||||
baseUrl: this.getBaseURL(),
|
||||
apiVersion: this.getApiVersion(),
|
||||
apiVersion,
|
||||
headers: {
|
||||
...this.provider.extra_headers
|
||||
}
|
||||
@ -200,7 +202,7 @@ export class GeminiAPIClient extends BaseApiClient<
|
||||
return trailingVersion
|
||||
}
|
||||
|
||||
return 'v1beta'
|
||||
return ''
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -25,7 +25,7 @@ import type {
|
||||
OpenAISdkRawOutput,
|
||||
ReasoningEffortOptionalParams
|
||||
} from '@renderer/types/sdk'
|
||||
import { formatApiHost, withoutTrailingSlash } from '@renderer/utils/api'
|
||||
import { withoutTrailingSlash } from '@renderer/utils/api'
|
||||
import { isOllamaProvider } from '@renderer/utils/provider'
|
||||
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
@ -49,8 +49,9 @@ export abstract class OpenAIBaseClient<
|
||||
}
|
||||
|
||||
// 仅适用于openai
|
||||
override getBaseURL(isSupportedAPIVerion: boolean = true): string {
|
||||
return formatApiHost(this.provider.apiHost, isSupportedAPIVerion)
|
||||
override getBaseURL(): string {
|
||||
// apiHost is formatted when called by AiProvider
|
||||
return this.provider.apiHost
|
||||
}
|
||||
|
||||
override async generateImage({
|
||||
@ -129,7 +130,7 @@ export abstract class OpenAIBaseClient<
|
||||
}
|
||||
|
||||
if (isOllamaProvider(this.provider)) {
|
||||
const baseUrl = withoutTrailingSlash(this.getBaseURL(false))
|
||||
const baseUrl = withoutTrailingSlash(this.getBaseURL())
|
||||
.replace(/\/v1$/, '')
|
||||
.replace(/\/api$/, '')
|
||||
const response = await fetch(`${baseUrl}/api/tags`, {
|
||||
@ -184,6 +185,7 @@ export abstract class OpenAIBaseClient<
|
||||
|
||||
let apiKeyForSdkInstance = this.apiKey
|
||||
let baseURLForSdkInstance = this.getBaseURL()
|
||||
logger.debug('baseURLForSdkInstance', { baseURLForSdkInstance })
|
||||
let headersForSdkInstance = {
|
||||
...this.defaultHeaders(),
|
||||
...this.provider.extra_headers
|
||||
@ -195,7 +197,7 @@ export abstract class OpenAIBaseClient<
|
||||
// this.provider.apiKey不允许修改
|
||||
// this.provider.apiKey = token
|
||||
apiKeyForSdkInstance = token
|
||||
baseURLForSdkInstance = this.getBaseURL(false)
|
||||
baseURLForSdkInstance = this.getBaseURL()
|
||||
headersForSdkInstance = {
|
||||
...headersForSdkInstance,
|
||||
...COPILOT_DEFAULT_HEADERS
|
||||
|
||||
@ -122,6 +122,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
const baseUrl = this.getBaseURL()
|
||||
|
||||
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
|
||||
return new AzureOpenAI({
|
||||
@ -134,7 +135,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
return new OpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: this.apiKey,
|
||||
baseURL: this.getBaseURL(),
|
||||
baseURL: baseUrl,
|
||||
defaultHeaders: {
|
||||
...this.defaultHeaders(),
|
||||
...this.provider.extra_headers
|
||||
|
||||
@ -42,7 +42,8 @@ vi.mock('@renderer/utils/api', () => ({
|
||||
routeToEndpoint: vi.fn((host) => ({
|
||||
baseURL: host,
|
||||
endpoint: '/chat/completions'
|
||||
}))
|
||||
})),
|
||||
isWithTrailingSharp: vi.fn((host) => host?.endsWith('#') || false)
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/provider', async (importOriginal) => {
|
||||
@ -227,12 +228,19 @@ describe('CherryAI provider configuration', () => {
|
||||
// Mock the functions to simulate non-CherryAI provider
|
||||
vi.mocked(isCherryAIProvider).mockReturnValue(false)
|
||||
vi.mocked(getProviderByModel).mockReturnValue(provider)
|
||||
// Mock isWithTrailingSharp to return false for this test
|
||||
vi.mocked(formatApiHost as any).mockImplementation((host, isSupportedAPIVersion = true) => {
|
||||
if (isSupportedAPIVersion === false) {
|
||||
return host
|
||||
}
|
||||
return `${host}/v1`
|
||||
})
|
||||
|
||||
// Call getActualProvider
|
||||
const actualProvider = getActualProvider(model)
|
||||
|
||||
// Verify that formatApiHost was called with default parameters (true)
|
||||
expect(formatApiHost).toHaveBeenCalledWith('https://api.openai.com')
|
||||
// Verify that formatApiHost was called with appendApiVersion parameter
|
||||
expect(formatApiHost).toHaveBeenCalledWith('https://api.openai.com', true)
|
||||
expect(actualProvider.apiHost).toBe('https://api.openai.com/v1')
|
||||
})
|
||||
|
||||
@ -303,12 +311,19 @@ describe('Perplexity provider configuration', () => {
|
||||
vi.mocked(isCherryAIProvider).mockReturnValue(false)
|
||||
vi.mocked(isPerplexityProvider).mockReturnValue(false)
|
||||
vi.mocked(getProviderByModel).mockReturnValue(provider)
|
||||
// Mock isWithTrailingSharp to return false for this test
|
||||
vi.mocked(formatApiHost as any).mockImplementation((host, isSupportedAPIVersion = true) => {
|
||||
if (isSupportedAPIVersion === false) {
|
||||
return host
|
||||
}
|
||||
return `${host}/v1`
|
||||
})
|
||||
|
||||
// Call getActualProvider
|
||||
const actualProvider = getActualProvider(model)
|
||||
|
||||
// Verify that formatApiHost was called with default parameters (true)
|
||||
expect(formatApiHost).toHaveBeenCalledWith('https://api.openai.com')
|
||||
// Verify that formatApiHost was called with appendApiVersion parameter
|
||||
expect(formatApiHost).toHaveBeenCalledWith('https://api.openai.com', true)
|
||||
expect(actualProvider.apiHost).toBe('https://api.openai.com/v1')
|
||||
})
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ import {
|
||||
formatAzureOpenAIApiHost,
|
||||
formatOllamaApiHost,
|
||||
formatVertexApiHost,
|
||||
isWithTrailingSharp,
|
||||
routeToEndpoint
|
||||
} from '@renderer/utils/api'
|
||||
import {
|
||||
@ -69,14 +70,15 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
|
||||
*/
|
||||
export function formatProviderApiHost(provider: Provider): Provider {
|
||||
const formatted = { ...provider }
|
||||
const appendApiVersion = !isWithTrailingSharp(provider.apiHost)
|
||||
if (formatted.anthropicApiHost) {
|
||||
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
|
||||
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost, appendApiVersion)
|
||||
}
|
||||
|
||||
if (isAnthropicProvider(provider)) {
|
||||
const baseHost = formatted.anthropicApiHost || formatted.apiHost
|
||||
// AI SDK needs /v1 in baseURL, Anthropic SDK will strip it in getSdkClient
|
||||
formatted.apiHost = formatApiHost(baseHost)
|
||||
formatted.apiHost = formatApiHost(baseHost, appendApiVersion)
|
||||
if (!formatted.anthropicApiHost) {
|
||||
formatted.anthropicApiHost = formatted.apiHost
|
||||
}
|
||||
@ -85,7 +87,7 @@ export function formatProviderApiHost(provider: Provider): Provider {
|
||||
} else if (isOllamaProvider(formatted)) {
|
||||
formatted.apiHost = formatOllamaApiHost(formatted.apiHost)
|
||||
} else if (isGeminiProvider(formatted)) {
|
||||
formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta')
|
||||
formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion, 'v1beta')
|
||||
} else if (isAzureOpenAIProvider(formatted)) {
|
||||
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
|
||||
} else if (isVertexProvider(formatted)) {
|
||||
@ -95,7 +97,7 @@ export function formatProviderApiHost(provider: Provider): Provider {
|
||||
} else if (isPerplexityProvider(formatted)) {
|
||||
formatted.apiHost = formatApiHost(formatted.apiHost, false)
|
||||
} else {
|
||||
formatted.apiHost = formatApiHost(formatted.apiHost)
|
||||
formatted.apiHost = formatApiHost(formatted.apiHost, appendApiVersion)
|
||||
}
|
||||
return formatted
|
||||
}
|
||||
|
||||
@ -10,6 +10,7 @@ import {
|
||||
formatVertexApiHost,
|
||||
getTrailingApiVersion,
|
||||
hasAPIVersion,
|
||||
isWithTrailingSharp,
|
||||
maskApiKey,
|
||||
routeToEndpoint,
|
||||
splitApiKeyString,
|
||||
@ -450,6 +451,43 @@ describe('api', () => {
|
||||
it('returns undefined for empty string', () => {
|
||||
expect(getTrailingApiVersion('')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined when URL ends with # regardless of version', () => {
|
||||
expect(getTrailingApiVersion('https://api.example.com/v1#')).toBeUndefined()
|
||||
expect(getTrailingApiVersion('https://api.example.com/v2beta#')).toBeUndefined()
|
||||
expect(getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1#')).toBeUndefined()
|
||||
expect(getTrailingApiVersion('https://api.example.com/service/v1#')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('handles URLs with # and trailing slash correctly', () => {
|
||||
expect(getTrailingApiVersion('https://api.example.com/v1/#')).toBeUndefined()
|
||||
expect(getTrailingApiVersion('https://api.example.com/v2beta/#')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('handles URLs with version followed by # and additional path', () => {
|
||||
expect(getTrailingApiVersion('https://api.example.com/v1#endpoint')).toBeUndefined()
|
||||
expect(getTrailingApiVersion('https://api.example.com/v2beta#chat/completions')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('handles complex URLs with multiple # characters', () => {
|
||||
expect(getTrailingApiVersion('https://api.example.com/v1#path#')).toBeUndefined()
|
||||
expect(getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v2beta#')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('handles URLs ending with # when version is not at the end', () => {
|
||||
expect(getTrailingApiVersion('https://api.example.com/v1/service#')).toBeUndefined()
|
||||
expect(getTrailingApiVersion('https://api.example.com/v1/api/chat#')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('distinguishes between URLs with and without trailing #', () => {
|
||||
// Without # - should extract version
|
||||
expect(getTrailingApiVersion('https://api.example.com/v1')).toBe('v1')
|
||||
expect(getTrailingApiVersion('https://api.example.com/v2beta')).toBe('v2beta')
|
||||
|
||||
// With # - should return undefined
|
||||
expect(getTrailingApiVersion('https://api.example.com/v1#')).toBeUndefined()
|
||||
expect(getTrailingApiVersion('https://api.example.com/v2beta#')).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('withoutTrailingApiVersion', () => {
|
||||
@ -495,6 +533,70 @@ describe('api', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('isWithTrailingSharp', () => {
|
||||
it('returns true when URL ends with #', () => {
|
||||
expect(isWithTrailingSharp('https://api.example.com#')).toBe(true)
|
||||
expect(isWithTrailingSharp('http://localhost:3000#')).toBe(true)
|
||||
expect(isWithTrailingSharp('#')).toBe(true)
|
||||
})
|
||||
|
||||
it('returns false when URL does not end with #', () => {
|
||||
expect(isWithTrailingSharp('https://api.example.com')).toBe(false)
|
||||
expect(isWithTrailingSharp('http://localhost:3000')).toBe(false)
|
||||
expect(isWithTrailingSharp('')).toBe(false)
|
||||
})
|
||||
|
||||
it('returns false when URL has # in the middle but not at the end', () => {
|
||||
expect(isWithTrailingSharp('https://api.example.com#path')).toBe(false)
|
||||
expect(isWithTrailingSharp('https://api.example.com#section/path')).toBe(false)
|
||||
expect(isWithTrailingSharp('https://api.example.com#path#other')).toBe(false)
|
||||
})
|
||||
|
||||
it('handles URLs with multiple # characters', () => {
|
||||
expect(isWithTrailingSharp('https://api.example.com##')).toBe(true)
|
||||
expect(isWithTrailingSharp('https://api.example.com#path#')).toBe(true)
|
||||
expect(isWithTrailingSharp('https://api.example.com###')).toBe(true)
|
||||
})
|
||||
|
||||
it('handles URLs with trailing whitespace after #', () => {
|
||||
expect(isWithTrailingSharp('https://api.example.com# ')).toBe(false)
|
||||
expect(isWithTrailingSharp('https://api.example.com#\t')).toBe(false)
|
||||
expect(isWithTrailingSharp('https://api.example.com#\n')).toBe(false)
|
||||
})
|
||||
|
||||
it('handles URLs with whitespace before trailing #', () => {
|
||||
expect(isWithTrailingSharp(' https://api.example.com#')).toBe(true)
|
||||
expect(isWithTrailingSharp('\thttps://localhost:3000#')).toBe(true)
|
||||
})
|
||||
|
||||
it('preserves type safety with generic parameter', () => {
|
||||
const url1: string = 'https://api.example.com#'
|
||||
const url2 = 'https://example.com' as const
|
||||
|
||||
expect(isWithTrailingSharp(url1)).toBe(true)
|
||||
expect(isWithTrailingSharp(url2)).toBe(false)
|
||||
})
|
||||
|
||||
it('handles complex real-world URLs', () => {
|
||||
expect(isWithTrailingSharp('https://open.cherryin.net/v1/chat/completions#')).toBe(true)
|
||||
expect(isWithTrailingSharp('https://api.openai.com/v1/engines/gpt-4#')).toBe(true)
|
||||
expect(isWithTrailingSharp('https://gateway.ai.cloudflare.com/v1/xxx/v1beta#')).toBe(true)
|
||||
|
||||
expect(isWithTrailingSharp('https://open.cherryin.net/v1/chat/completions')).toBe(false)
|
||||
expect(isWithTrailingSharp('https://api.openai.com/v1/engines/gpt-4')).toBe(false)
|
||||
expect(isWithTrailingSharp('https://gateway.ai.cloudflare.com/v1/xxx/v1beta')).toBe(false)
|
||||
})
|
||||
|
||||
it('handles edge cases', () => {
|
||||
expect(isWithTrailingSharp('#')).toBe(true)
|
||||
expect(isWithTrailingSharp(' #')).toBe(true)
|
||||
expect(isWithTrailingSharp('# ')).toBe(false)
|
||||
expect(isWithTrailingSharp('path#')).toBe(true)
|
||||
expect(isWithTrailingSharp('/path/with/trailing/#')).toBe(true)
|
||||
expect(isWithTrailingSharp('/path/without/trailing/')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('withoutTrailingSharp', () => {
|
||||
it('removes trailing # from URL', () => {
|
||||
expect(withoutTrailingSharp('https://api.example.com#')).toBe('https://api.example.com')
|
||||
|
||||
@ -62,6 +62,23 @@ export function withoutTrailingSlash<T extends string>(url: T): T {
|
||||
return url.replace(/\/$/, '') as T
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a URL string ends with a trailing '#' character.
|
||||
*
|
||||
* @template T - The string type to preserve type safety
|
||||
* @param {T} url - The URL string to check
|
||||
* @returns {boolean} True if the URL ends with '#', false otherwise
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* isWithTrailingSharp('https://example.com#') // true
|
||||
* isWithTrailingSharp('https://example.com') // false
|
||||
* ```
|
||||
*/
|
||||
export function isWithTrailingSharp<T extends string>(url: T): boolean {
|
||||
return url.endsWith('#')
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes the trailing '#' from a URL string if it exists.
|
||||
*
|
||||
|
||||
Loading…
Reference in New Issue
Block a user