cherry-studio/src/main/services/VertexAIService.ts
SuYao 4c0167cc03
Feat/vertex-claude-support (#7564)
* feat(migrate): add default settings for assistants during migration

- Introduced a new migration step to assign default settings for assistants that lack configuration.
- Default settings include temperature, context count, and other parameters to ensure consistent behavior across the application.

* chore(store): increment version number to 115 for persisted reducer

* feat(vertex-sdk): integrate Anthropic Vertex SDK and add access token retrieval

- Added support for the new `@anthropic-ai/vertex-sdk` in the project.
- Introduced a new IPC channel `VertexAI_GetAccessToken` to retrieve access tokens.
- Implemented `getAccessToken` method in `VertexAIService` to handle service account authentication.
- Updated the `IpcChannel` enum and related IPC handlers to support the new functionality.
- Enhanced the `VertexAPIClient` to utilize the `AnthropicVertexClient` for model handling.
- Refactored existing code to accommodate the integration of the Vertex SDK and improve modularity.

* feat(vertex-ai): enhance VertexAI settings and API host management

- Added a new method to format the API host URL in both AnthropicVertexClient and VertexAPIClient.
- Updated getBaseURL methods to utilize the new formatting logic.
- Enhanced VertexAISettings component to include an input for API host configuration, with help text for user guidance.
- Updated localization files to include new help text for the API host field in multiple languages.

* fix(vertex-sdk): update baseURL handling and patch dependencies

- Refactored baseURL assignment in AnthropicVertexClient to ensure it defaults to undefined when the URL is empty.
- Updated yarn.lock to reflect changes in dependency resolution and checksum for @anthropic-ai/vertex-sdk patch.

* refactor(VertexAISetting): use provider.id rather than provider

* refactor: improve API host formatting in AnthropicVertexClient

- Updated the `formatApiHost` method to streamline host URL handling.
- Introduced a helper function to determine if the original host should be used based on its format.
- Ensured consistent appending of the `/v1/` path for valid API requests.

* fix: handle empty host in AnthropicVertexClient

- Added a check in the `getBaseURL` method to return the host if it is empty, preventing potential errors.
- Included a console log for the base URL to aid in debugging and verification of the URL formatting.

* feat(AnthropicVertexClient): add logging for authentication errors and mock client in tests

- Introduced logging functionality in AnthropicVertexClient to replace console.error with logger service for better error tracking.
- Added mock implementation for AnthropicVertexClient in tests to enhance testing capabilities.
- Updated package.json to include the @aws-sdk/client-s3 dependency.

* feat(tests): add comprehensive tests for client compatibility types

- Introduced a new test file to validate compatibility types for various API clients including OpenAI, Anthropic, Gemini, Aihubmix, NewAPI, and Vertex.
- Implemented mock services to facilitate testing and ensure isolation of client behavior.
- Added tests for both direct API clients and decorator pattern clients, ensuring correct compatibility type returns.
- Enhanced middleware compatibility logic tests to verify correct identification of compatible clients.

---------

Co-authored-by: one <wangan.cs@gmail.com>
2025-07-24 23:46:32 +08:00

174 lines
4.8 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { GoogleAuth } from 'google-auth-library'
interface ServiceAccountCredentials {
privateKey: string
clientEmail: string
}
interface VertexAIAuthParams {
projectId: string
serviceAccount?: ServiceAccountCredentials
}
const REQUIRED_VERTEX_AI_SCOPE = 'https://www.googleapis.com/auth/cloud-platform'
class VertexAIService {
private static instance: VertexAIService
private authClients: Map<string, GoogleAuth> = new Map()
static getInstance(): VertexAIService {
if (!VertexAIService.instance) {
VertexAIService.instance = new VertexAIService()
}
return VertexAIService.instance
}
/**
* 格式化私钥确保它包含正确的PEM头部和尾部
*/
private formatPrivateKey(privateKey: string): string {
if (!privateKey || typeof privateKey !== 'string') {
throw new Error('Private key must be a non-empty string')
}
// 处理JSON字符串中的转义换行符
let key = privateKey.replace(/\\n/g, '\n')
// 如果已经是正确格式的PEM直接返回
if (key.includes('-----BEGIN PRIVATE KEY-----') && key.includes('-----END PRIVATE KEY-----')) {
return key
}
// 移除所有换行符和空白字符(为了重新格式化)
key = key.replace(/\s+/g, '')
// 移除可能存在的头部和尾部
key = key.replace(/-----BEGIN[^-]*-----/g, '')
key = key.replace(/-----END[^-]*-----/g, '')
// 确保私钥不为空
if (!key) {
throw new Error('Private key is empty after formatting')
}
// 添加正确的PEM头部和尾部并格式化为64字符一行
const formattedKey = key.match(/.{1,64}/g)?.join('\n') || key
return `-----BEGIN PRIVATE KEY-----\n${formattedKey}\n-----END PRIVATE KEY-----`
}
/**
* 获取认证头用于 Vertex AI 请求
*/
async getAuthHeaders(params: VertexAIAuthParams): Promise<Record<string, string>> {
const { projectId, serviceAccount } = params
if (!serviceAccount?.privateKey || !serviceAccount?.clientEmail) {
throw new Error('Service account credentials are required')
}
// 创建缓存键
const cacheKey = `${projectId}-${serviceAccount.clientEmail}`
// 检查是否已有客户端实例
let auth = this.authClients.get(cacheKey)
if (!auth) {
try {
// 格式化私钥
const formattedPrivateKey = this.formatPrivateKey(serviceAccount.privateKey)
// 创建新的认证客户端
auth = new GoogleAuth({
credentials: {
private_key: formattedPrivateKey,
client_email: serviceAccount.clientEmail
},
projectId,
scopes: [REQUIRED_VERTEX_AI_SCOPE]
})
this.authClients.set(cacheKey, auth)
} catch (formatError: any) {
throw new Error(`Invalid private key format: ${formatError.message}`)
}
}
try {
// 获取认证头
const authHeaders = await auth.getRequestHeaders()
// 转换为普通对象
const headers: Record<string, string> = {}
for (const [key, value] of Object.entries(authHeaders)) {
if (typeof value === 'string') {
headers[key] = value
}
}
return headers
} catch (error: any) {
// 如果认证失败,清除缓存的客户端
this.authClients.delete(cacheKey)
throw new Error(`Failed to authenticate with service account: ${error.message}`)
}
}
async getAccessToken(params: VertexAIAuthParams): Promise<string> {
const { projectId, serviceAccount } = params
if (!serviceAccount?.privateKey || !serviceAccount?.clientEmail) {
throw new Error('Service account credentials are required')
}
const formattedPrivateKey = this.formatPrivateKey(serviceAccount.privateKey)
const cacheKey = `${projectId}-${serviceAccount.clientEmail}`
let auth = this.authClients.get(cacheKey)
if (!auth) {
auth = new GoogleAuth({
credentials: {
private_key: formattedPrivateKey,
client_email: serviceAccount.clientEmail
},
projectId,
scopes: [REQUIRED_VERTEX_AI_SCOPE]
})
this.authClients.set(cacheKey, auth)
}
const accessToken = await auth.getAccessToken()
return accessToken || ''
}
/**
* 清理指定项目的认证缓存
*/
clearAuthCache(projectId: string, clientEmail?: string): void {
if (clientEmail) {
const cacheKey = `${projectId}-${clientEmail}`
this.authClients.delete(cacheKey)
} else {
// 清理该项目的所有缓存
for (const [key] of this.authClients) {
if (key.startsWith(`${projectId}-`)) {
this.authClients.delete(key)
}
}
}
}
/**
* 清理所有认证缓存
*/
clearAllAuthCache(): void {
this.authClients.clear()
}
}
export default VertexAIService