refactor(GeminiProvider): implement image generation handling in chat responses

This commit is contained in:
suyao 2025-05-11 18:09:58 +08:00
parent 66939a5302
commit 476855c9c7
No known key found for this signature in database
2 changed files with 148 additions and 21 deletions

View File

@ -127,7 +127,6 @@ export default class GeminiProvider extends BaseProvider {
* @returns The message contents
*/
private async getMessageContents(message: Message): Promise<Content> {
console.log('getMessageContents', message)
const role = message.role === 'user' ? 'user' : 'model'
const parts: Part[] = [{ text: await this.getMessageContent(message) }]
// Add any generated images from previous responses
@ -197,6 +196,50 @@ export default class GeminiProvider extends BaseProvider {
}
}
private async getImageFileContents(message: Message): Promise<Content> {
const role = message.role === 'user' ? 'user' : 'model'
const content = getMainTextContent(message)
const parts: Part[] = [{ text: content }]
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (
imageBlock.metadata?.generateImageResponse?.images &&
imageBlock.metadata.generateImageResponse.images.length > 0
) {
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
if (imageUrl && imageUrl.startsWith('data:')) {
// Extract base64 data and mime type from the data URL
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
if (matches && matches.length === 3) {
const mimeType = matches[1]
const base64Data = matches[2]
parts.push({
inlineData: {
data: base64Data,
mimeType: mimeType
} as Part['inlineData']
})
}
}
}
}
const file = imageBlock.file
if (file) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
} as Part['inlineData']
})
}
}
return {
role,
parts: parts
}
}
/**
* Get the safety settings
* @returns The safety settings
@ -284,6 +327,18 @@ export default class GeminiProvider extends BaseProvider {
}: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
let canGenerateImage = false
if (isGenerateImageModel(model)) {
if (model.id === 'gemini-2.0-flash-exp') {
canGenerateImage = assistant.enableGenerateImage!
} else {
canGenerateImage = true
}
}
if (canGenerateImage) {
await this.generateImageByChat({ messages, assistant, onChunk })
return
}
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
const userMessages = filterUserRoleStartMessages(
@ -320,21 +375,10 @@ export default class GeminiProvider extends BaseProvider {
})
}
let canGenerateImage = false
if (isGenerateImageModel(model)) {
if (model.id === 'gemini-2.0-flash-exp') {
canGenerateImage = assistant.enableGenerateImage!
} else {
canGenerateImage = true
}
}
const generateContentConfig: GenerateContentConfig = {
responseModalities: canGenerateImage ? [Modality.TEXT, Modality.IMAGE] : undefined,
responseMimeType: canGenerateImage ? 'text/plain' : undefined,
safetySettings: this.getSafetySettings(),
// generate image don't need system instruction
systemInstruction: isGemmaModel(model) || canGenerateImage ? undefined : systemInstruction,
systemInstruction: isGemmaModel(model) ? undefined : systemInstruction,
temperature: assistant?.settings?.temperature,
topP: assistant?.settings?.topP,
maxOutputTokens: maxTokens,
@ -528,11 +572,6 @@ export default class GeminiProvider extends BaseProvider {
onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text })
}
const generateImage = this.processGeminiImageResponse(chunk, onChunk)
if (generateImage?.images?.length) {
onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage })
}
if (chunk.candidates?.[0]?.finishReason) {
if (chunk.text) {
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
@ -950,8 +989,97 @@ export default class GeminiProvider extends BaseProvider {
return data.embeddings?.[0]?.values?.length || 0
}
public generateImageByChat(): Promise<void> {
throw new Error('Method not implemented.')
public async generateImageByChat({ messages, assistant, onChunk }): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens } = getAssistantSettings(assistant)
const userMessages = filterUserRoleStartMessages(
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
)
const userLastMessage = userMessages.pop()
const { abortController } = this.createAbortController(userLastMessage?.id, true)
const { signal } = abortController
const generateContentConfig: GenerateContentConfig = {
responseModalities: [Modality.TEXT, Modality.IMAGE],
responseMimeType: 'text/plain',
safetySettings: this.getSafetySettings(),
temperature: assistant?.settings?.temperature,
topP: assistant?.settings?.top_p,
maxOutputTokens: maxTokens,
abortSignal: signal,
...this.getCustomParameters(assistant)
}
const history: Content[] = []
try {
for (const message of userMessages) {
history.push(await this.getImageFileContents(message))
}
let time_first_token_millsec = 0
const start_time_millsec = new Date().getTime()
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
const chat = this.sdk.chats.create({
model: model.id,
config: generateContentConfig,
history: history
})
let content = ''
const finalUsage: Usage = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0
}
const userMessage: Content = await this.getImageFileContents(userLastMessage!)
const response = await chat.sendMessageStream({
message: userMessage.parts!,
config: {
...generateContentConfig,
abortSignal: signal
}
})
for await (const chunk of response as AsyncGenerator<GenerateContentResponse>) {
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime()
}
if (chunk.text !== undefined) {
content += chunk.text
onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text })
}
const generateImage = this.processGeminiImageResponse(chunk, onChunk)
if (generateImage?.images?.length) {
onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage })
}
if (chunk.candidates?.[0]?.finishReason) {
if (chunk.text) {
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
}
if (chunk.usageMetadata) {
finalUsage.prompt_tokens = chunk.usageMetadata.promptTokenCount || 0
finalUsage.completion_tokens = chunk.usageMetadata.candidatesTokenCount || 0
finalUsage.total_tokens = chunk.usageMetadata.totalTokenCount || 0
}
}
}
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
usage: finalUsage,
metrics: {
completion_tokens: finalUsage.completion_tokens,
time_completion_millsec: new Date().getTime() - start_time_millsec,
time_first_token_millsec: time_first_token_millsec - start_time_millsec
}
}
})
} catch (error) {
console.error('[generateImageByChat] error', error)
onChunk({
type: ChunkType.ERROR,
error
})
}
}
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {

View File

@ -613,7 +613,6 @@ const fetchAndProcessAssistantResponseImpl = async (
const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant })
response.usage = usage
}
console.log('response', response)
}
if (response && response.metrics) {
if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) {