mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-23 18:10:26 +08:00
refactor(GeminiProvider): implement image generation handling in chat responses
This commit is contained in:
parent
66939a5302
commit
476855c9c7
@ -127,7 +127,6 @@ export default class GeminiProvider extends BaseProvider {
|
||||
* @returns The message contents
|
||||
*/
|
||||
private async getMessageContents(message: Message): Promise<Content> {
|
||||
console.log('getMessageContents', message)
|
||||
const role = message.role === 'user' ? 'user' : 'model'
|
||||
const parts: Part[] = [{ text: await this.getMessageContent(message) }]
|
||||
// Add any generated images from previous responses
|
||||
@ -197,6 +196,50 @@ export default class GeminiProvider extends BaseProvider {
|
||||
}
|
||||
}
|
||||
|
||||
private async getImageFileContents(message: Message): Promise<Content> {
|
||||
const role = message.role === 'user' ? 'user' : 'model'
|
||||
const content = getMainTextContent(message)
|
||||
const parts: Part[] = [{ text: content }]
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (
|
||||
imageBlock.metadata?.generateImageResponse?.images &&
|
||||
imageBlock.metadata.generateImageResponse.images.length > 0
|
||||
) {
|
||||
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
|
||||
if (imageUrl && imageUrl.startsWith('data:')) {
|
||||
// Extract base64 data and mime type from the data URL
|
||||
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
|
||||
if (matches && matches.length === 3) {
|
||||
const mimeType = matches[1]
|
||||
const base64Data = matches[2]
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType: mimeType
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const file = imageBlock.file
|
||||
if (file) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
return {
|
||||
role,
|
||||
parts: parts
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the safety settings
|
||||
* @returns The safety settings
|
||||
@ -284,6 +327,18 @@ export default class GeminiProvider extends BaseProvider {
|
||||
}: CompletionsParams): Promise<void> {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
let canGenerateImage = false
|
||||
if (isGenerateImageModel(model)) {
|
||||
if (model.id === 'gemini-2.0-flash-exp') {
|
||||
canGenerateImage = assistant.enableGenerateImage!
|
||||
} else {
|
||||
canGenerateImage = true
|
||||
}
|
||||
}
|
||||
if (canGenerateImage) {
|
||||
await this.generateImageByChat({ messages, assistant, onChunk })
|
||||
return
|
||||
}
|
||||
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
|
||||
|
||||
const userMessages = filterUserRoleStartMessages(
|
||||
@ -320,21 +375,10 @@ export default class GeminiProvider extends BaseProvider {
|
||||
})
|
||||
}
|
||||
|
||||
let canGenerateImage = false
|
||||
if (isGenerateImageModel(model)) {
|
||||
if (model.id === 'gemini-2.0-flash-exp') {
|
||||
canGenerateImage = assistant.enableGenerateImage!
|
||||
} else {
|
||||
canGenerateImage = true
|
||||
}
|
||||
}
|
||||
|
||||
const generateContentConfig: GenerateContentConfig = {
|
||||
responseModalities: canGenerateImage ? [Modality.TEXT, Modality.IMAGE] : undefined,
|
||||
responseMimeType: canGenerateImage ? 'text/plain' : undefined,
|
||||
safetySettings: this.getSafetySettings(),
|
||||
// generate image don't need system instruction
|
||||
systemInstruction: isGemmaModel(model) || canGenerateImage ? undefined : systemInstruction,
|
||||
systemInstruction: isGemmaModel(model) ? undefined : systemInstruction,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
topP: assistant?.settings?.topP,
|
||||
maxOutputTokens: maxTokens,
|
||||
@ -528,11 +572,6 @@ export default class GeminiProvider extends BaseProvider {
|
||||
onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text })
|
||||
}
|
||||
|
||||
const generateImage = this.processGeminiImageResponse(chunk, onChunk)
|
||||
if (generateImage?.images?.length) {
|
||||
onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage })
|
||||
}
|
||||
|
||||
if (chunk.candidates?.[0]?.finishReason) {
|
||||
if (chunk.text) {
|
||||
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
|
||||
@ -950,8 +989,97 @@ export default class GeminiProvider extends BaseProvider {
|
||||
return data.embeddings?.[0]?.values?.length || 0
|
||||
}
|
||||
|
||||
public generateImageByChat(): Promise<void> {
|
||||
throw new Error('Method not implemented.')
|
||||
public async generateImageByChat({ messages, assistant, onChunk }): Promise<void> {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens } = getAssistantSettings(assistant)
|
||||
const userMessages = filterUserRoleStartMessages(
|
||||
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
||||
)
|
||||
|
||||
const userLastMessage = userMessages.pop()
|
||||
const { abortController } = this.createAbortController(userLastMessage?.id, true)
|
||||
const { signal } = abortController
|
||||
const generateContentConfig: GenerateContentConfig = {
|
||||
responseModalities: [Modality.TEXT, Modality.IMAGE],
|
||||
responseMimeType: 'text/plain',
|
||||
safetySettings: this.getSafetySettings(),
|
||||
temperature: assistant?.settings?.temperature,
|
||||
topP: assistant?.settings?.top_p,
|
||||
maxOutputTokens: maxTokens,
|
||||
abortSignal: signal,
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
const history: Content[] = []
|
||||
try {
|
||||
for (const message of userMessages) {
|
||||
history.push(await this.getImageFileContents(message))
|
||||
}
|
||||
|
||||
let time_first_token_millsec = 0
|
||||
const start_time_millsec = new Date().getTime()
|
||||
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
const chat = this.sdk.chats.create({
|
||||
model: model.id,
|
||||
config: generateContentConfig,
|
||||
history: history
|
||||
})
|
||||
let content = ''
|
||||
const finalUsage: Usage = {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0
|
||||
}
|
||||
const userMessage: Content = await this.getImageFileContents(userLastMessage!)
|
||||
const response = await chat.sendMessageStream({
|
||||
message: userMessage.parts!,
|
||||
config: {
|
||||
...generateContentConfig,
|
||||
abortSignal: signal
|
||||
}
|
||||
})
|
||||
for await (const chunk of response as AsyncGenerator<GenerateContentResponse>) {
|
||||
if (time_first_token_millsec == 0) {
|
||||
time_first_token_millsec = new Date().getTime()
|
||||
}
|
||||
|
||||
if (chunk.text !== undefined) {
|
||||
content += chunk.text
|
||||
onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text })
|
||||
}
|
||||
const generateImage = this.processGeminiImageResponse(chunk, onChunk)
|
||||
if (generateImage?.images?.length) {
|
||||
onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage })
|
||||
}
|
||||
if (chunk.candidates?.[0]?.finishReason) {
|
||||
if (chunk.text) {
|
||||
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
|
||||
}
|
||||
if (chunk.usageMetadata) {
|
||||
finalUsage.prompt_tokens = chunk.usageMetadata.promptTokenCount || 0
|
||||
finalUsage.completion_tokens = chunk.usageMetadata.candidatesTokenCount || 0
|
||||
finalUsage.total_tokens = chunk.usageMetadata.totalTokenCount || 0
|
||||
}
|
||||
}
|
||||
}
|
||||
onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
usage: finalUsage,
|
||||
metrics: {
|
||||
completion_tokens: finalUsage.completion_tokens,
|
||||
time_completion_millsec: new Date().getTime() - start_time_millsec,
|
||||
time_first_token_millsec: time_first_token_millsec - start_time_millsec
|
||||
}
|
||||
}
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('[generateImageByChat] error', error)
|
||||
onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {
|
||||
|
||||
@ -613,7 +613,6 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant })
|
||||
response.usage = usage
|
||||
}
|
||||
console.log('response', response)
|
||||
}
|
||||
if (response && response.metrics) {
|
||||
if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user