From 4d1e3963bc57e412d8327b42a9a381246a4a9ab1 Mon Sep 17 00:00:00 2001 From: kangfenmao Date: Tue, 26 Nov 2024 13:15:25 +0800 Subject: [PATCH] feat: add configurable request options to gemini provider --- src/renderer/src/providers/GeminiProvider.ts | 80 ++++++++++++-------- 1 file changed, 47 insertions(+), 33 deletions(-) diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index 8a5f57b3cb..b502b8db15 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -5,6 +5,7 @@ import { HarmCategory, InlineDataPart, Part, + RequestOptions, TextPart } from '@google/generative-ai' import { SUMMARIZE_PROMPT } from '@renderer/config/prompts' @@ -20,10 +21,14 @@ import BaseProvider from './BaseProvider' export default class GeminiProvider extends BaseProvider { private sdk: GoogleGenerativeAI + private requestOptions: RequestOptions constructor(provider: Provider) { super(provider) this.sdk = new GoogleGenerativeAI(this.apiKey) + this.requestOptions = { + baseUrl: this.provider.apiHost + } } private async getMessageContents(message: Message): Promise { @@ -75,23 +80,26 @@ export default class GeminiProvider extends BaseProvider { history.push(await this.getMessageContents(message)) } - const geminiModel = this.sdk.getGenerativeModel({ - model: model.id, - systemInstruction: assistant.prompt, - generationConfig: { - maxOutputTokens: maxTokens, - temperature: assistant?.settings?.temperature - }, - safetySettings: [ - { category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold: HarmBlockThreshold.BLOCK_NONE }, - { - category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - threshold: HarmBlockThreshold.BLOCK_NONE + const geminiModel = this.sdk.getGenerativeModel( + { + model: model.id, + systemInstruction: assistant.prompt, + generationConfig: { + maxOutputTokens: maxTokens, + temperature: assistant?.settings?.temperature }, - { category: HarmCategory.HARM_CATEGORY_HARASSMENT, threshold: HarmBlockThreshold.BLOCK_NONE }, - { category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold: HarmBlockThreshold.BLOCK_NONE } - ] - }) + safetySettings: [ + { category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold: HarmBlockThreshold.BLOCK_NONE }, + { + category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold: HarmBlockThreshold.BLOCK_NONE + }, + { category: HarmCategory.HARM_CATEGORY_HARASSMENT, threshold: HarmBlockThreshold.BLOCK_NONE }, + { category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold: HarmBlockThreshold.BLOCK_NONE } + ] + }, + this.requestOptions + ) const chat = geminiModel.startChat({ history }) const messageContents = await this.getMessageContents(userLastMessage!) @@ -129,14 +137,17 @@ export default class GeminiProvider extends BaseProvider { const { maxTokens } = getAssistantSettings(assistant) const model = assistant.model || defaultModel - const geminiModel = this.sdk.getGenerativeModel({ - model: model.id, - systemInstruction: assistant.prompt, - generationConfig: { - maxOutputTokens: maxTokens, - temperature: assistant?.settings?.temperature - } - }) + const geminiModel = this.sdk.getGenerativeModel( + { + model: model.id, + systemInstruction: assistant.prompt, + generationConfig: { + maxOutputTokens: maxTokens, + temperature: assistant?.settings?.temperature + } + }, + this.requestOptions + ) const { response } = await geminiModel.generateContent(message.content) @@ -168,13 +179,16 @@ export default class GeminiProvider extends BaseProvider { content: userMessageContent } - const geminiModel = this.sdk.getGenerativeModel({ - model: model.id, - systemInstruction: systemMessage.content, - generationConfig: { - temperature: assistant?.settings?.temperature - } - }) + const geminiModel = this.sdk.getGenerativeModel( + { + model: model.id, + systemInstruction: systemMessage.content, + generationConfig: { + temperature: assistant?.settings?.temperature + } + }, + this.requestOptions + ) const chat = await geminiModel.startChat() @@ -187,7 +201,7 @@ export default class GeminiProvider extends BaseProvider { const model = getDefaultModel() const systemMessage = { role: 'system', content: prompt } - const geminiModel = this.sdk.getGenerativeModel({ model: model.id }) + const geminiModel = this.sdk.getGenerativeModel({ model: model.id }, this.requestOptions) const chat = await geminiModel.startChat({ systemInstruction: systemMessage.content }) const { response } = await chat.sendMessage(content) @@ -214,7 +228,7 @@ export default class GeminiProvider extends BaseProvider { } try { - const geminiModel = this.sdk.getGenerativeModel({ model: body.model }) + const geminiModel = this.sdk.getGenerativeModel({ model: body.model }, this.requestOptions) const result = await geminiModel.generateContent(body.messages[0].content) return { valid: !isEmpty(result.response.text()),