mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 21:01:32 +08:00
feat: enhance Vertex AI provider integration and configuration
- Added support for Google Vertex AI credentials in the provider configuration. - Refactored the VertexAPIClient to handle both standard and VertexProvider types. - Implemented utility functions to check Vertex AI configuration completeness and create VertexProvider instances. - Updated provider mapping in index_new.ts to ensure proper handling of Vertex AI settings.
This commit is contained in:
parent
8ca6341609
commit
f934b479b2
@ -40,10 +40,15 @@ const configHandlers: {
|
||||
'google-vertex': (builder, provider) => {
|
||||
const vertexBuilder = builder as ProviderConfigBuilder<'google-vertex'>
|
||||
const vertexProvider = provider as CompleteProviderConfig<'google-vertex'>
|
||||
vertexBuilder.withGoogleVertexConfig({
|
||||
project: vertexProvider.project,
|
||||
location: vertexProvider.location
|
||||
})
|
||||
vertexBuilder
|
||||
.withGoogleVertexConfig({
|
||||
project: vertexProvider.project,
|
||||
location: vertexProvider.location
|
||||
})
|
||||
.withGoogleCredentials({
|
||||
clientEmail: vertexProvider.googleCredentials?.clientEmail || '',
|
||||
privateKey: vertexProvider.googleCredentials?.privateKey || ''
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -266,7 +271,6 @@ export class ProviderConfigFactory {
|
||||
baseURL: string
|
||||
apiVersion?: string
|
||||
resourceName?: string
|
||||
deploymentName?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('azure')
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import { AihubmixAPIClient } from './AihubmixAPIClient'
|
||||
@ -46,6 +47,13 @@ export class ApiClientFactory {
|
||||
instance = new GeminiAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'vertexai':
|
||||
console.log(`[ApiClientFactory] Creating VertexAPIClient for provider: ${provider.id}`)
|
||||
// 检查 VertexAI 配置
|
||||
if (!isVertexAIConfigured()) {
|
||||
throw new Error(
|
||||
'VertexAI is not configured. Please configure project, location and service account credentials.'
|
||||
)
|
||||
}
|
||||
instance = new VertexAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'anthropic':
|
||||
|
||||
@ -1,15 +1,23 @@
|
||||
import { GoogleGenAI } from '@google/genai'
|
||||
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
|
||||
import { Provider } from '@renderer/types'
|
||||
import { createVertexProvider, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { Provider, VertexProvider } from '@renderer/types'
|
||||
|
||||
import { GeminiAPIClient } from './GeminiAPIClient'
|
||||
|
||||
export class VertexAPIClient extends GeminiAPIClient {
|
||||
private authHeaders?: Record<string, string>
|
||||
private authHeadersExpiry?: number
|
||||
private vertexProvider: VertexProvider
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
|
||||
// 如果传入的是普通 Provider,转换为 VertexProvider
|
||||
if (isVertexProvider(provider)) {
|
||||
this.vertexProvider = provider
|
||||
} else {
|
||||
this.vertexProvider = createVertexProvider(provider)
|
||||
}
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
@ -17,11 +25,9 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
const location = getVertexAILocation()
|
||||
const { googleCredentials, project, location } = this.vertexProvider
|
||||
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
|
||||
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project || !location) {
|
||||
throw new Error('Vertex AI settings are not configured')
|
||||
}
|
||||
|
||||
@ -29,7 +35,7 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
|
||||
this.sdkInstance = new GoogleGenAI({
|
||||
vertexai: true,
|
||||
project: projectId,
|
||||
project: project,
|
||||
location: location,
|
||||
httpOptions: {
|
||||
apiVersion: this.getApiVersion(),
|
||||
@ -44,11 +50,10 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
* 获取认证头,如果配置了 service account 则从主进程获取
|
||||
*/
|
||||
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
const { googleCredentials, project } = this.vertexProvider
|
||||
|
||||
// 检查是否配置了 service account
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
|
||||
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
@ -61,10 +66,10 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
try {
|
||||
// 从主进程获取认证头
|
||||
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
|
||||
projectId,
|
||||
projectId: project,
|
||||
serviceAccount: {
|
||||
privateKey: serviceAccount.privateKey,
|
||||
clientEmail: serviceAccount.clientEmail
|
||||
privateKey: googleCredentials.privateKey,
|
||||
clientEmail: googleCredentials.clientEmail
|
||||
}
|
||||
})
|
||||
|
||||
@ -85,11 +90,10 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
this.authHeaders = undefined
|
||||
this.authHeadersExpiry = undefined
|
||||
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
const { googleCredentials, project } = this.vertexProvider
|
||||
|
||||
if (projectId && serviceAccount.clientEmail) {
|
||||
window.api.vertexAI.clearAuthCache(projectId, serviceAccount.clientEmail)
|
||||
if (project && googleCredentials.clientEmail) {
|
||||
window.api.vertexAI.clearAuthCache(project, googleCredentials.clientEmail)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,7 +18,7 @@ import {
|
||||
StreamTextParams
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
|
||||
import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
|
||||
@ -26,7 +26,6 @@ import LegacyAiProvider from './index'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
|
||||
import { CompletionsResult } from './middleware/schemas'
|
||||
import { getAiSdkProviderId } from './provider/factory'
|
||||
import { getTimeout } from './transformParameters'
|
||||
|
||||
/**
|
||||
* 将 Provider 配置转换为新 AI SDK 格式
|
||||
@ -35,21 +34,27 @@ function providerToAiSdkConfig(provider: Provider): {
|
||||
providerId: ProviderId | 'openai-compatible'
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
} {
|
||||
const aiSdkProviderId = getAiSdkProviderId(provider)
|
||||
// 如果是 vertexai 类型且没有 googleCredentials,转换为 VertexProvider
|
||||
let actualProvider = provider
|
||||
if (provider.type === 'vertexai' && !isVertexProvider(provider)) {
|
||||
if (!isVertexAIConfigured()) {
|
||||
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
|
||||
}
|
||||
actualProvider = createVertexProvider(provider)
|
||||
}
|
||||
|
||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||
|
||||
if (aiSdkProviderId !== 'openai-compatible') {
|
||||
const defaultModel = getDefaultModel()
|
||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, provider, {
|
||||
timeout: getTimeout(defaultModel)
|
||||
})
|
||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, actualProvider)
|
||||
|
||||
return {
|
||||
providerId: aiSdkProviderId as ProviderId,
|
||||
options
|
||||
}
|
||||
} else {
|
||||
console.log(`Using openai-compatible fallback for provider: ${provider.type}`)
|
||||
const options = ProviderConfigFactory.createOpenAICompatible(provider.apiHost, provider.apiKey)
|
||||
console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`)
|
||||
const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey)
|
||||
|
||||
return {
|
||||
providerId: 'openai-compatible',
|
||||
@ -70,6 +75,11 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
// 对于 vertexai,检查配置是否完整
|
||||
if (provider.type === 'vertexai' && !isVertexAIConfigured()) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否为图像生成模型(暂时不支持)
|
||||
if (model && isDedicatedImageGenerationModel(model)) {
|
||||
return false
|
||||
|
||||
@ -5,6 +5,7 @@ import {
|
||||
setVertexAIServiceAccountClientEmail,
|
||||
setVertexAIServiceAccountPrivateKey
|
||||
} from '@renderer/store/llm'
|
||||
import { Provider, VertexProvider } from '@renderer/types'
|
||||
import { useDispatch } from 'react-redux'
|
||||
|
||||
export function useVertexAISettings() {
|
||||
@ -35,3 +36,43 @@ export function getVertexAIProjectId() {
|
||||
export function getVertexAIServiceAccount() {
|
||||
return store.getState().llm.settings.vertexai.serviceAccount
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫:检查 Provider 是否为 VertexProvider
|
||||
*/
|
||||
export function isVertexProvider(provider: Provider): provider is VertexProvider {
|
||||
return provider.type === 'vertexai' && 'googleCredentials' in provider
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 VertexProvider 对象,整合单独的配置
|
||||
* @param baseProvider 基础的 provider 配置
|
||||
* @returns VertexProvider 对象
|
||||
*/
|
||||
export function createVertexProvider(baseProvider: Provider): VertexProvider {
|
||||
const settings = getVertexAISettings()
|
||||
|
||||
return {
|
||||
...baseProvider,
|
||||
type: 'vertexai' as const,
|
||||
googleCredentials: {
|
||||
clientEmail: settings.serviceAccount.clientEmail,
|
||||
privateKey: settings.serviceAccount.privateKey
|
||||
},
|
||||
project: settings.projectId,
|
||||
location: settings.location
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查 VertexAI 配置是否完整
|
||||
*/
|
||||
export function isVertexAIConfigured(): boolean {
|
||||
const settings = getVertexAISettings()
|
||||
return !!(
|
||||
settings.serviceAccount.clientEmail &&
|
||||
settings.serviceAccount.privateKey &&
|
||||
settings.projectId &&
|
||||
settings.location
|
||||
)
|
||||
}
|
||||
|
||||
@ -146,7 +146,7 @@ export type User = {
|
||||
email: string
|
||||
}
|
||||
|
||||
export type Provider = {
|
||||
export interface BaseProvider {
|
||||
id: string
|
||||
type: ProviderType
|
||||
name: string
|
||||
@ -163,6 +163,18 @@ export type Provider = {
|
||||
notes?: string
|
||||
}
|
||||
|
||||
export type Provider = BaseProvider
|
||||
|
||||
export interface VertexProvider extends BaseProvider {
|
||||
type: 'vertexai'
|
||||
googleCredentials: {
|
||||
clientEmail: string
|
||||
privateKey: string
|
||||
}
|
||||
project: string
|
||||
location: string
|
||||
}
|
||||
|
||||
export type ProviderType =
|
||||
| 'openai'
|
||||
| 'openai-response'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user