mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-10 15:49:29 +08:00
feat: enhance API version handling and cache functionality
- Updated reasoning cache to use tool-specific keys for better organization. - Added methods to list cache keys and entries. - Improved API version regex patterns for more accurate matching. - Refactored API host formatting to handle leading/trailing whitespace and slashes. - Added functions to extract and remove trailing API version segments from URLs.
This commit is contained in:
parent
c6c7c240a3
commit
3989229f61
@ -186,7 +186,7 @@ export class AiSdkToAnthropicSSE {
|
|||||||
// === Tool Events ===
|
// === Tool Events ===
|
||||||
case 'tool-call':
|
case 'tool-call':
|
||||||
if (this.reasoningCache && chunk.providerMetadata?.google?.thoughtSignature) {
|
if (this.reasoningCache && chunk.providerMetadata?.google?.thoughtSignature) {
|
||||||
this.reasoningCache.set('google', chunk.providerMetadata?.google?.thoughtSignature)
|
this.reasoningCache.set(`google-${chunk.toolName}`, chunk.providerMetadata?.google?.thoughtSignature)
|
||||||
}
|
}
|
||||||
// FIXME: 按toolcall id绑定
|
// FIXME: 按toolcall id绑定
|
||||||
if (
|
if (
|
||||||
@ -555,7 +555,6 @@ export class AiSdkToAnthropicSSE {
|
|||||||
}
|
}
|
||||||
|
|
||||||
this.onEvent(messageStopEvent)
|
this.onEvent(messageStopEvent)
|
||||||
this.reasoningCache?.destroy?.()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -27,18 +27,35 @@ export function withoutTrailingSlash<T extends string>(url: T): T {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks if the host path contains a version string (e.g., /v1, /v2beta).
|
* Matches a version segment in a path that starts with `/v<number>` and optionally
|
||||||
|
* continues with `alpha` or `beta`. The segment may be followed by `/` or the end
|
||||||
|
* of the string (useful for cases like `/v3alpha/resources`).
|
||||||
|
*/
|
||||||
|
const VERSION_REGEX_PATTERN = '\\/v\\d+(?:alpha|beta)?(?=\\/|$)'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Matches an API version at the end of a URL (with optional trailing slash).
|
||||||
|
* Used to detect and extract versions only from the trailing position.
|
||||||
|
*/
|
||||||
|
const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 判断 host 的 path 中是否包含形如版本的字符串(例如 /v1、/v2beta 等),
|
||||||
|
*
|
||||||
|
* @param host - 要检查的 host 或 path 字符串
|
||||||
|
* @returns 如果 path 中包含版本字符串则返回 true,否则 false
|
||||||
*/
|
*/
|
||||||
export function hasAPIVersion(host?: string): boolean {
|
export function hasAPIVersion(host?: string): boolean {
|
||||||
if (!host) return false
|
if (!host) return false
|
||||||
|
|
||||||
const versionRegex = /\/v\d+(?:alpha|beta)?(?=\/|$)/i
|
const regex = new RegExp(VERSION_REGEX_PATTERN, 'i')
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const url = new URL(host)
|
const url = new URL(host)
|
||||||
return versionRegex.test(url.pathname)
|
return regex.test(url.pathname)
|
||||||
} catch {
|
} catch {
|
||||||
return versionRegex.test(host)
|
// 若无法作为完整 URL 解析,则当作路径直接检测
|
||||||
|
return regex.test(host)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,22 +88,26 @@ export function formatVertexApiHost(
|
|||||||
/**
|
/**
|
||||||
* Formats an API host URL by normalizing it and optionally appending an API version.
|
* Formats an API host URL by normalizing it and optionally appending an API version.
|
||||||
*
|
*
|
||||||
* @param host - The API host URL to format
|
* @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed.
|
||||||
* @param isSupportedAPIVersion - Whether the API version is supported. Defaults to `true`.
|
* @param supportApiVersion - Whether the API version is supported. Defaults to `true`.
|
||||||
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
|
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
|
||||||
*
|
*
|
||||||
|
* @returns The formatted API host URL. If the host is empty after normalization, returns an empty string.
|
||||||
|
* If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host as-is.
|
||||||
|
* Otherwise, returns the host with the API version appended.
|
||||||
|
*
|
||||||
* @example
|
* @example
|
||||||
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
|
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
|
||||||
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#'
|
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#'
|
||||||
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
|
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
|
||||||
*/
|
*/
|
||||||
export function formatApiHost(host?: string, isSupportedAPIVersion: boolean = true, apiVersion: string = 'v1'): string {
|
export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string {
|
||||||
const normalizedHost = withoutTrailingSlash((host || '').trim())
|
const normalizedHost = withoutTrailingSlash(trim(host))
|
||||||
if (!normalizedHost) {
|
if (!normalizedHost) {
|
||||||
return ''
|
return ''
|
||||||
}
|
}
|
||||||
|
|
||||||
if (normalizedHost.endsWith('#') || !isSupportedAPIVersion || hasAPIVersion(normalizedHost)) {
|
if (normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost)) {
|
||||||
return normalizedHost
|
return normalizedHost
|
||||||
}
|
}
|
||||||
return `${normalizedHost}/${apiVersion}`
|
return `${normalizedHost}/${apiVersion}`
|
||||||
@ -175,3 +196,50 @@ export function validateApiHost(apiHost: string): boolean {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts the trailing API version segment from a URL path.
|
||||||
|
*
|
||||||
|
* This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL.
|
||||||
|
* Only versions at the end of the path are extracted, not versions in the middle.
|
||||||
|
* The returned version string does not include leading or trailing slashes.
|
||||||
|
*
|
||||||
|
* @param {string} url - The URL string to parse.
|
||||||
|
* @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found.
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* getTrailingApiVersion('https://api.example.com/v1') // 'v1'
|
||||||
|
* getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta'
|
||||||
|
* getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end)
|
||||||
|
* getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta'
|
||||||
|
* getTrailingApiVersion('https://api.example.com') // undefined
|
||||||
|
*/
|
||||||
|
export function getTrailingApiVersion(url: string): string | undefined {
|
||||||
|
const match = url.match(TRAILING_VERSION_REGEX)
|
||||||
|
|
||||||
|
if (match) {
|
||||||
|
// Extract version without leading slash and trailing slash
|
||||||
|
return match[0].replace(/^\//, '').replace(/\/$/, '')
|
||||||
|
}
|
||||||
|
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes the trailing API version segment from a URL path.
|
||||||
|
*
|
||||||
|
* This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL.
|
||||||
|
* Only versions at the end of the path are removed, not versions in the middle.
|
||||||
|
*
|
||||||
|
* @param {string} url - The URL string to process.
|
||||||
|
* @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found.
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com'
|
||||||
|
* withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com'
|
||||||
|
* withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change)
|
||||||
|
* withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com'
|
||||||
|
*/
|
||||||
|
export function withoutTrailingApiVersion(url: string): string {
|
||||||
|
return url.replace(TRAILING_VERSION_REGEX, '')
|
||||||
|
}
|
||||||
|
|||||||
@ -54,6 +54,18 @@ export class ReasoningCache<T> {
|
|||||||
return entry.details
|
return entry.details
|
||||||
}
|
}
|
||||||
|
|
||||||
|
listKeys(): string[] {
|
||||||
|
return Array.from(this.cache.keys())
|
||||||
|
}
|
||||||
|
|
||||||
|
listEntries(): Array<{ key: string; entry: CacheEntry<T> }> {
|
||||||
|
const entries: Array<{ key: string; entry: CacheEntry<T> }> = []
|
||||||
|
for (const [key, entry] of this.cache.entries()) {
|
||||||
|
entries.push({ key, entry })
|
||||||
|
}
|
||||||
|
return entries
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clear expired entries
|
* Clear expired entries
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -237,7 +237,6 @@ function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Reco
|
|||||||
inputSchema: zodSchema(schema)
|
inputSchema: zodSchema(schema)
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.debug('Converted Anthropic tool to AI SDK tool', aiTool)
|
|
||||||
aiSdkTools[toolDef.name] = aiTool
|
aiSdkTools[toolDef.name] = aiTool
|
||||||
}
|
}
|
||||||
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
|
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
|
||||||
@ -302,8 +301,9 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
|
|||||||
}
|
}
|
||||||
} else if (block.type === 'tool_use') {
|
} else if (block.type === 'tool_use') {
|
||||||
const options: ProviderOptions = {}
|
const options: ProviderOptions = {}
|
||||||
|
|
||||||
if (isGemini3ModelId(params.model)) {
|
if (isGemini3ModelId(params.model)) {
|
||||||
if (reasoningCache.get('google')) {
|
if (reasoningCache.get(`google-${block.name}`)) {
|
||||||
options.google = {
|
options.google = {
|
||||||
thoughtSignature: MAGIC_STRING
|
thoughtSignature: MAGIC_STRING
|
||||||
}
|
}
|
||||||
@ -394,11 +394,6 @@ async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider>
|
|||||||
|
|
||||||
const provider = await createProviderCore(providerId, config.options)
|
const provider = await createProviderCore(providerId, config.options)
|
||||||
|
|
||||||
logger.debug('AI SDK provider created', {
|
|
||||||
providerId,
|
|
||||||
hasOptions: !!config.options
|
|
||||||
})
|
|
||||||
|
|
||||||
return provider
|
return provider
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -424,7 +419,6 @@ async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkCon
|
|||||||
...headers,
|
...headers,
|
||||||
...existingHeaders
|
...existingHeaders
|
||||||
}
|
}
|
||||||
logger.debug('Copilot token retrieved successfully')
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Failed to get Copilot token', error as Error)
|
logger.error('Failed to get Copilot token', error as Error)
|
||||||
throw new Error('Failed to get Copilot token. Please re-authorize Copilot.')
|
throw new Error('Failed to get Copilot token. Please re-authorize Copilot.')
|
||||||
@ -450,7 +444,6 @@ async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkCon
|
|||||||
baseURL: 'https://api.anthropic.com/v1',
|
baseURL: 'https://api.anthropic.com/v1',
|
||||||
apiKey: ''
|
apiKey: ''
|
||||||
}
|
}
|
||||||
logger.debug('Anthropic OAuth token retrieved successfully')
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Failed to get Anthropic OAuth token', error as Error)
|
logger.error('Failed to get Anthropic OAuth token', error as Error)
|
||||||
throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.')
|
throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.')
|
||||||
@ -479,7 +472,6 @@ async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkCon
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
logger.debug('CherryAI signed fetch configured')
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -498,12 +490,6 @@ async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthro
|
|||||||
// Prepare special provider config (Copilot, Anthropic OAuth, etc.)
|
// Prepare special provider config (Copilot, Anthropic OAuth, etc.)
|
||||||
sdkConfig = await prepareSpecialProviderConfig(provider, sdkConfig)
|
sdkConfig = await prepareSpecialProviderConfig(provider, sdkConfig)
|
||||||
|
|
||||||
logger.debug('Created AI SDK config', {
|
|
||||||
providerId: sdkConfig.providerId,
|
|
||||||
hasOptions: !!sdkConfig.options,
|
|
||||||
message: params.messages
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create provider instance and get language model
|
// Create provider instance and get language model
|
||||||
const aiSdkProvider = await createAiSdkProvider(sdkConfig)
|
const aiSdkProvider = await createAiSdkProvider(sdkConfig)
|
||||||
const baseModel = aiSdkProvider.languageModel(modelId)
|
const baseModel = aiSdkProvider.languageModel(modelId)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user