mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
* refactor(main): 使用枚举管理 IPC 通道 - 新增 IpcChannel 枚举,用于统一管理所有的 IPC 通道 - 修改相关代码,使用 IpcChannel 枚举替代硬编码的字符串通道名称 - 此改动有助于提高代码的可维护性和可读性,避免因通道名称变更导致的错误 * refactor(ipc): 将字符串通道名称替换为 IpcChannel 枚举 - 在多个文件中将硬编码的字符串通道名称替换为 IpcChannel 枚举值 - 更新了相关文件的导入,增加了对 IpcChannel 的引用 - 通过使用枚举来管理 IPC 通道名称,提高了代码的可维护性和可读性 * refactor(ipc): 调整 IPC 通道枚举和预加载脚本 - 移除了 IpcChannel 枚举中的未使用注释 - 更新了预加载脚本中 IpcChannel 的导入路径 * refactor(ipc): 更新 IpcChannel导入路径 - 将 IpcChannel 的导入路径从 @main/enum/IpcChannel 修改为 @shared/IpcChannel - 此修改涉及多个文件,包括 AppUpdater、BackupManager、EditMcpJsonPopup 等 - 同时移除了 tsconfig.web.json 中对 src/main/**/* 的引用 * refactor(ipc): 添加 ReduxStoreReady 事件并更新事件监听 - 在 IpcChannel 枚举中添加 ReduxStoreReady 事件 - 更新 ReduxService 中的事件监听,使用新的枚举值 * refactor(main): 重构 ReduxService 中的状态变化事件处理 - 将状态变化事件名称定义为常量 STATUS_CHANGE_EVENT - 更新事件监听和触发使用新的常量 - 优化了代码结构,提高了可维护性 * refactor(i18n): 优化国际化配置和语言选择逻辑 - 在多个文件中引入 defaultLanguage 常量,统一默认语言设置 - 调整 i18n 初始化和语言变更逻辑,使用新配置 - 更新相关组件和 Hook 中的语言选择逻辑 * refactor(ConfigManager): 重构配置管理器 - 添加 ConfigKeys 枚举,用于统一配置项的键名 - 引入 defaultLanguage,作为默认语言设置 - 重构 get 和 set 方法,使用 ConfigKeys 枚举作为键名 - 优化类型定义和方法签名,提高代码可读性和可维护性 * refactor(ConfigManager): 重命名配置键 ZoomFactor 将配置键 zoomFactor 重命名为 ZoomFactor,以符合命名规范。 更新了相关方法和属性以反映这一变更。 * refactor(shared): 重构常量定义并优化文件大小格式化逻辑 - 在 constant.ts 中添加 KB、MB、GB 常量定义 - 将 defaultLanguage 移至 constant.ts - 更新 ConfigManager、useAppInit、i18n、GeneralSettings 等文件中的导入路径 - 优化 formatFileSize 函数,使用新定义的常量 * refactor(FileSize): 使用 GB/MB/KB 等常量处理文件大小计算 * refactor(ipc): 将字符串通道名称替换为 IpcChannel 枚举 - 在多个文件中将硬编码的字符串通道名称替换为 IpcChannel 枚举值 - 更新了相关文件的导入,增加了对 IpcChannel 的引用 - 通过使用枚举来管理 IPC 通道名称,提高了代码的可维护性和可读性 * refactor(ipc): 更新 IpcChannel导入路径 - 将 IpcChannel 的导入路径从 @main/enum/IpcChannel 修改为 @shared/IpcChannel - 此修改涉及多个文件,包括 AppUpdater、BackupManager、EditMcpJsonPopup 等 - 同时移除了 tsconfig.web.json 中对 src/main/**/* 的引用 * refactor(i18n): 优化国际化配置和语言选择逻辑 - 在多个文件中引入 defaultLanguage 常量,统一默认语言设置 - 调整 i18n 初始化和语言变更逻辑,使用新配置 - 更新相关组件和 Hook 中的语言选择逻辑 * refactor(shared): 重构常量定义并优化文件大小格式化逻辑 - 在 constant.ts 中添加 KB、MB、GB 常量定义 - 将 defaultLanguage 移至 constant.ts - 更新 ConfigManager、useAppInit、i18n、GeneralSettings 等文件中的导入路径 - 优化 formatFileSize 函数,使用新定义的常量 * refactor: 移除重复的导入语句 - 在 HomeWindow.tsx 和 useAppInit.ts 文件中移除了重复的 defaultLanguage导入语句 - 这个改动简化了代码结构,提高了代码的可读性和维护性
816 lines
24 KiB
TypeScript
816 lines
24 KiB
TypeScript
import {
|
|
ContentListUnion,
|
|
createPartFromBase64,
|
|
FinishReason,
|
|
GenerateContentResponse,
|
|
GoogleGenAI
|
|
} from '@google/genai'
|
|
import {
|
|
Content,
|
|
FileDataPart,
|
|
FunctionCallPart,
|
|
FunctionResponsePart,
|
|
GenerateContentStreamResult,
|
|
GoogleGenerativeAI,
|
|
HarmBlockThreshold,
|
|
HarmCategory,
|
|
InlineDataPart,
|
|
Part,
|
|
RequestOptions,
|
|
SafetySetting,
|
|
TextPart
|
|
} from '@google/generative-ai'
|
|
import { isGemmaModel, isWebSearchModel } from '@renderer/config/models'
|
|
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
|
import i18n from '@renderer/i18n'
|
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
|
import { EVENT_NAMES } from '@renderer/services/EventService'
|
|
import {
|
|
filterContextMessages,
|
|
filterEmptyMessages,
|
|
filterUserRoleStartMessages
|
|
} from '@renderer/services/MessagesService'
|
|
import { Assistant, FileType, FileTypes, MCPToolResponse, Message, Model, Provider, Suggestion } from '@renderer/types'
|
|
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
|
import {
|
|
callMCPTool,
|
|
geminiFunctionCallToMcpTool,
|
|
mcpToolsToGeminiTools,
|
|
upsertMCPToolResponse
|
|
} from '@renderer/utils/mcp-tools'
|
|
import axios from 'axios'
|
|
import { isEmpty, takeRight } from 'lodash'
|
|
import OpenAI from 'openai'
|
|
|
|
import { ChunkCallbackData, CompletionsParams } from '.'
|
|
import BaseProvider from './BaseProvider'
|
|
import { MB } from '@shared/config/constant'
|
|
|
|
export default class GeminiProvider extends BaseProvider {
|
|
private sdk: GoogleGenerativeAI
|
|
private requestOptions: RequestOptions
|
|
private imageSdk: GoogleGenAI
|
|
|
|
constructor(provider: Provider) {
|
|
super(provider)
|
|
this.sdk = new GoogleGenerativeAI(this.apiKey)
|
|
/// this sdk is experimental
|
|
this.imageSdk = new GoogleGenAI({ apiKey: this.apiKey, httpOptions: { baseUrl: this.getBaseURL() } })
|
|
this.requestOptions = {
|
|
baseUrl: this.getBaseURL()
|
|
}
|
|
}
|
|
|
|
public getBaseURL(): string {
|
|
return this.provider.apiHost
|
|
}
|
|
|
|
/**
|
|
* Handle a PDF file
|
|
* @param file - The file
|
|
* @returns The part
|
|
*/
|
|
private async handlePdfFile(file: FileType): Promise<Part> {
|
|
const smallFileSize = 20 * MB
|
|
const isSmallFile = file.size < smallFileSize
|
|
|
|
if (isSmallFile) {
|
|
const { data, mimeType } = await window.api.gemini.base64File(file)
|
|
return {
|
|
inlineData: {
|
|
data,
|
|
mimeType
|
|
}
|
|
} as InlineDataPart
|
|
}
|
|
|
|
// Retrieve file from Gemini uploaded files
|
|
const fileMetadata = await window.api.gemini.retrieveFile(file, this.apiKey)
|
|
|
|
if (fileMetadata) {
|
|
return {
|
|
fileData: {
|
|
fileUri: fileMetadata.uri,
|
|
mimeType: fileMetadata.mimeType
|
|
}
|
|
} as FileDataPart
|
|
}
|
|
|
|
// If file is not found, upload it to Gemini
|
|
const uploadResult = await window.api.gemini.uploadFile(file, this.apiKey)
|
|
|
|
return {
|
|
fileData: {
|
|
fileUri: uploadResult.file.uri,
|
|
mimeType: uploadResult.file.mimeType
|
|
}
|
|
} as FileDataPart
|
|
}
|
|
|
|
/**
|
|
* Get the message contents
|
|
* @param message - The message
|
|
* @returns The message contents
|
|
*/
|
|
private async getMessageContents(message: Message): Promise<Content> {
|
|
const role = message.role === 'user' ? 'user' : 'model'
|
|
|
|
const parts: Part[] = [{ text: await this.getMessageContent(message) }]
|
|
// Add any generated images from previous responses
|
|
if (message.metadata?.generateImage?.images && message.metadata.generateImage.images.length > 0) {
|
|
for (const imageUrl of message.metadata.generateImage.images) {
|
|
if (imageUrl && imageUrl.startsWith('data:')) {
|
|
// Extract base64 data and mime type from the data URL
|
|
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
|
|
if (matches && matches.length === 3) {
|
|
const mimeType = matches[1]
|
|
const base64Data = matches[2]
|
|
parts.push({
|
|
inlineData: {
|
|
data: base64Data,
|
|
mimeType: mimeType
|
|
}
|
|
} as InlineDataPart)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for (const file of message.files || []) {
|
|
if (file.type === FileTypes.IMAGE) {
|
|
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
|
parts.push({
|
|
inlineData: {
|
|
data: base64Data.base64,
|
|
mimeType: base64Data.mime
|
|
}
|
|
} as InlineDataPart)
|
|
}
|
|
|
|
if (file.ext === '.pdf') {
|
|
parts.push(await this.handlePdfFile(file))
|
|
continue
|
|
}
|
|
|
|
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
|
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
|
parts.push({
|
|
text: file.origin_name + '\n' + fileContent
|
|
} as TextPart)
|
|
}
|
|
}
|
|
|
|
return {
|
|
role,
|
|
parts
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Get the safety settings
|
|
* @param modelId - The model ID
|
|
* @returns The safety settings
|
|
*/
|
|
private getSafetySettings(modelId: string): SafetySetting[] {
|
|
const safetyThreshold = modelId.includes('gemini-2.0-flash-exp')
|
|
? ('OFF' as HarmBlockThreshold)
|
|
: HarmBlockThreshold.BLOCK_NONE
|
|
|
|
return [
|
|
{
|
|
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
|
threshold: safetyThreshold
|
|
},
|
|
{
|
|
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
|
threshold: safetyThreshold
|
|
},
|
|
{
|
|
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
|
|
threshold: safetyThreshold
|
|
},
|
|
{
|
|
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
|
threshold: safetyThreshold
|
|
},
|
|
{
|
|
category: 'HARM_CATEGORY_CIVIC_INTEGRITY' as HarmCategory,
|
|
threshold: safetyThreshold
|
|
}
|
|
]
|
|
}
|
|
|
|
/**
|
|
* Generate completions
|
|
* @param messages - The messages
|
|
* @param assistant - The assistant
|
|
* @param mcpTools - The MCP tools
|
|
* @param onChunk - The onChunk callback
|
|
* @param onFilterMessages - The onFilterMessages callback
|
|
*/
|
|
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
|
|
if (assistant.enableGenerateImage) {
|
|
await this.generateImageExp({ messages, assistant, onFilterMessages, onChunk })
|
|
} else {
|
|
const defaultModel = getDefaultModel()
|
|
const model = assistant.model || defaultModel
|
|
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
|
|
|
const userMessages = filterUserRoleStartMessages(
|
|
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
|
)
|
|
onFilterMessages(userMessages)
|
|
|
|
const userLastMessage = userMessages.pop()
|
|
|
|
const history: Content[] = []
|
|
|
|
for (const message of userMessages) {
|
|
history.push(await this.getMessageContents(message))
|
|
}
|
|
|
|
const tools = mcpToolsToGeminiTools(mcpTools)
|
|
const toolResponses: MCPToolResponse[] = []
|
|
|
|
if (assistant.enableWebSearch && isWebSearchModel(model)) {
|
|
tools.push({
|
|
// @ts-ignore googleSearch is not a valid tool for Gemini
|
|
googleSearch: {}
|
|
})
|
|
}
|
|
|
|
const geminiModel = this.sdk.getGenerativeModel(
|
|
{
|
|
model: model.id,
|
|
...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }),
|
|
safetySettings: this.getSafetySettings(model.id),
|
|
tools: tools,
|
|
generationConfig: {
|
|
maxOutputTokens: maxTokens,
|
|
temperature: assistant?.settings?.temperature,
|
|
topP: assistant?.settings?.topP,
|
|
...this.getCustomParameters(assistant)
|
|
}
|
|
},
|
|
this.requestOptions
|
|
)
|
|
|
|
const chat = geminiModel.startChat({ history })
|
|
const messageContents = await this.getMessageContents(userLastMessage!)
|
|
|
|
if (isGemmaModel(model) && assistant.prompt) {
|
|
const isFirstMessage = history.length === 0
|
|
if (isFirstMessage) {
|
|
const systemMessage = {
|
|
role: 'user',
|
|
parts: [
|
|
{
|
|
text:
|
|
'<start_of_turn>user\n' +
|
|
assistant.prompt +
|
|
'<end_of_turn>\n' +
|
|
'<start_of_turn>user\n' +
|
|
messageContents.parts[0].text +
|
|
'<end_of_turn>'
|
|
}
|
|
]
|
|
}
|
|
messageContents.parts = systemMessage.parts
|
|
}
|
|
}
|
|
|
|
const start_time_millsec = new Date().getTime()
|
|
const { abortController, cleanup } = this.createAbortController(userLastMessage?.id)
|
|
const { signal } = abortController
|
|
|
|
if (!streamOutput) {
|
|
const { response } = await chat.sendMessage(messageContents.parts, { signal })
|
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
|
onChunk({
|
|
text: response.candidates?.[0].content.parts[0].text,
|
|
usage: {
|
|
prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
|
|
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
|
|
total_tokens: response.usageMetadata?.totalTokenCount || 0
|
|
},
|
|
metrics: {
|
|
completion_tokens: response.usageMetadata?.candidatesTokenCount,
|
|
time_completion_millsec,
|
|
time_first_token_millsec: 0
|
|
},
|
|
search: response.candidates?.[0]?.groundingMetadata
|
|
})
|
|
return
|
|
}
|
|
|
|
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal })
|
|
let time_first_token_millsec = 0
|
|
|
|
const processStream = async (stream: GenerateContentStreamResult, idx: number) => {
|
|
for await (const chunk of stream.stream) {
|
|
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
|
|
|
if (time_first_token_millsec == 0) {
|
|
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
|
}
|
|
|
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
|
|
|
const functionCalls = chunk.functionCalls()
|
|
|
|
if (functionCalls) {
|
|
const fcallParts: FunctionCallPart[] = []
|
|
const fcRespParts: FunctionResponsePart[] = []
|
|
for (const call of functionCalls) {
|
|
console.log('Function call:', call)
|
|
fcallParts.push({ functionCall: call } as FunctionCallPart)
|
|
const mcpTool = geminiFunctionCallToMcpTool(mcpTools, call)
|
|
if (mcpTool) {
|
|
upsertMCPToolResponse(
|
|
toolResponses,
|
|
{
|
|
tool: mcpTool,
|
|
status: 'invoking',
|
|
id: `${call.name}-${idx}`
|
|
},
|
|
onChunk
|
|
)
|
|
const toolCallResponse = await callMCPTool(mcpTool)
|
|
fcRespParts.push({
|
|
functionResponse: {
|
|
name: mcpTool.id,
|
|
response: toolCallResponse
|
|
}
|
|
})
|
|
upsertMCPToolResponse(
|
|
toolResponses,
|
|
{
|
|
tool: mcpTool,
|
|
status: 'done',
|
|
response: toolCallResponse,
|
|
id: `${call.name}-${idx}`
|
|
},
|
|
onChunk
|
|
)
|
|
}
|
|
}
|
|
|
|
if (fcRespParts) {
|
|
history.push(messageContents)
|
|
history.push({
|
|
role: 'model',
|
|
parts: fcallParts
|
|
})
|
|
const newChat = geminiModel.startChat({ history })
|
|
const newStream = await newChat.sendMessageStream(fcRespParts, { signal })
|
|
await processStream(newStream, idx + 1)
|
|
}
|
|
}
|
|
|
|
onChunk({
|
|
text: chunk.text(),
|
|
usage: {
|
|
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
|
|
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
|
|
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
|
|
},
|
|
metrics: {
|
|
completion_tokens: chunk.usageMetadata?.candidatesTokenCount,
|
|
time_completion_millsec,
|
|
time_first_token_millsec
|
|
},
|
|
search: chunk.candidates?.[0]?.groundingMetadata,
|
|
mcpToolResponse: toolResponses
|
|
})
|
|
}
|
|
}
|
|
|
|
await processStream(userMessagesStream, 0).finally(cleanup)
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Translate a message
|
|
* @param message - The message
|
|
* @param assistant - The assistant
|
|
* @param onResponse - The onResponse callback
|
|
* @returns The translated message
|
|
*/
|
|
async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
|
const defaultModel = getDefaultModel()
|
|
const { maxTokens } = getAssistantSettings(assistant)
|
|
const model = assistant.model || defaultModel
|
|
|
|
const geminiModel = this.sdk.getGenerativeModel(
|
|
{
|
|
model: model.id,
|
|
...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }),
|
|
generationConfig: {
|
|
maxOutputTokens: maxTokens,
|
|
temperature: assistant?.settings?.temperature
|
|
}
|
|
},
|
|
this.requestOptions
|
|
)
|
|
|
|
const content =
|
|
isGemmaModel(model) && assistant.prompt
|
|
? `<start_of_turn>user\n${assistant.prompt}<end_of_turn>\n<start_of_turn>user\n${message.content}<end_of_turn>`
|
|
: message.content
|
|
|
|
if (!onResponse) {
|
|
const { response } = await geminiModel.generateContent(content)
|
|
return response.text()
|
|
}
|
|
|
|
const response = await geminiModel.generateContentStream(content)
|
|
|
|
let text = ''
|
|
|
|
for await (const chunk of response.stream) {
|
|
text += chunk.text()
|
|
onResponse(text)
|
|
}
|
|
|
|
return text
|
|
}
|
|
|
|
/**
|
|
* Summarize a message
|
|
* @param messages - The messages
|
|
* @param assistant - The assistant
|
|
* @returns The summary
|
|
*/
|
|
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
|
|
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
|
|
|
const userMessages = takeRight(messages, 5)
|
|
.filter((message) => !message.isPreset)
|
|
.map((message) => ({
|
|
role: message.role,
|
|
content: message.content
|
|
}))
|
|
|
|
const userMessageContent = userMessages.reduce((prev, curr) => {
|
|
const content = curr.role === 'user' ? `User: ${curr.content}` : `Assistant: ${curr.content}`
|
|
return prev + (prev ? '\n' : '') + content
|
|
}, '')
|
|
|
|
const systemMessage = {
|
|
role: 'system',
|
|
content: (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title')
|
|
}
|
|
|
|
const userMessage = {
|
|
role: 'user',
|
|
content: userMessageContent
|
|
}
|
|
|
|
const geminiModel = this.sdk.getGenerativeModel(
|
|
{
|
|
model: model.id,
|
|
...(isGemmaModel(model) ? {} : { systemInstruction: systemMessage.content }),
|
|
generationConfig: {
|
|
temperature: assistant?.settings?.temperature
|
|
}
|
|
},
|
|
this.requestOptions
|
|
)
|
|
|
|
const chat = await geminiModel.startChat()
|
|
const content = isGemmaModel(model)
|
|
? `<start_of_turn>user\n${systemMessage.content}<end_of_turn>\n<start_of_turn>user\n${userMessage.content}<end_of_turn>`
|
|
: userMessage.content
|
|
|
|
const { response } = await chat.sendMessage(content)
|
|
|
|
return removeSpecialCharactersForTopicName(response.text())
|
|
}
|
|
|
|
/**
|
|
* Generate text
|
|
* @param prompt - The prompt
|
|
* @param content - The content
|
|
* @returns The generated text
|
|
*/
|
|
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
|
const model = getDefaultModel()
|
|
const systemMessage = { role: 'system', content: prompt }
|
|
|
|
const geminiModel = this.sdk.getGenerativeModel(
|
|
{
|
|
model: model.id,
|
|
...(isGemmaModel(model) ? {} : { systemInstruction: systemMessage.content })
|
|
},
|
|
this.requestOptions
|
|
)
|
|
|
|
const chat = await geminiModel.startChat()
|
|
const messageContent = isGemmaModel(model)
|
|
? `<start_of_turn>user\n${prompt}<end_of_turn>\n<start_of_turn>user\n${content}<end_of_turn>`
|
|
: content
|
|
|
|
const { response } = await chat.sendMessage(messageContent)
|
|
|
|
return response.text()
|
|
}
|
|
|
|
/**
|
|
* Generate suggestions
|
|
* @returns The suggestions
|
|
*/
|
|
public async suggestions(): Promise<Suggestion[]> {
|
|
return []
|
|
}
|
|
|
|
/**
|
|
* Summarize a message for search
|
|
* @param messages - The messages
|
|
* @param assistant - The assistant
|
|
* @returns The summary
|
|
*/
|
|
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string> {
|
|
const model = assistant.model || getDefaultModel()
|
|
|
|
const systemMessage = {
|
|
role: 'system',
|
|
content: assistant.prompt
|
|
}
|
|
|
|
const userMessage = {
|
|
role: 'user',
|
|
content: messages.map((m) => m.content).join('\n')
|
|
}
|
|
|
|
const geminiModel = this.sdk.getGenerativeModel(
|
|
{
|
|
model: model.id,
|
|
systemInstruction: systemMessage.content,
|
|
generationConfig: {
|
|
temperature: assistant?.settings?.temperature
|
|
}
|
|
},
|
|
{
|
|
...this.requestOptions,
|
|
timeout: 20 * 1000
|
|
}
|
|
)
|
|
|
|
const chat = await geminiModel.startChat()
|
|
const { response } = await chat.sendMessage(userMessage.content)
|
|
|
|
return response.text()
|
|
}
|
|
|
|
/**
|
|
* Generate an image
|
|
* @returns The generated image
|
|
*/
|
|
public async generateImage(): Promise<string[]> {
|
|
return []
|
|
}
|
|
|
|
/**
|
|
* 生成图像
|
|
* @param messages - 消息列表
|
|
* @param assistant - 助手配置
|
|
* @param onChunk - 处理生成块的回调
|
|
* @param onFilterMessages - 过滤消息的回调
|
|
* @returns Promise<void>
|
|
*/
|
|
private async generateImageExp({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
|
|
const defaultModel = getDefaultModel()
|
|
const model = assistant.model || defaultModel
|
|
const { contextCount, streamOutput, maxTokens } = getAssistantSettings(assistant)
|
|
|
|
const userMessages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
|
onFilterMessages(userMessages)
|
|
|
|
const userLastMessage = userMessages.pop()
|
|
if (!userLastMessage) {
|
|
throw new Error('No user message found')
|
|
}
|
|
|
|
const history: Content[] = []
|
|
|
|
for (const message of userMessages) {
|
|
history.push(await this.getMessageContents(message))
|
|
}
|
|
|
|
const userLastMessageContent = await this.getMessageContents(userLastMessage)
|
|
const allContents = [...history, userLastMessageContent]
|
|
|
|
let contents: ContentListUnion = allContents.length > 0 ? (allContents as ContentListUnion) : []
|
|
|
|
contents = await this.addImageFileToContents(userLastMessage, contents)
|
|
|
|
if (!streamOutput) {
|
|
const response = await this.callGeminiGenerateContent(model.id, contents, maxTokens)
|
|
|
|
const { isValid, message } = this.isValidGeminiResponse(response)
|
|
if (!isValid) {
|
|
throw new Error(`Gemini API error: ${message}`)
|
|
}
|
|
|
|
this.processGeminiImageResponse(response, onChunk)
|
|
return
|
|
}
|
|
const response = await this.callGeminiGenerateContentStream(model.id, contents, maxTokens)
|
|
|
|
for await (const chunk of response) {
|
|
this.processGeminiImageResponse(chunk, onChunk)
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 添加图片文件到内容列表
|
|
* @param message - 用户消息
|
|
* @param contents - 内容列表
|
|
* @returns 更新后的内容列表
|
|
*/
|
|
private async addImageFileToContents(message: Message, contents: ContentListUnion): Promise<ContentListUnion> {
|
|
if (message.files && message.files.length > 0) {
|
|
const file = message.files[0]
|
|
const fileContent = await window.api.file.base64Image(file.id + file.ext)
|
|
|
|
if (fileContent && fileContent.base64) {
|
|
const contentsArray = Array.isArray(contents) ? contents : [contents]
|
|
return [...contentsArray, createPartFromBase64(fileContent.base64, fileContent.mime)]
|
|
}
|
|
}
|
|
return contents
|
|
}
|
|
|
|
/**
|
|
* 调用Gemini API生成内容
|
|
* @param modelId - 模型ID
|
|
* @param contents - 内容列表
|
|
* @returns 生成结果
|
|
*/
|
|
private async callGeminiGenerateContent(
|
|
modelId: string,
|
|
contents: ContentListUnion,
|
|
maxTokens?: number
|
|
): Promise<GenerateContentResponse> {
|
|
try {
|
|
return await this.imageSdk.models.generateContent({
|
|
model: modelId,
|
|
contents: contents,
|
|
config: {
|
|
responseModalities: ['Text', 'Image'],
|
|
responseMimeType: 'text/plain',
|
|
maxOutputTokens: maxTokens
|
|
}
|
|
})
|
|
} catch (error) {
|
|
console.error('Gemini API error:', error)
|
|
throw error
|
|
}
|
|
}
|
|
|
|
private async callGeminiGenerateContentStream(
|
|
modelId: string,
|
|
contents: ContentListUnion,
|
|
maxTokens?: number
|
|
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
|
try {
|
|
return await this.imageSdk.models.generateContentStream({
|
|
model: modelId,
|
|
contents: contents,
|
|
config: {
|
|
responseModalities: ['Text', 'Image'],
|
|
responseMimeType: 'text/plain',
|
|
maxOutputTokens: maxTokens
|
|
}
|
|
})
|
|
} catch (error) {
|
|
console.error('Gemini API error:', error)
|
|
throw error
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 检查Gemini响应是否有效
|
|
* @param response - Gemini响应
|
|
* @returns 是否有效
|
|
*/
|
|
private isValidGeminiResponse(response: GenerateContentResponse): { isValid: boolean; message: string } {
|
|
return {
|
|
isValid: response?.candidates?.[0]?.finishReason === FinishReason.STOP ? true : false,
|
|
message: response?.candidates?.[0]?.finishReason || ''
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 处理Gemini图像响应
|
|
* @param response - Gemini响应
|
|
* @param onChunk - 处理生成块的回调
|
|
*/
|
|
private processGeminiImageResponse(response: any, onChunk: (chunk: ChunkCallbackData) => void): void {
|
|
const parts = response.candidates[0].content.parts
|
|
if (!parts) {
|
|
return
|
|
}
|
|
// 提取图像数据
|
|
const images = parts
|
|
.filter((part: Part) => part.inlineData)
|
|
.map((part: Part) => {
|
|
if (!part.inlineData) {
|
|
return null
|
|
}
|
|
const dataPrefix = `data:${part.inlineData.mimeType || 'image/png'};base64,`
|
|
return part.inlineData.data.startsWith('data:') ? part.inlineData.data : dataPrefix + part.inlineData.data
|
|
})
|
|
|
|
// 提取文本数据
|
|
const text = parts
|
|
.filter((part: Part) => part.text !== undefined)
|
|
.map((part: Part) => part.text)
|
|
.join('')
|
|
|
|
// 返回结果
|
|
onChunk({
|
|
text,
|
|
generateImage: {
|
|
type: 'base64',
|
|
images
|
|
},
|
|
usage: {
|
|
prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
|
|
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
|
|
total_tokens: response.usageMetadata?.totalTokenCount || 0
|
|
},
|
|
metrics: {
|
|
completion_tokens: response.usageMetadata?.candidatesTokenCount
|
|
}
|
|
})
|
|
}
|
|
|
|
/**
|
|
* Check if the model is valid
|
|
* @param model - The model
|
|
* @returns The validity of the model
|
|
*/
|
|
public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> {
|
|
if (!model) {
|
|
return { valid: false, error: new Error('No model found') }
|
|
}
|
|
|
|
const body = {
|
|
model: model.id,
|
|
messages: [{ role: 'user', content: 'hi' }],
|
|
max_tokens: 100,
|
|
stream: false
|
|
}
|
|
|
|
try {
|
|
const geminiModel = this.sdk.getGenerativeModel({ model: body.model }, this.requestOptions)
|
|
const result = await geminiModel.generateContent(body.messages[0].content)
|
|
return {
|
|
valid: !isEmpty(result.response.text()),
|
|
error: null
|
|
}
|
|
} catch (error: any) {
|
|
return {
|
|
valid: false,
|
|
error
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Get the models
|
|
* @returns The models
|
|
*/
|
|
public async models(): Promise<OpenAI.Models.Model[]> {
|
|
try {
|
|
const api = this.provider.apiHost + '/v1beta/models'
|
|
const { data } = await axios.get(api, { params: { key: this.apiKey } })
|
|
|
|
return data.models.map(
|
|
(m) =>
|
|
({
|
|
id: m.name.replace('models/', ''),
|
|
name: m.displayName,
|
|
description: m.description,
|
|
object: 'model',
|
|
created: Date.now(),
|
|
owned_by: 'gemini'
|
|
}) as OpenAI.Models.Model
|
|
)
|
|
} catch (error) {
|
|
return []
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Get the embedding dimensions
|
|
* @param model - The model
|
|
* @returns The embedding dimensions
|
|
*/
|
|
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
|
const data = await this.sdk.getGenerativeModel({ model: model.id }, this.requestOptions).embedContent('hi')
|
|
return data.embedding.values.length
|
|
}
|
|
}
|