cherry-studio/src/renderer/src/providers/AiProvider/GeminiProvider.ts
kangfenmao a64c8ded73 fix: message and rerank errors
commit 1c90e23d76a3e1008408bf29add122ccab7dbe6d
Merge: 4e792033 1fde0999
Author: kangfenmao <kangfenmao@qq.com>
Date:   Sun May 11 18:43:56 2025 +0800

    Merge branch 'main' into fix/next-release-bugs

commit 4e7920336d838501b900a18f1c254f0d17fb54b8
Author: suyao <sy20010504@gmail.com>
Date:   Sun May 11 18:09:58 2025 +0800

    refactor(GeminiProvider): implement image generation handling in chat responses

commit cd1ce4c0c65bc2a111ddb3112722cb36e66b2515
Merge: 968de188 235122c8
Author: suyao <sy20010504@gmail.com>
Date:   Sun May 11 16:45:29 2025 +0800

    Merge branch 'main' into fix/next-release-bugs

commit 968de18866abacc55fd9bd74c8d618871e64ade6
Author: suyao <sy20010504@gmail.com>
Date:   Sun May 11 16:41:38 2025 +0800

    fix: add new image generation models to the configuration

commit 1eaf5801b4c0e2c3fa1aa2ed829b20d97ea57d3f
Merge: cb76588d e6655fff
Author: suyao <sy20010504@gmail.com>
Date:   Sun May 11 13:14:17 2025 +0800

    Merge branch 'main' into fix/next-release-bugs

commit cb76588dc692f35da1f3d4fdbc9217c6a1a36501
Author: suyao <sy20010504@gmail.com>
Date:   Sun May 11 13:13:20 2025 +0800

    fix: enhance error handling and metrics tracking across AI providers and message processing

commit c2d6bdabc00b48419773d08c7d6630803d6310c8
Author: suyao <sy20010504@gmail.com>
Date:   Sun May 11 05:36:52 2025 +0800

    fix: update regex for function calling models and improve time tracking logic

commit 95340b87d0bba3cdcd173a181953afa42b26da9b
Author: suyao <sy20010504@gmail.com>
Date:   Sun May 11 05:05:05 2025 +0800

    fix: adjust thinking millisecond handling in message thunk

commit f4d4d3901603f14df616582fa537f3d9c1a694eb
Author: suyao <sy20010504@gmail.com>
Date:   Sun May 11 03:52:26 2025 +0800

    fix: remove 'auto' option from qwen model supported options

commit e26f603dfe1a9146b0575142363d5ceab30e32df
Author: suyao <sy20010504@gmail.com>
Date:   Sun May 11 03:50:05 2025 +0800

    fix: add support for inline base64 image data in image block

commit bb0093c656b2b72158db1bf7bfef6aae46b8096c
Merge: f9d1339b d39584fc
Author: suyao <sy20010504@gmail.com>
Date:   Sun May 11 03:00:57 2025 +0800

    Merge branch 'main' into fix/next-release-bugs

commit f9d1339bd3367a4f80da23aac1fdc73b4cd2a075
Author: suyao <sy20010504@gmail.com>
Date:   Sun May 11 03:00:11 2025 +0800

    fix:  set default zoomFactor in settings

commit 7cf6fd685662a012e2460e722edcbe5ed12f1a1c
Merge: ba9c4482 3bebfe27
Author: suyao <sy20010504@gmail.com>
Date:   Sun May 11 01:31:05 2025 +0800

    Merge branch 'main' into fix/next-release-bugs

commit ba9c44828578a3b5cc9fd1aaba80158615921785
Merge: 97dffe71 3bf0b6b3
Author: kangfenmao <kangfenmao@qq.com>
Date:   Sat May 10 20:18:44 2025 +0800

    Merge branch 'main' into fix/next-release-bugs

commit 97dffe719ee00ace2325c80022a48c44f03d6e26
Author: lizhixuan <daoquqiexing@gmail.com>
Date:   Sat May 10 11:38:01 2025 +0800

    feat(MessageMenubar): add edit option to dropdown for single message editing

commit 70157439a1d9778cacc87b1781a84d159a8d3f33
Author: lizhixuan <daoquqiexing@gmail.com>
Date:   Sat May 10 10:53:10 2025 +0800

    refactor(StreamProcessingService): comment out console.log for cleaner code

commit fa33ba77a9306ad316f34da4149858192079f7a2
Author: lizhixuan <daoquqiexing@gmail.com>
Date:   Sat May 10 10:52:08 2025 +0800

    refactor(messageThunk): remove console.log statements for cleaner code

commit 6544c5d2990adf1943195e1d4d11383859a05488
Author: kangfenmao <kangfenmao@qq.com>
Date:   Sat May 10 10:17:44 2025 +0800

    feat(i18n): add download success and failure messages in multiple languages

commit e23bb6744a4a99b2062012691340f78fad4e1952
Merge: 55c5c553 60cc1dee
Author: kangfenmao <kangfenmao@qq.com>
Date:   Sat May 10 09:54:38 2025 +0800

    Merge branch 'main' into fix/next-release-bugs

commit 55c5c5533eab46a1de9c5511433ed1b1b9e90512
Author: suyao <sy20010504@gmail.com>
Date:   Fri May 9 22:19:35 2025 +0800

    fix: update styled component props to use dollar sign prefix for consistency

commit 7a5839e0efdf3eb648b3d26cdf7ce131a2821f6a
Author: suyao <sy20010504@gmail.com>
Date:   Fri May 9 22:02:06 2025 +0800

    fix: prevent default action in handleLinkClick for better link handling

commit ecb075fddfc2bc5796a804ccde29aaa762d85da6
Merge: df149608 963f04f7
Author: suyao <sy20010504@gmail.com>
Date:   Fri May 9 21:55:54 2025 +0800

    Merge branch 'main' into fix/next-release-bugs

commit df149608904039903d74d7a72a722b6c1a567ee3
Merge: 93bd4eb9 e29a9303
Author: suyao <sy20010504@gmail.com>
Date:   Fri May 9 21:48:13 2025 +0800

    Merge branch 'main' into fix/next-release-bugs

commit 93bd4eb907816414af5c27efbaa2d3d014c707fc
Merge: c4d1deb6 38ff9b90
Author: suyao <sy20010504@gmail.com>
Date:   Fri May 9 21:27:08 2025 +0800

    Merge branch 'fix/next-release-bugs' of github.com:CherryHQ/cherry-studio into fix/next-release-bugs

commit c4d1deb6911977a23cf731db6bed80b8352557ff
Author: suyao <sy20010504@gmail.com>
Date:   Fri May 9 21:24:48 2025 +0800

    feat: enhance citation handling and add metadata support in citation blocks

commit 38ff9b90b8fee91ed1fba7b83c9470bc40bd3429
Author: MyPrototypeWhat <daoquqiexing@gmail.com>
Date:   Fri May 9 19:47:24 2025 +0800

    fix: enhance logging and update async handling in StreamProcessingService and messageThunk

    - Enabled logging in `createStreamProcessor` for better debugging.
    - Added logging for updated messages in `updateExistingMessageAndBlocksInDB` and `saveUpdatesToDB`.
    - Updated `onTextComplete` and `onLLMWebSearchComplete` to handle asynchronous operations correctly.
    - Commented out unused `saveUpdatedBlockToDB` calls to prevent unnecessary database updates.

commit cda0215c9c4e007c2c7240c3c9c8521fb7111774
Author: MyPrototypeWhat <daoquqiexing@gmail.com>
Date:   Fri May 9 18:47:55 2025 +0800

    refactor: optimize block update logic and remove unused code

    - Updated `throttledBlockUpdate` to handle asynchronous updates directly.
    - Removed the unused `throttledBlockDbUpdate` function and its related logic.
    - Added cancellation for throttled updates on error and completion to improve performance and reliability.
    - Cleaned up commented-out code for better readability.

commit de2f5b09c8384eabd4df7253047b838a2759671a
Author: MyPrototypeWhat <daoquqiexing@gmail.com>
Date:   Fri May 9 18:42:00 2025 +0800

    refactor: update message handling and state management

    - Simplified message editing logic by removing unnecessary success/error logging.
    - Added `updatedAt` timestamp to message updates for better tracking.
    - Refactored `editMessageBlocks` to accept message ID and updates directly.
    - Removed unused `getTopicLimit` function from `TopicManager`.
    - Updated message rendering to use `updatedAt` when available.
    - Enhanced type definitions to include `updatedAt` in message structure.

commit 700fa13971cafb04314817d2d8732c8fbf33c9d7
Author: suyao <sy20010504@gmail.com>
Date:   Fri May 9 16:19:55 2025 +0800

    Remove Zhipu mode and text-only link handling

commit 06bd1338cd671b255e477cec76b12663ea759f4c
Author: kangfenmao <kangfenmao@qq.com>
Date:   Fri May 9 15:49:02 2025 +0800

    fix: update citation rendering logic in MainTextBlock component

    - Added a check to determine if the citation URL is a valid link.
    - Updated citation tag formatting to conditionally include the link based on the URL validity.

commit e96c9a569f7708816f57505975d30667929eeb19
Author: kangfenmao <kangfenmao@qq.com>
Date:   Thu May 8 18:31:14 2025 +0800

    style: update ChatNavigation and CitationsList components for improved UI consistency

    - Added header style to remove borders in ChatNavigation.
    - Enhanced CitationsList with new Skeleton loading state and improved layout for citation cards.
    - Refactored CitationLink to a div for better styling control and adjusted padding in OpenButton for a more polished appearance.
2025-05-11 18:44:28 +08:00

1112 lines
33 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import {
Content,
File,
FunctionCall,
GenerateContentConfig,
GenerateContentResponse,
GoogleGenAI,
HarmBlockThreshold,
HarmCategory,
Modality,
Part,
PartUnion,
SafetySetting,
ThinkingConfig,
Tool
} from '@google/genai'
import { nanoid } from '@reduxjs/toolkit'
import {
findTokenLimit,
isGeminiReasoningModel,
isGemmaModel,
isGenerateImageModel,
isVisionModel,
isWebSearchModel
} from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService'
import {
filterContextMessages,
filterEmptyMessages,
filterUserRoleStartMessages
} from '@renderer/services/MessagesService'
import {
Assistant,
EFFORT_RATIO,
FileType,
FileTypes,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Metrics,
Model,
Provider,
Suggestion,
ToolCallResponse,
Usage,
WebSearchSource
} from '@renderer/types'
import { BlockCompleteChunk, Chunk, ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk'
import type { Message, Response } from '@renderer/types/newMessage'
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import {
geminiFunctionCallToMcpTool,
mcpToolCallResponseToGeminiMessage,
mcpToolsToGeminiTools,
parseAndCallTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { MB } from '@shared/config/constant'
import axios from 'axios'
import { flatten, isEmpty, takeRight } from 'lodash'
import OpenAI from 'openai'
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
export default class GeminiProvider extends BaseProvider {
private sdk: GoogleGenAI
constructor(provider: Provider) {
super(provider)
this.sdk = new GoogleGenAI({ vertexai: false, apiKey: this.apiKey, httpOptions: { baseUrl: this.getBaseURL() } })
}
public getBaseURL(): string {
return this.provider.apiHost
}
/**
* Handle a PDF file
* @param file - The file
* @returns The part
*/
private async handlePdfFile(file: FileType): Promise<Part> {
const smallFileSize = 20 * MB
const isSmallFile = file.size < smallFileSize
if (isSmallFile) {
const { data, mimeType } = await window.api.gemini.base64File(file)
return {
inlineData: {
data,
mimeType
} as Part['inlineData']
}
}
// Retrieve file from Gemini uploaded files
const fileMetadata: File | undefined = await window.api.gemini.retrieveFile(file, this.apiKey)
if (fileMetadata) {
return {
fileData: {
fileUri: fileMetadata.uri,
mimeType: fileMetadata.mimeType
} as Part['fileData']
}
}
// If file is not found, upload it to Gemini
const result = await window.api.gemini.uploadFile(file, this.apiKey)
return {
fileData: {
fileUri: result.uri,
mimeType: result.mimeType
} as Part['fileData']
}
}
/**
* Get the message contents
* @param message - The message
* @returns The message contents
*/
private async getMessageContents(message: Message): Promise<Content> {
const role = message.role === 'user' ? 'user' : 'model'
const parts: Part[] = [{ text: await this.getMessageContent(message) }]
// Add any generated images from previous responses
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (
imageBlock.metadata?.generateImageResponse?.images &&
imageBlock.metadata.generateImageResponse.images.length > 0
) {
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
if (imageUrl && imageUrl.startsWith('data:')) {
// Extract base64 data and mime type from the data URL
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
if (matches && matches.length === 3) {
const mimeType = matches[1]
const base64Data = matches[2]
parts.push({
inlineData: {
data: base64Data,
mimeType: mimeType
} as Part['inlineData']
})
}
}
}
}
const file = imageBlock.file
if (file) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
} as Part['inlineData']
})
}
}
const fileBlocks = findFileBlocks(message)
for (const fileBlock of fileBlocks) {
const file = fileBlock.file
if (file.type === FileTypes.IMAGE) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
} as Part['inlineData']
})
}
if (file.ext === '.pdf') {
parts.push(await this.handlePdfFile(file))
continue
}
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
parts.push({
text: file.origin_name + '\n' + fileContent
})
}
}
return {
role,
parts: parts
}
}
private async getImageFileContents(message: Message): Promise<Content> {
const role = message.role === 'user' ? 'user' : 'model'
const content = getMainTextContent(message)
const parts: Part[] = [{ text: content }]
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (
imageBlock.metadata?.generateImageResponse?.images &&
imageBlock.metadata.generateImageResponse.images.length > 0
) {
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
if (imageUrl && imageUrl.startsWith('data:')) {
// Extract base64 data and mime type from the data URL
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
if (matches && matches.length === 3) {
const mimeType = matches[1]
const base64Data = matches[2]
parts.push({
inlineData: {
data: base64Data,
mimeType: mimeType
} as Part['inlineData']
})
}
}
}
}
const file = imageBlock.file
if (file) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
} as Part['inlineData']
})
}
}
return {
role,
parts: parts
}
}
/**
* Get the safety settings
* @returns The safety settings
*/
private getSafetySettings(): SafetySetting[] {
const safetyThreshold = 'OFF' as HarmBlockThreshold
return [
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
threshold: HarmBlockThreshold.BLOCK_NONE
}
]
}
/**
* Get the reasoning effort for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getBudgetToken(assistant: Assistant, model: Model) {
if (isGeminiReasoningModel(model)) {
const reasoningEffort = assistant?.settings?.reasoning_effort
// 如果thinking_budget是undefined不思考
if (reasoningEffort === undefined) {
return {
thinkingConfig: {
includeThoughts: false
} as ThinkingConfig
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
if (effortRatio > 1) {
return {}
}
const { max } = findTokenLimit(model.id) || { max: 0 }
// 如果thinking_budget是明确设置的值包括0使用该值
return {
thinkingConfig: {
thinkingBudget: Math.floor(max * effortRatio),
includeThoughts: true
} as ThinkingConfig
}
}
return {}
}
/**
* Generate completions
* @param messages - The messages
* @param assistant - The assistant
* @param mcpTools - The MCP tools
* @param onChunk - The onChunk callback
* @param onFilterMessages - The onFilterMessages callback
*/
public async completions({
messages,
assistant,
mcpTools,
onChunk,
onFilterMessages
}: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
let canGenerateImage = false
if (isGenerateImageModel(model)) {
if (model.id === 'gemini-2.0-flash-exp') {
canGenerateImage = assistant.enableGenerateImage!
} else {
canGenerateImage = true
}
}
if (canGenerateImage) {
await this.generateImageByChat({ messages, assistant, onChunk })
return
}
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
const userMessages = filterUserRoleStartMessages(
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
)
onFilterMessages(userMessages)
const userLastMessage = userMessages.pop()
const history: Content[] = []
for (const message of userMessages) {
history.push(await this.getMessageContents(message))
}
let systemInstruction = assistant.prompt
const { tools } = this.setupToolsConfig<Tool>({
mcpTools,
model,
enableToolUse
})
if (this.useSystemPromptForTools) {
systemInstruction = buildSystemPrompt(assistant.prompt || '', mcpTools)
}
const toolResponses: MCPToolResponse[] = []
if (assistant.enableWebSearch && isWebSearchModel(model)) {
tools.push({
// @ts-ignore googleSearch is not a valid tool for Gemini
googleSearch: {}
})
}
const generateContentConfig: GenerateContentConfig = {
safetySettings: this.getSafetySettings(),
// generate image don't need system instruction
systemInstruction: isGemmaModel(model) ? undefined : systemInstruction,
temperature: assistant?.settings?.temperature,
topP: assistant?.settings?.topP,
maxOutputTokens: maxTokens,
tools: tools,
...this.getBudgetToken(assistant, model),
...this.getCustomParameters(assistant)
}
const messageContents: Content = await this.getMessageContents(userLastMessage!)
const chat = this.sdk.chats.create({
model: model.id,
config: generateContentConfig,
history: history
})
if (isGemmaModel(model) && assistant.prompt) {
const isFirstMessage = history.length === 0
if (isFirstMessage && messageContents) {
const systemMessage = [
{
text:
'<start_of_turn>user\n' +
systemInstruction +
'<end_of_turn>\n' +
'<start_of_turn>user\n' +
(messageContents?.parts?.[0] as Part).text +
'<end_of_turn>'
}
] as Part[]
if (messageContents && messageContents.parts) {
messageContents.parts[0] = systemMessage[0]
}
}
}
const finalUsage: Usage = {
completion_tokens: 0,
prompt_tokens: 0,
total_tokens: 0
}
const finalMetrics: Metrics = {
completion_tokens: 0,
time_completion_millsec: 0,
time_first_token_millsec: 0
}
const { cleanup, abortController } = this.createAbortController(userLastMessage?.id, true)
const processToolResults = async (toolResults: Awaited<ReturnType<typeof parseAndCallTools>>, idx: number) => {
if (toolResults.length === 0) return
const newChat = this.sdk.chats.create({
model: model.id,
config: generateContentConfig,
history: history as Content[]
})
const newStream = await newChat.sendMessageStream({
message: flatten(toolResults.map((ts) => (ts as Content).parts)) as PartUnion,
config: {
...generateContentConfig,
abortSignal: abortController.signal
}
})
await processStream(newStream, idx + 1)
}
const processToolCalls = async (toolCalls: FunctionCall[]) => {
const mcpToolResponses: ToolCallResponse[] = toolCalls
.map((toolCall) => {
const mcpTool = geminiFunctionCallToMcpTool(mcpTools, toolCall)
if (!mcpTool) return undefined
const parsedArgs = (() => {
try {
return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args
} catch {
return toolCall.args
}
})()
return {
id: toolCall.id || nanoid(),
toolCallId: toolCall.id,
tool: mcpTool,
arguments: parsedArgs,
status: 'pending'
} as ToolCallResponse
})
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
return await parseAndCallTools(
mcpToolResponses,
toolResponses,
onChunk,
this.mcpToolCallResponseToMessage,
model,
mcpTools
)
}
const processToolUses = async (content: string) => {
return await parseAndCallTools(
content,
toolResponses,
onChunk,
this.mcpToolCallResponseToMessage,
model,
mcpTools
)
}
const processStream = async (
stream: AsyncGenerator<GenerateContentResponse> | GenerateContentResponse,
idx: number
) => {
history.push(messageContents)
let functionCalls: FunctionCall[] = []
let time_first_token_millsec = 0
const start_time_millsec = new Date().getTime()
if (stream instanceof GenerateContentResponse) {
let content = ''
const time_completion_millsec = new Date().getTime() - start_time_millsec
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
if (stream.text?.length) {
toolResults.push(...(await processToolUses(stream.text)))
}
stream.candidates?.forEach((candidate) => {
if (candidate.content) {
history.push(candidate.content)
candidate.content.parts?.forEach((part) => {
if (part.functionCall) {
functionCalls.push(part.functionCall)
}
if (part.text) {
content += part.text
onChunk({ type: ChunkType.TEXT_DELTA, text: part.text })
}
})
}
})
if (content.length) {
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
}
if (functionCalls.length) {
toolResults.push(...(await processToolCalls(functionCalls)))
}
if (stream.text?.length) {
toolResults.push(...(await processToolUses(stream.text)))
}
if (toolResults.length) {
await processToolResults(toolResults, idx)
}
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
text: stream.text,
usage: {
prompt_tokens: stream.usageMetadata?.promptTokenCount || 0,
thoughts_tokens: stream.usageMetadata?.thoughtsTokenCount || 0,
completion_tokens: stream.usageMetadata?.candidatesTokenCount || 0,
total_tokens: stream.usageMetadata?.totalTokenCount || 0
},
metrics: {
completion_tokens: stream.usageMetadata?.candidatesTokenCount,
time_completion_millsec,
time_first_token_millsec: 0
},
webSearch: {
results: stream.candidates?.[0]?.groundingMetadata,
source: 'gemini'
}
} as Response
} as BlockCompleteChunk)
} else {
let content = ''
for await (const chunk of stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime()
}
if (chunk.text !== undefined) {
content += chunk.text
onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text })
}
if (chunk.candidates?.[0]?.finishReason) {
if (chunk.text) {
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
}
if (chunk.usageMetadata) {
finalUsage.prompt_tokens += chunk.usageMetadata.promptTokenCount || 0
finalUsage.completion_tokens += chunk.usageMetadata.candidatesTokenCount || 0
finalUsage.total_tokens += chunk.usageMetadata.totalTokenCount || 0
}
if (chunk.candidates?.[0]?.groundingMetadata) {
const groundingMetadata = chunk.candidates?.[0]?.groundingMetadata
onChunk({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: groundingMetadata,
source: WebSearchSource.GEMINI
}
} as LLMWebSearchCompleteChunk)
}
if (chunk.functionCalls) {
chunk.candidates?.forEach((candidate) => {
if (candidate.content) {
history.push(candidate.content)
}
})
functionCalls = functionCalls.concat(chunk.functionCalls)
}
finalMetrics.completion_tokens = finalUsage.completion_tokens
finalMetrics.time_completion_millsec += new Date().getTime() - start_time_millsec
finalMetrics.time_first_token_millsec =
(finalMetrics.time_first_token_millsec || 0) + (time_first_token_millsec - start_time_millsec)
}
}
// --- End Incremental onChunk calls ---
// Call processToolUses AFTER potentially processing text content in this chunk
// This assumes tools might be specified within the text stream
// Note: parseAndCallTools inside should handle its own onChunk for tool responses
let toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
if (functionCalls.length) {
toolResults = await processToolCalls(functionCalls)
}
if (content.length) {
toolResults = toolResults.concat(await processToolUses(content))
}
if (toolResults.length) {
await processToolResults(toolResults, idx)
}
// FIXME: 由于递归会发送n次
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
usage: finalUsage,
metrics: finalMetrics
}
})
}
}
if (!streamOutput) {
const response = await chat.sendMessage({
message: messageContents as PartUnion,
config: {
...generateContentConfig,
abortSignal: abortController.signal
}
})
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
return await processStream(response, 0).then(cleanup)
}
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
const userMessagesStream = await chat.sendMessageStream({
message: messageContents as PartUnion,
config: {
...generateContentConfig,
abortSignal: abortController.signal
}
})
await processStream(userMessagesStream, 0).finally(cleanup)
}
/**
* Translate a message
* @param content
* @param assistant - The assistant
* @param onResponse - The onResponse callback
* @returns The translated message
*/
public async translate(
content: string,
assistant: Assistant,
onResponse?: (text: string, isComplete: boolean) => void
) {
const defaultModel = getDefaultModel()
const { maxTokens } = getAssistantSettings(assistant)
const model = assistant.model || defaultModel
const _content =
isGemmaModel(model) && assistant.prompt
? `<start_of_turn>user\n${assistant.prompt}<end_of_turn>\n<start_of_turn>user\n${content}<end_of_turn>`
: content
if (!onResponse) {
const response = await this.sdk.models.generateContent({
model: model.id,
config: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature,
systemInstruction: isGemmaModel(model) ? undefined : assistant.prompt
},
contents: [
{
role: 'user',
parts: [{ text: _content }]
}
]
})
return response.text || ''
}
const response = await this.sdk.models.generateContentStream({
model: model.id,
config: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature,
systemInstruction: isGemmaModel(model) ? undefined : assistant.prompt
},
contents: [
{
role: 'user',
parts: [{ text: content }]
}
]
})
let text = ''
for await (const chunk of response) {
text += chunk.text
onResponse?.(text, false)
}
onResponse?.(text, true)
return text
}
/**
* Summarize a message
* @param messages - The messages
* @param assistant - The assistant
* @returns The summary
*/
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages = takeRight(messages, 5)
.filter((message) => !message.isPreset)
.map((message) => ({
role: message.role,
// Get content using helper
content: getMainTextContent(message)
}))
const userMessageContent = userMessages.reduce((prev, curr) => {
const content = curr.role === 'user' ? `User: ${curr.content}` : `Assistant: ${curr.content}`
return prev + (prev ? '\n' : '') + content
}, '')
const systemMessage = {
role: 'system',
content: (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title')
}
const userMessage = {
role: 'user',
content: userMessageContent
}
const content = isGemmaModel(model)
? `<start_of_turn>user\n${systemMessage.content}<end_of_turn>\n<start_of_turn>user\n${userMessage.content}<end_of_turn>`
: userMessage.content
const response = await this.sdk.models.generateContent({
model: model.id,
config: {
systemInstruction: isGemmaModel(model) ? undefined : systemMessage.content
},
contents: [
{
role: 'user',
parts: [{ text: content }]
}
]
})
return removeSpecialCharactersForTopicName(response.text || '')
}
/**
* Generate text
* @param prompt - The prompt
* @param content - The content
* @returns The generated text
*/
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel()
const MessageContent = isGemmaModel(model)
? `<start_of_turn>user\n${prompt}<end_of_turn>\n<start_of_turn>user\n${content}<end_of_turn>`
: content
const response = await this.sdk.models.generateContent({
model: model.id,
config: {
systemInstruction: isGemmaModel(model) ? undefined : prompt
},
contents: [
{
role: 'user',
parts: [{ text: MessageContent }]
}
]
})
return response.text || ''
}
/**
* Generate suggestions
* @returns The suggestions
*/
public async suggestions(): Promise<Suggestion[]> {
return []
}
/**
* Summarize a message for search
* @param messages - The messages
* @param assistant - The assistant
* @returns The summary
*/
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string> {
const model = assistant.model || getDefaultModel()
const systemMessage = {
role: 'system',
content: assistant.prompt
}
// Get content using helper
const userMessageContent = messages.map(getMainTextContent).join('\n')
const content = isGemmaModel(model)
? `<start_of_turn>user\n${systemMessage.content}<end_of_turn>\n<start_of_turn>user\n${userMessageContent}<end_of_turn>`
: userMessageContent
const lastUserMessage = messages[messages.length - 1]
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
const { signal } = abortController
const response = await this.sdk.models
.generateContent({
model: model.id,
config: {
systemInstruction: isGemmaModel(model) ? undefined : systemMessage.content,
temperature: assistant?.settings?.temperature,
httpOptions: {
timeout: 20 * 1000
},
abortSignal: signal
},
contents: [
{
role: 'user',
parts: [{ text: content }]
}
]
})
.finally(cleanup)
return response.text || ''
}
/**
* Generate an image
* @returns The generated image
*/
public async generateImage(): Promise<string[]> {
return []
}
/**
* 处理Gemini图像响应
* @param chunk
* @param onChunk - 处理生成块的回调
*/
private processGeminiImageResponse(
chunk: GenerateContentResponse,
onChunk: (chunk: Chunk) => void
): { type: 'base64'; images: string[] } | undefined {
const parts = chunk.candidates?.[0]?.content?.parts
if (!parts) {
return
}
// 提取图像数据
const images = parts
.filter((part: Part) => part.inlineData)
.map((part: Part) => {
if (!part.inlineData) {
return null
}
// onChunk的位置需要更改
onChunk({
type: ChunkType.IMAGE_CREATED
})
const dataPrefix = `data:${part.inlineData.mimeType || 'image/png'};base64,`
return part.inlineData.data?.startsWith('data:') ? part.inlineData.data : dataPrefix + part.inlineData.data
})
return {
type: 'base64',
images: images.filter((image) => image !== null)
}
}
/**
* Check if the model is valid
* @param model - The model
* @param stream - Whether to use streaming interface
* @returns The validity of the model
*/
public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> {
if (!model) {
return { valid: false, error: new Error('No model found') }
}
try {
if (!stream) {
const result = await this.sdk.models.generateContent({
model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }],
config: {
maxOutputTokens: 100
}
})
if (isEmpty(result.text)) {
throw new Error('Empty response')
}
} else {
const response = await this.sdk.models.generateContentStream({
model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }],
config: {
maxOutputTokens: 100
}
})
// 等待整个流式响应结束
let hasContent = false
for await (const chunk of response) {
if (chunk.text && chunk.text.length > 0) {
hasContent = true
break
}
}
if (!hasContent) {
throw new Error('Empty streaming response')
}
}
return { valid: true, error: null }
} catch (error: any) {
return {
valid: false,
error
}
}
}
/**
* Get the models
* @returns The models
*/
public async models(): Promise<OpenAI.Models.Model[]> {
try {
const api = this.provider.apiHost + '/v1beta/models'
const { data } = await axios.get(api, { params: { key: this.apiKey } })
return data.models.map(
(m) =>
({
id: m.name.replace('models/', ''),
name: m.displayName,
description: m.description,
object: 'model',
created: Date.now(),
owned_by: 'gemini'
}) as OpenAI.Models.Model
)
} catch (error) {
return []
}
}
/**
* Get the embedding dimensions
* @param model - The model
* @returns The embedding dimensions
*/
public async getEmbeddingDimensions(model: Model): Promise<number> {
const data = await this.sdk.models.embedContent({
model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
})
return data.embeddings?.[0]?.values?.length || 0
}
public async generateImageByChat({ messages, assistant, onChunk }): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens } = getAssistantSettings(assistant)
const userMessages = filterUserRoleStartMessages(
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
)
const userLastMessage = userMessages.pop()
const { abortController } = this.createAbortController(userLastMessage?.id, true)
const { signal } = abortController
const generateContentConfig: GenerateContentConfig = {
responseModalities: [Modality.TEXT, Modality.IMAGE],
responseMimeType: 'text/plain',
safetySettings: this.getSafetySettings(),
temperature: assistant?.settings?.temperature,
topP: assistant?.settings?.top_p,
maxOutputTokens: maxTokens,
abortSignal: signal,
...this.getCustomParameters(assistant)
}
const history: Content[] = []
try {
for (const message of userMessages) {
history.push(await this.getImageFileContents(message))
}
let time_first_token_millsec = 0
const start_time_millsec = new Date().getTime()
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
const chat = this.sdk.chats.create({
model: model.id,
config: generateContentConfig,
history: history
})
let content = ''
const finalUsage: Usage = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0
}
const userMessage: Content = await this.getImageFileContents(userLastMessage!)
const response = await chat.sendMessageStream({
message: userMessage.parts!,
config: {
...generateContentConfig,
abortSignal: signal
}
})
for await (const chunk of response as AsyncGenerator<GenerateContentResponse>) {
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime()
}
if (chunk.text !== undefined) {
content += chunk.text
onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text })
}
const generateImage = this.processGeminiImageResponse(chunk, onChunk)
if (generateImage?.images?.length) {
onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage })
}
if (chunk.candidates?.[0]?.finishReason) {
if (chunk.text) {
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
}
if (chunk.usageMetadata) {
finalUsage.prompt_tokens = chunk.usageMetadata.promptTokenCount || 0
finalUsage.completion_tokens = chunk.usageMetadata.candidatesTokenCount || 0
finalUsage.total_tokens = chunk.usageMetadata.totalTokenCount || 0
}
}
}
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
usage: finalUsage,
metrics: {
completion_tokens: finalUsage.completion_tokens,
time_completion_millsec: new Date().getTime() - start_time_millsec,
time_first_token_millsec: time_first_token_millsec - start_time_millsec
}
}
})
} catch (error) {
console.error('[generateImageByChat] error', error)
onChunk({
type: ChunkType.ERROR,
error
})
}
}
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {
return mcpToolsToGeminiTools(mcpTools) as T[]
}
public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model))
} else if ('toolCallId' in mcpToolResponse) {
return {
role: 'user',
parts: [
{
functionResponse: {
id: mcpToolResponse.toolCallId,
name: mcpToolResponse.tool.id,
response: {
output: !resp.isError ? resp.content : undefined,
error: resp.isError ? resp.content : undefined
}
}
}
]
} satisfies Content
}
return
}
}