mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-25 03:10:08 +08:00
feat(GeminiProvider): Add isGemmaModel function and update model handling
Introduce isGemmaModel function to identify Gemma models and adjust system instruction handling in GeminiProvider based on model type. Ensure proper message formatting for Gemma models during chat initialization.
This commit is contained in:
parent
5f35fceb25
commit
65d89cf759
@ -1992,3 +1992,11 @@ export function getOpenAIWebSearchParams(assistant: Assistant, model: Model): Re
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
export function isGemmaModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
return model.id.includes('gemma-') || model.group === 'Gemma'
|
||||
}
|
||||
|
||||
@ -13,7 +13,7 @@ import {
|
||||
SafetySetting,
|
||||
TextPart
|
||||
} from '@google/generative-ai'
|
||||
import { isWebSearchModel } from '@renderer/config/models'
|
||||
import { isGemmaModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||
@ -205,7 +205,7 @@ export default class GeminiProvider extends BaseProvider {
|
||||
const geminiModel = this.sdk.getGenerativeModel(
|
||||
{
|
||||
model: model.id,
|
||||
systemInstruction: assistant.prompt,
|
||||
...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }),
|
||||
safetySettings: this.getSafetySettings(model.id),
|
||||
tools: tools,
|
||||
generationConfig: {
|
||||
@ -221,6 +221,27 @@ export default class GeminiProvider extends BaseProvider {
|
||||
const chat = geminiModel.startChat({ history })
|
||||
const messageContents = await this.getMessageContents(userLastMessage!)
|
||||
|
||||
if (isGemmaModel(model) && assistant.prompt) {
|
||||
const isFirstMessage = history.length === 0
|
||||
if (isFirstMessage) {
|
||||
const systemMessage = {
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
text:
|
||||
'<start_of_turn>user\n' +
|
||||
assistant.prompt +
|
||||
'<end_of_turn>\n' +
|
||||
'<start_of_turn>user\n' +
|
||||
messageContents.parts[0].text +
|
||||
'<end_of_turn>'
|
||||
}
|
||||
]
|
||||
}
|
||||
messageContents.parts = systemMessage.parts
|
||||
}
|
||||
}
|
||||
|
||||
const start_time_millsec = new Date().getTime()
|
||||
const { abortController, cleanup } = this.createAbortController(userLastMessage?.id)
|
||||
const { signal } = abortController
|
||||
|
||||
Loading…
Reference in New Issue
Block a user