mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-23 18:10:26 +08:00
refactor(GeminiProvider): implement image generation handling in chat responses
This commit is contained in:
parent
66939a5302
commit
476855c9c7
@ -127,7 +127,6 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
* @returns The message contents
|
* @returns The message contents
|
||||||
*/
|
*/
|
||||||
private async getMessageContents(message: Message): Promise<Content> {
|
private async getMessageContents(message: Message): Promise<Content> {
|
||||||
console.log('getMessageContents', message)
|
|
||||||
const role = message.role === 'user' ? 'user' : 'model'
|
const role = message.role === 'user' ? 'user' : 'model'
|
||||||
const parts: Part[] = [{ text: await this.getMessageContent(message) }]
|
const parts: Part[] = [{ text: await this.getMessageContent(message) }]
|
||||||
// Add any generated images from previous responses
|
// Add any generated images from previous responses
|
||||||
@ -197,6 +196,50 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async getImageFileContents(message: Message): Promise<Content> {
|
||||||
|
const role = message.role === 'user' ? 'user' : 'model'
|
||||||
|
const content = getMainTextContent(message)
|
||||||
|
const parts: Part[] = [{ text: content }]
|
||||||
|
const imageBlocks = findImageBlocks(message)
|
||||||
|
for (const imageBlock of imageBlocks) {
|
||||||
|
if (
|
||||||
|
imageBlock.metadata?.generateImageResponse?.images &&
|
||||||
|
imageBlock.metadata.generateImageResponse.images.length > 0
|
||||||
|
) {
|
||||||
|
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
|
||||||
|
if (imageUrl && imageUrl.startsWith('data:')) {
|
||||||
|
// Extract base64 data and mime type from the data URL
|
||||||
|
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
|
||||||
|
if (matches && matches.length === 3) {
|
||||||
|
const mimeType = matches[1]
|
||||||
|
const base64Data = matches[2]
|
||||||
|
parts.push({
|
||||||
|
inlineData: {
|
||||||
|
data: base64Data,
|
||||||
|
mimeType: mimeType
|
||||||
|
} as Part['inlineData']
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const file = imageBlock.file
|
||||||
|
if (file) {
|
||||||
|
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||||
|
parts.push({
|
||||||
|
inlineData: {
|
||||||
|
data: base64Data.base64,
|
||||||
|
mimeType: base64Data.mime
|
||||||
|
} as Part['inlineData']
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
role,
|
||||||
|
parts: parts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the safety settings
|
* Get the safety settings
|
||||||
* @returns The safety settings
|
* @returns The safety settings
|
||||||
@ -284,6 +327,18 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
}: CompletionsParams): Promise<void> {
|
}: CompletionsParams): Promise<void> {
|
||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
|
let canGenerateImage = false
|
||||||
|
if (isGenerateImageModel(model)) {
|
||||||
|
if (model.id === 'gemini-2.0-flash-exp') {
|
||||||
|
canGenerateImage = assistant.enableGenerateImage!
|
||||||
|
} else {
|
||||||
|
canGenerateImage = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (canGenerateImage) {
|
||||||
|
await this.generateImageByChat({ messages, assistant, onChunk })
|
||||||
|
return
|
||||||
|
}
|
||||||
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
|
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
|
||||||
|
|
||||||
const userMessages = filterUserRoleStartMessages(
|
const userMessages = filterUserRoleStartMessages(
|
||||||
@ -320,21 +375,10 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
let canGenerateImage = false
|
|
||||||
if (isGenerateImageModel(model)) {
|
|
||||||
if (model.id === 'gemini-2.0-flash-exp') {
|
|
||||||
canGenerateImage = assistant.enableGenerateImage!
|
|
||||||
} else {
|
|
||||||
canGenerateImage = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const generateContentConfig: GenerateContentConfig = {
|
const generateContentConfig: GenerateContentConfig = {
|
||||||
responseModalities: canGenerateImage ? [Modality.TEXT, Modality.IMAGE] : undefined,
|
|
||||||
responseMimeType: canGenerateImage ? 'text/plain' : undefined,
|
|
||||||
safetySettings: this.getSafetySettings(),
|
safetySettings: this.getSafetySettings(),
|
||||||
// generate image don't need system instruction
|
// generate image don't need system instruction
|
||||||
systemInstruction: isGemmaModel(model) || canGenerateImage ? undefined : systemInstruction,
|
systemInstruction: isGemmaModel(model) ? undefined : systemInstruction,
|
||||||
temperature: assistant?.settings?.temperature,
|
temperature: assistant?.settings?.temperature,
|
||||||
topP: assistant?.settings?.topP,
|
topP: assistant?.settings?.topP,
|
||||||
maxOutputTokens: maxTokens,
|
maxOutputTokens: maxTokens,
|
||||||
@ -528,11 +572,6 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text })
|
onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text })
|
||||||
}
|
}
|
||||||
|
|
||||||
const generateImage = this.processGeminiImageResponse(chunk, onChunk)
|
|
||||||
if (generateImage?.images?.length) {
|
|
||||||
onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage })
|
|
||||||
}
|
|
||||||
|
|
||||||
if (chunk.candidates?.[0]?.finishReason) {
|
if (chunk.candidates?.[0]?.finishReason) {
|
||||||
if (chunk.text) {
|
if (chunk.text) {
|
||||||
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
|
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
|
||||||
@ -950,8 +989,97 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
return data.embeddings?.[0]?.values?.length || 0
|
return data.embeddings?.[0]?.values?.length || 0
|
||||||
}
|
}
|
||||||
|
|
||||||
public generateImageByChat(): Promise<void> {
|
public async generateImageByChat({ messages, assistant, onChunk }): Promise<void> {
|
||||||
throw new Error('Method not implemented.')
|
const defaultModel = getDefaultModel()
|
||||||
|
const model = assistant.model || defaultModel
|
||||||
|
const { contextCount, maxTokens } = getAssistantSettings(assistant)
|
||||||
|
const userMessages = filterUserRoleStartMessages(
|
||||||
|
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
||||||
|
)
|
||||||
|
|
||||||
|
const userLastMessage = userMessages.pop()
|
||||||
|
const { abortController } = this.createAbortController(userLastMessage?.id, true)
|
||||||
|
const { signal } = abortController
|
||||||
|
const generateContentConfig: GenerateContentConfig = {
|
||||||
|
responseModalities: [Modality.TEXT, Modality.IMAGE],
|
||||||
|
responseMimeType: 'text/plain',
|
||||||
|
safetySettings: this.getSafetySettings(),
|
||||||
|
temperature: assistant?.settings?.temperature,
|
||||||
|
topP: assistant?.settings?.top_p,
|
||||||
|
maxOutputTokens: maxTokens,
|
||||||
|
abortSignal: signal,
|
||||||
|
...this.getCustomParameters(assistant)
|
||||||
|
}
|
||||||
|
const history: Content[] = []
|
||||||
|
try {
|
||||||
|
for (const message of userMessages) {
|
||||||
|
history.push(await this.getImageFileContents(message))
|
||||||
|
}
|
||||||
|
|
||||||
|
let time_first_token_millsec = 0
|
||||||
|
const start_time_millsec = new Date().getTime()
|
||||||
|
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||||
|
const chat = this.sdk.chats.create({
|
||||||
|
model: model.id,
|
||||||
|
config: generateContentConfig,
|
||||||
|
history: history
|
||||||
|
})
|
||||||
|
let content = ''
|
||||||
|
const finalUsage: Usage = {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
completion_tokens: 0,
|
||||||
|
total_tokens: 0
|
||||||
|
}
|
||||||
|
const userMessage: Content = await this.getImageFileContents(userLastMessage!)
|
||||||
|
const response = await chat.sendMessageStream({
|
||||||
|
message: userMessage.parts!,
|
||||||
|
config: {
|
||||||
|
...generateContentConfig,
|
||||||
|
abortSignal: signal
|
||||||
|
}
|
||||||
|
})
|
||||||
|
for await (const chunk of response as AsyncGenerator<GenerateContentResponse>) {
|
||||||
|
if (time_first_token_millsec == 0) {
|
||||||
|
time_first_token_millsec = new Date().getTime()
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chunk.text !== undefined) {
|
||||||
|
content += chunk.text
|
||||||
|
onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text })
|
||||||
|
}
|
||||||
|
const generateImage = this.processGeminiImageResponse(chunk, onChunk)
|
||||||
|
if (generateImage?.images?.length) {
|
||||||
|
onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage })
|
||||||
|
}
|
||||||
|
if (chunk.candidates?.[0]?.finishReason) {
|
||||||
|
if (chunk.text) {
|
||||||
|
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
|
||||||
|
}
|
||||||
|
if (chunk.usageMetadata) {
|
||||||
|
finalUsage.prompt_tokens = chunk.usageMetadata.promptTokenCount || 0
|
||||||
|
finalUsage.completion_tokens = chunk.usageMetadata.candidatesTokenCount || 0
|
||||||
|
finalUsage.total_tokens = chunk.usageMetadata.totalTokenCount || 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.BLOCK_COMPLETE,
|
||||||
|
response: {
|
||||||
|
usage: finalUsage,
|
||||||
|
metrics: {
|
||||||
|
completion_tokens: finalUsage.completion_tokens,
|
||||||
|
time_completion_millsec: new Date().getTime() - start_time_millsec,
|
||||||
|
time_first_token_millsec: time_first_token_millsec - start_time_millsec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error('[generateImageByChat] error', error)
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.ERROR,
|
||||||
|
error
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {
|
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {
|
||||||
|
|||||||
@ -613,7 +613,6 @@ const fetchAndProcessAssistantResponseImpl = async (
|
|||||||
const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant })
|
const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant })
|
||||||
response.usage = usage
|
response.usage = usage
|
||||||
}
|
}
|
||||||
console.log('response', response)
|
|
||||||
}
|
}
|
||||||
if (response && response.metrics) {
|
if (response && response.metrics) {
|
||||||
if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) {
|
if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user