mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 14:41:24 +08:00
fix: enhance provider handling and API key rotation logic in AiProvider (#11586)
* fix: enhance provider handling and API key rotation logic in AiProvider * fix * fix(api): enhance API key handling and logging for providers
This commit is contained in:
parent
a566cd65f4
commit
92bb05950d
@ -120,9 +120,12 @@ export default class ModernAiProvider {
|
||||
throw new Error('Model is required for completions. Please use constructor with model parameter.')
|
||||
}
|
||||
|
||||
// 每次请求时重新生成配置以确保API key轮换生效
|
||||
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
|
||||
logger.debug('Generated provider config for completions', this.config)
|
||||
// Config is now set in constructor, ApiService handles key rotation before passing provider
|
||||
if (!this.config) {
|
||||
// If config wasn't set in constructor (when provider only), generate it now
|
||||
this.config = providerToAiSdkConfig(this.actualProvider, this.model!)
|
||||
}
|
||||
logger.debug('Using provider config for completions', this.config)
|
||||
|
||||
// 检查 config 是否存在
|
||||
if (!this.config) {
|
||||
|
||||
@ -37,32 +37,6 @@ import { azureAnthropicProviderCreator } from './config/azure-anthropic'
|
||||
import { COPILOT_DEFAULT_HEADERS } from './constants'
|
||||
import { getAiSdkProviderId } from './factory'
|
||||
|
||||
/**
|
||||
* 获取轮询的API key
|
||||
* 复用legacy架构的多key轮询逻辑
|
||||
*/
|
||||
function getRotatedApiKey(provider: Provider): string {
|
||||
const keys = provider.apiKey.split(',').map((key) => key.trim())
|
||||
const keyName = `provider:${provider.id}:last_used_key`
|
||||
|
||||
if (keys.length === 1) {
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const lastUsedKey = window.keyv.get(keyName)
|
||||
if (!lastUsedKey) {
|
||||
window.keyv.set(keyName, keys[0])
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const currentIndex = keys.indexOf(lastUsedKey)
|
||||
const nextIndex = (currentIndex + 1) % keys.length
|
||||
const nextKey = keys[nextIndex]
|
||||
window.keyv.set(keyName, nextKey)
|
||||
|
||||
return nextKey
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理特殊provider的转换逻辑
|
||||
*/
|
||||
@ -171,7 +145,7 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A
|
||||
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
|
||||
const baseConfig = {
|
||||
baseURL: baseURL,
|
||||
apiKey: getRotatedApiKey(actualProvider)
|
||||
apiKey: actualProvider.apiKey
|
||||
}
|
||||
|
||||
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot
|
||||
|
||||
@ -8,8 +8,8 @@ import { isDedicatedImageGenerationModel, isEmbeddingModel, isFunctionCallingMod
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import store from '@renderer/store'
|
||||
import type { FetchChatCompletionParams } from '@renderer/types'
|
||||
import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
|
||||
import { type FetchChatCompletionParams, isSystemProvider } from '@renderer/types'
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import type { Message, ResponseError } from '@renderer/types/newMessage'
|
||||
@ -21,7 +21,8 @@ import { purifyMarkdownImages } from '@renderer/utils/markdown'
|
||||
import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt'
|
||||
import { isEmpty, takeRight } from 'lodash'
|
||||
import { NOT_SUPPORT_API_KEY_PROVIDER_TYPES, NOT_SUPPORT_API_KEY_PROVIDERS } from '@renderer/utils/provider'
|
||||
import { cloneDeep, isEmpty, takeRight } from 'lodash'
|
||||
|
||||
import type { ModernAiProviderConfig } from '../aiCore/index_new'
|
||||
import AiProviderNew from '../aiCore/index_new'
|
||||
@ -42,6 +43,8 @@ import {
|
||||
// } from './MessagesService'
|
||||
// import WebSearchService from './WebSearchService'
|
||||
|
||||
// FIXME: 这里太多重复逻辑,需要重构
|
||||
|
||||
const logger = loggerService.withContext('ApiService')
|
||||
|
||||
export async function fetchMcpTools(assistant: Assistant) {
|
||||
@ -94,7 +97,15 @@ export async function fetchChatCompletion({
|
||||
modelId: assistant.model?.id,
|
||||
modelName: assistant.model?.name
|
||||
})
|
||||
const AI = new AiProviderNew(assistant.model || getDefaultModel())
|
||||
|
||||
// Get base provider and apply API key rotation
|
||||
const baseProvider = getProviderByModel(assistant.model || getDefaultModel())
|
||||
const providerWithRotatedKey = {
|
||||
...cloneDeep(baseProvider),
|
||||
apiKey: getRotatedApiKey(baseProvider)
|
||||
}
|
||||
|
||||
const AI = new AiProviderNew(assistant.model || getDefaultModel(), providerWithRotatedKey)
|
||||
const provider = AI.getActualProvider()
|
||||
|
||||
const mcpTools: MCPTool[] = []
|
||||
@ -171,7 +182,13 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
|
||||
return null
|
||||
}
|
||||
|
||||
const AI = new AiProviderNew(model)
|
||||
// Apply API key rotation
|
||||
const providerWithRotatedKey = {
|
||||
...cloneDeep(provider),
|
||||
apiKey: getRotatedApiKey(provider)
|
||||
}
|
||||
|
||||
const AI = new AiProviderNew(model, providerWithRotatedKey)
|
||||
|
||||
const topicId = messages?.find((message) => message.topicId)?.topicId || ''
|
||||
|
||||
@ -270,7 +287,13 @@ export async function fetchNoteSummary({ content, assistant }: { content: string
|
||||
return null
|
||||
}
|
||||
|
||||
const AI = new AiProviderNew(model)
|
||||
// Apply API key rotation
|
||||
const providerWithRotatedKey = {
|
||||
...cloneDeep(provider),
|
||||
apiKey: getRotatedApiKey(provider)
|
||||
}
|
||||
|
||||
const AI = new AiProviderNew(model, providerWithRotatedKey)
|
||||
|
||||
// only 2000 char and no images
|
||||
const truncatedContent = content.substring(0, 2000)
|
||||
@ -358,7 +381,13 @@ export async function fetchGenerate({
|
||||
return ''
|
||||
}
|
||||
|
||||
const AI = new AiProviderNew(model)
|
||||
// Apply API key rotation
|
||||
const providerWithRotatedKey = {
|
||||
...cloneDeep(provider),
|
||||
apiKey: getRotatedApiKey(provider)
|
||||
}
|
||||
|
||||
const AI = new AiProviderNew(model, providerWithRotatedKey)
|
||||
|
||||
const assistant = getDefaultAssistant()
|
||||
assistant.model = model
|
||||
@ -403,43 +432,91 @@ export async function fetchGenerate({
|
||||
|
||||
export function hasApiKey(provider: Provider) {
|
||||
if (!provider) return false
|
||||
if (['ollama', 'lmstudio', 'vertexai', 'cherryai'].includes(provider.id)) return true
|
||||
if (provider.id === 'cherryai') return true
|
||||
if (
|
||||
(isSystemProvider(provider) && NOT_SUPPORT_API_KEY_PROVIDERS.includes(provider.id)) ||
|
||||
NOT_SUPPORT_API_KEY_PROVIDER_TYPES.includes(provider.type)
|
||||
)
|
||||
return true
|
||||
return !isEmpty(provider.apiKey)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the first available embedding model from enabled providers
|
||||
* Get rotated API key for providers that support multiple keys
|
||||
* Returns empty string for providers that don't require API keys
|
||||
*/
|
||||
// function getFirstEmbeddingModel() {
|
||||
// const providers = store.getState().llm.providers.filter((p) => p.enabled)
|
||||
function getRotatedApiKey(provider: Provider): string {
|
||||
// Handle providers that don't require API keys
|
||||
if (!provider.apiKey || provider.apiKey.trim() === '') {
|
||||
return ''
|
||||
}
|
||||
|
||||
// for (const provider of providers) {
|
||||
// const embeddingModel = provider.models.find((model) => isEmbeddingModel(model))
|
||||
// if (embeddingModel) {
|
||||
// return embeddingModel
|
||||
// }
|
||||
// }
|
||||
const keys = provider.apiKey
|
||||
.split(',')
|
||||
.map((key) => key.trim())
|
||||
.filter(Boolean)
|
||||
|
||||
// return undefined
|
||||
// }
|
||||
if (keys.length === 0) {
|
||||
return ''
|
||||
}
|
||||
|
||||
const keyName = `provider:${provider.id}:last_used_key`
|
||||
|
||||
// If only one key, return it directly
|
||||
if (keys.length === 1) {
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const lastUsedKey = window.keyv.get(keyName)
|
||||
if (!lastUsedKey) {
|
||||
window.keyv.set(keyName, keys[0])
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const currentIndex = keys.indexOf(lastUsedKey)
|
||||
|
||||
// Log when the last used key is no longer in the list
|
||||
if (currentIndex === -1) {
|
||||
logger.debug('Last used API key no longer found in provider keys, falling back to first key', {
|
||||
providerId: provider.id,
|
||||
lastUsedKey: lastUsedKey.substring(0, 8) + '...' // Only log first 8 chars for security
|
||||
})
|
||||
}
|
||||
|
||||
const nextIndex = (currentIndex + 1) % keys.length
|
||||
const nextKey = keys[nextIndex]
|
||||
window.keyv.set(keyName, nextKey)
|
||||
|
||||
return nextKey
|
||||
}
|
||||
|
||||
export async function fetchModels(provider: Provider): Promise<Model[]> {
|
||||
const AI = new AiProviderNew(provider)
|
||||
// Apply API key rotation
|
||||
const providerWithRotatedKey = {
|
||||
...cloneDeep(provider),
|
||||
apiKey: getRotatedApiKey(provider)
|
||||
}
|
||||
|
||||
const AI = new AiProviderNew(providerWithRotatedKey)
|
||||
|
||||
try {
|
||||
return await AI.models()
|
||||
} catch (error) {
|
||||
logger.error('Failed to fetch models from provider', {
|
||||
providerId: provider.id,
|
||||
providerName: provider.name,
|
||||
error: error as Error
|
||||
})
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
export function checkApiProvider(provider: Provider): void {
|
||||
if (
|
||||
provider.id !== 'ollama' &&
|
||||
provider.id !== 'lmstudio' &&
|
||||
provider.type !== 'vertexai' &&
|
||||
provider.id !== 'copilot'
|
||||
) {
|
||||
const isExcludedProvider =
|
||||
(isSystemProvider(provider) && NOT_SUPPORT_API_KEY_PROVIDERS.includes(provider.id)) ||
|
||||
NOT_SUPPORT_API_KEY_PROVIDER_TYPES.includes(provider.type)
|
||||
|
||||
if (!isExcludedProvider) {
|
||||
if (!provider.apiKey) {
|
||||
window.toast.error(i18n.t('message.error.enter.api.label'))
|
||||
throw new Error(i18n.t('message.error.enter.api.label'))
|
||||
@ -460,8 +537,7 @@ export function checkApiProvider(provider: Provider): void {
|
||||
export async function checkApi(provider: Provider, model: Model, timeout = 15000): Promise<void> {
|
||||
checkApiProvider(provider)
|
||||
|
||||
// Don't pass in provider parameter. We need auto-format URL
|
||||
const ai = new AiProviderNew(model)
|
||||
const ai = new AiProviderNew(model, provider)
|
||||
|
||||
const assistant = getDefaultAssistant()
|
||||
assistant.model = model
|
||||
|
||||
@ -187,3 +187,13 @@ export const isSupportAPIVersionProvider = (provider: Provider) => {
|
||||
}
|
||||
return provider.apiOptions?.isNotSupportAPIVersion !== false
|
||||
}
|
||||
|
||||
export const NOT_SUPPORT_API_KEY_PROVIDERS: readonly SystemProviderId[] = [
|
||||
'ollama',
|
||||
'lmstudio',
|
||||
'vertexai',
|
||||
'aws-bedrock',
|
||||
'copilot'
|
||||
]
|
||||
|
||||
export const NOT_SUPPORT_API_KEY_PROVIDER_TYPES: readonly ProviderType[] = ['vertexai', 'aws-bedrock']
|
||||
|
||||
Loading…
Reference in New Issue
Block a user