diff --git a/packages/aiCore/src/providers/factory.ts b/packages/aiCore/src/providers/factory.ts index 12c721bdd7..1a90531b4d 100644 --- a/packages/aiCore/src/providers/factory.ts +++ b/packages/aiCore/src/providers/factory.ts @@ -40,10 +40,15 @@ const configHandlers: { 'google-vertex': (builder, provider) => { const vertexBuilder = builder as ProviderConfigBuilder<'google-vertex'> const vertexProvider = provider as CompleteProviderConfig<'google-vertex'> - vertexBuilder.withGoogleVertexConfig({ - project: vertexProvider.project, - location: vertexProvider.location - }) + vertexBuilder + .withGoogleVertexConfig({ + project: vertexProvider.project, + location: vertexProvider.location + }) + .withGoogleCredentials({ + clientEmail: vertexProvider.googleCredentials?.clientEmail || '', + privateKey: vertexProvider.googleCredentials?.privateKey || '' + }) } } @@ -266,7 +271,6 @@ export class ProviderConfigFactory { baseURL: string apiVersion?: string resourceName?: string - deploymentName?: string } ) { return this.builder('azure') diff --git a/src/renderer/src/aiCore/clients/ApiClientFactory.ts b/src/renderer/src/aiCore/clients/ApiClientFactory.ts index adc97e70e0..124c7d8a7a 100644 --- a/src/renderer/src/aiCore/clients/ApiClientFactory.ts +++ b/src/renderer/src/aiCore/clients/ApiClientFactory.ts @@ -1,3 +1,4 @@ +import { isVertexAIConfigured } from '@renderer/hooks/useVertexAI' import { Provider } from '@renderer/types' import { AihubmixAPIClient } from './AihubmixAPIClient' @@ -46,6 +47,13 @@ export class ApiClientFactory { instance = new GeminiAPIClient(provider) as BaseApiClient break case 'vertexai': + console.log(`[ApiClientFactory] Creating VertexAPIClient for provider: ${provider.id}`) + // 检查 VertexAI 配置 + if (!isVertexAIConfigured()) { + throw new Error( + 'VertexAI is not configured. Please configure project, location and service account credentials.' + ) + } instance = new VertexAPIClient(provider) as BaseApiClient break case 'anthropic': diff --git a/src/renderer/src/aiCore/clients/gemini/VertexAPIClient.ts b/src/renderer/src/aiCore/clients/gemini/VertexAPIClient.ts index 713d2585d3..c74142dfba 100644 --- a/src/renderer/src/aiCore/clients/gemini/VertexAPIClient.ts +++ b/src/renderer/src/aiCore/clients/gemini/VertexAPIClient.ts @@ -1,15 +1,23 @@ import { GoogleGenAI } from '@google/genai' -import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI' -import { Provider } from '@renderer/types' +import { createVertexProvider, isVertexProvider } from '@renderer/hooks/useVertexAI' +import { Provider, VertexProvider } from '@renderer/types' import { GeminiAPIClient } from './GeminiAPIClient' export class VertexAPIClient extends GeminiAPIClient { private authHeaders?: Record private authHeadersExpiry?: number + private vertexProvider: VertexProvider constructor(provider: Provider) { super(provider) + + // 如果传入的是普通 Provider,转换为 VertexProvider + if (isVertexProvider(provider)) { + this.vertexProvider = provider + } else { + this.vertexProvider = createVertexProvider(provider) + } } override async getSdkInstance() { @@ -17,11 +25,9 @@ export class VertexAPIClient extends GeminiAPIClient { return this.sdkInstance } - const serviceAccount = getVertexAIServiceAccount() - const projectId = getVertexAIProjectId() - const location = getVertexAILocation() + const { googleCredentials, project, location } = this.vertexProvider - if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) { + if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project || !location) { throw new Error('Vertex AI settings are not configured') } @@ -29,7 +35,7 @@ export class VertexAPIClient extends GeminiAPIClient { this.sdkInstance = new GoogleGenAI({ vertexai: true, - project: projectId, + project: project, location: location, httpOptions: { apiVersion: this.getApiVersion(), @@ -44,11 +50,10 @@ export class VertexAPIClient extends GeminiAPIClient { * 获取认证头,如果配置了 service account 则从主进程获取 */ private async getServiceAccountAuthHeaders(): Promise | undefined> { - const serviceAccount = getVertexAIServiceAccount() - const projectId = getVertexAIProjectId() + const { googleCredentials, project } = this.vertexProvider // 检查是否配置了 service account - if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) { + if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project) { return undefined } @@ -61,10 +66,10 @@ export class VertexAPIClient extends GeminiAPIClient { try { // 从主进程获取认证头 this.authHeaders = await window.api.vertexAI.getAuthHeaders({ - projectId, + projectId: project, serviceAccount: { - privateKey: serviceAccount.privateKey, - clientEmail: serviceAccount.clientEmail + privateKey: googleCredentials.privateKey, + clientEmail: googleCredentials.clientEmail } }) @@ -85,11 +90,10 @@ export class VertexAPIClient extends GeminiAPIClient { this.authHeaders = undefined this.authHeadersExpiry = undefined - const serviceAccount = getVertexAIServiceAccount() - const projectId = getVertexAIProjectId() + const { googleCredentials, project } = this.vertexProvider - if (projectId && serviceAccount.clientEmail) { - window.api.vertexAI.clearAuthCache(projectId, serviceAccount.clientEmail) + if (project && googleCredentials.clientEmail) { + window.api.vertexAI.clearAuthCache(project, googleCredentials.clientEmail) } } } diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 530aca6e12..d0b558c49a 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -18,7 +18,7 @@ import { StreamTextParams } from '@cherrystudio/ai-core' import { isDedicatedImageGenerationModel } from '@renderer/config/models' -import { getDefaultModel } from '@renderer/services/AssistantService' +import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import type { GenerateImageParams, Model, Provider } from '@renderer/types' import AiSdkToChunkAdapter from './AiSdkToChunkAdapter' @@ -26,7 +26,6 @@ import LegacyAiProvider from './index' import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder' import { CompletionsResult } from './middleware/schemas' import { getAiSdkProviderId } from './provider/factory' -import { getTimeout } from './transformParameters' /** * 将 Provider 配置转换为新 AI SDK 格式 @@ -35,21 +34,27 @@ function providerToAiSdkConfig(provider: Provider): { providerId: ProviderId | 'openai-compatible' options: ProviderSettingsMap[keyof ProviderSettingsMap] } { - const aiSdkProviderId = getAiSdkProviderId(provider) + // 如果是 vertexai 类型且没有 googleCredentials,转换为 VertexProvider + let actualProvider = provider + if (provider.type === 'vertexai' && !isVertexProvider(provider)) { + if (!isVertexAIConfigured()) { + throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') + } + actualProvider = createVertexProvider(provider) + } + + const aiSdkProviderId = getAiSdkProviderId(actualProvider) if (aiSdkProviderId !== 'openai-compatible') { - const defaultModel = getDefaultModel() - const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, provider, { - timeout: getTimeout(defaultModel) - }) + const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, actualProvider) return { providerId: aiSdkProviderId as ProviderId, options } } else { - console.log(`Using openai-compatible fallback for provider: ${provider.type}`) - const options = ProviderConfigFactory.createOpenAICompatible(provider.apiHost, provider.apiKey) + console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`) + const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey) return { providerId: 'openai-compatible', @@ -70,6 +75,11 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean { return false } + // 对于 vertexai,检查配置是否完整 + if (provider.type === 'vertexai' && !isVertexAIConfigured()) { + return false + } + // 检查是否为图像生成模型(暂时不支持) if (model && isDedicatedImageGenerationModel(model)) { return false diff --git a/src/renderer/src/hooks/useVertexAI.ts b/src/renderer/src/hooks/useVertexAI.ts index 89769c0d54..6b443f395b 100644 --- a/src/renderer/src/hooks/useVertexAI.ts +++ b/src/renderer/src/hooks/useVertexAI.ts @@ -5,6 +5,7 @@ import { setVertexAIServiceAccountClientEmail, setVertexAIServiceAccountPrivateKey } from '@renderer/store/llm' +import { Provider, VertexProvider } from '@renderer/types' import { useDispatch } from 'react-redux' export function useVertexAISettings() { @@ -35,3 +36,43 @@ export function getVertexAIProjectId() { export function getVertexAIServiceAccount() { return store.getState().llm.settings.vertexai.serviceAccount } + +/** + * 类型守卫:检查 Provider 是否为 VertexProvider + */ +export function isVertexProvider(provider: Provider): provider is VertexProvider { + return provider.type === 'vertexai' && 'googleCredentials' in provider +} + +/** + * 创建 VertexProvider 对象,整合单独的配置 + * @param baseProvider 基础的 provider 配置 + * @returns VertexProvider 对象 + */ +export function createVertexProvider(baseProvider: Provider): VertexProvider { + const settings = getVertexAISettings() + + return { + ...baseProvider, + type: 'vertexai' as const, + googleCredentials: { + clientEmail: settings.serviceAccount.clientEmail, + privateKey: settings.serviceAccount.privateKey + }, + project: settings.projectId, + location: settings.location + } +} + +/** + * 检查 VertexAI 配置是否完整 + */ +export function isVertexAIConfigured(): boolean { + const settings = getVertexAISettings() + return !!( + settings.serviceAccount.clientEmail && + settings.serviceAccount.privateKey && + settings.projectId && + settings.location + ) +} diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index e60f374ba7..32aa7b558c 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -146,7 +146,7 @@ export type User = { email: string } -export type Provider = { +export interface BaseProvider { id: string type: ProviderType name: string @@ -163,6 +163,18 @@ export type Provider = { notes?: string } +export type Provider = BaseProvider + +export interface VertexProvider extends BaseProvider { + type: 'vertexai' + googleCredentials: { + clientEmail: string + privateKey: string + } + project: string + location: string +} + export type ProviderType = | 'openai' | 'openai-response'