feat(OpenAIProvider): support image edit (#5469)

* feat(OpenAIProvider): support image edit

* fix: can edit image

* feat: add upload situation

* chore: optimize  abort

* fix: image cannot read

* chore(SvgSpinners180Ring): remove unused React import

* refactor(FileManager): simplify file reading logic by storing intermediate data

---------

Co-authored-by: eeee0717 <chentao020717Work@outlook.com>
This commit is contained in:
SuYao 2025-05-01 12:38:33 +08:00 committed by GitHub
parent 23dd66d4a1
commit b43702f21d
18 changed files with 186 additions and 61 deletions

View File

@ -118,6 +118,7 @@
"@eslint/js": "^9.22.0",
"@google/genai": "^0.10.0",
"@hello-pangea/dnd": "^16.6.0",
"@iconify-json/svg-spinners": "^1.2.2",
"@kangfenmao/keyv-storage": "^0.1.0",
"@modelcontextprotocol/sdk": "^1.10.2",
"@mozilla/readability": "^0.6.0",

View File

@ -109,7 +109,7 @@ export enum IpcChannel {
File_Base64Image = 'file:base64Image',
File_Download = 'file:download',
File_Copy = 'file:copy',
File_BinaryFile = 'file:binaryFile',
File_BinaryImage = 'file:binaryImage',
Fs_Read = 'fs:read',

View File

@ -215,7 +215,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.File_Base64Image, fileManager.base64Image)
ipcMain.handle(IpcChannel.File_Download, fileManager.downloadFile)
ipcMain.handle(IpcChannel.File_Copy, fileManager.copyFile)
ipcMain.handle(IpcChannel.File_BinaryFile, fileManager.binaryFile)
ipcMain.handle(IpcChannel.File_BinaryImage, fileManager.binaryImage)
// fs
ipcMain.handle(IpcChannel.Fs_Read, FileService.readFile)

View File

@ -263,7 +263,7 @@ class FileStorage {
}
}
public binaryFile = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: Buffer; mime: string }> => {
public binaryImage = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: Buffer; mime: string }> => {
const filePath = path.join(this.storageDir, id)
const data = await fs.promises.readFile(filePath)
const mime = `image/${path.extname(filePath).slice(1)}`

View File

@ -66,7 +66,7 @@ const api = {
base64Image: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64Image, fileId),
download: (url: string) => ipcRenderer.invoke(IpcChannel.File_Download, url),
copy: (fileId: string, destPath: string) => ipcRenderer.invoke(IpcChannel.File_Copy, fileId, destPath),
binaryFile: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryFile, fileId)
binaryImage: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryImage, fileId)
},
fs: {
read: (path: string) => ipcRenderer.invoke(IpcChannel.Fs_Read, path)

View File

@ -0,0 +1,20 @@
import { SVGProps } from 'react'
export function SvgSpinners180Ring(props: SVGProps<SVGSVGElement>) {
return (
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" {...props}>
{/* Icon from SVG Spinners by Utkarsh Verma - https://github.com/n3r4zzurr0/svg-spinners/blob/main/LICENSE */}
<path
fill="currentColor"
d="M12,4a8,8,0,0,1,7.89,6.7A1.53,1.53,0,0,0,21.38,12h0a1.5,1.5,0,0,0,1.48-1.75,11,11,0,0,0-21.72,0A1.5,1.5,0,0,0,2.62,12h0a1.53,1.53,0,0,0,1.49-1.3A8,8,0,0,1,12,4Z">
<animateTransform
attributeName="transform"
dur="0.75s"
repeatCount="indefinite"
type="rotate"
values="0 12 12;360 12 12"></animateTransform>
</path>
</svg>
)
}
export default SvgSpinners180Ring

View File

@ -24,9 +24,9 @@ import { Avatar, Drawer, Tooltip } from 'antd'
import { WebviewTag } from 'electron'
import { useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import BeatLoader from 'react-spinners/BeatLoader'
import styled from 'styled-components'
import SvgSpinners180Ring from '../Icons/SvgSpinners180Ring'
import WebviewContainer from './WebviewContainer'
interface AppExtraInfo {
@ -375,7 +375,7 @@ const MinappPopupContainer: React.FC = () => {
size={80}
style={{ border: '1px solid var(--color-border)', marginTop: -150 }}
/>
<BeatLoader color="var(--color-text-2)" size="10px" style={{ marginTop: 15 }} />
<SvgSpinners180Ring color="var(--color-text-2)" style={{ marginTop: 15 }} />
</EmptyView>
)}
{WebviewContainerGroup}

View File

@ -1,4 +1,4 @@
import { isVisionModel } from '@renderer/config/models'
import { isGenerateImageModel, isVisionModel } from '@renderer/config/models'
import { FileType, Model } from '@renderer/types'
import { documentExts, imageExts, textExts } from '@shared/config/constant'
import { Tooltip } from 'antd'
@ -22,10 +22,19 @@ interface Props {
const AttachmentButton: FC<Props> = ({ ref, model, files, setFiles, ToolbarButton, disabled }) => {
const { t } = useTranslation()
const extensions = useMemo(
() => (isVisionModel(model) ? [...imageExts, ...documentExts, ...textExts] : [...documentExts, ...textExts]),
[model]
)
// const extensions = useMemo(
// () => (isVisionModel(model) ? [...imageExts, ...documentExts, ...textExts] : [...documentExts, ...textExts]),
// [model]
// )
const extensions = useMemo(() => {
if (isVisionModel(model)) {
return [...imageExts, ...documentExts, ...textExts]
} else if (isGenerateImageModel(model)) {
return [...imageExts]
} else {
return [...documentExts, ...textExts]
}
}, [model])
const onSelectFile = useCallback(async () => {
const _files = await window.api.file.select({
@ -54,7 +63,9 @@ const AttachmentButton: FC<Props> = ({ ref, model, files, setFiles, ToolbarButto
return (
<Tooltip
placement="top"
title={isVisionModel(model) ? t('chat.input.upload') : t('chat.input.upload.document')}
title={
isVisionModel(model) || isGenerateImageModel(model) ? t('chat.input.upload') : t('chat.input.upload.document')
}
arrow>
<ToolbarButton type="text" onClick={onSelectFile} disabled={disabled}>
<Paperclip size={18} style={{ color: files.length ? 'var(--color-primary)' : 'var(--color-icon)' }} />

View File

@ -576,7 +576,8 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
event.preventDefault()
if (file.path === '') {
if (file.type.startsWith('image/') && isVisionModel(model)) {
// 图像生成也支持图像编辑
if (file.type.startsWith('image/') && (isVisionModel(model) || isGenerateImageModel(model))) {
const tempFilePath = await window.api.file.create(file.name)
const arrayBuffer = await file.arrayBuffer()
const uint8Array = new Uint8Array(arrayBuffer)

View File

@ -1,3 +1,4 @@
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
import type { ImageMessageBlock } from '@renderer/types/newMessage'
import React from 'react'
@ -8,7 +9,7 @@ interface Props {
}
const ImageBlock: React.FC<Props> = ({ block }) => {
return <MessageImage block={block} />
return block.status === 'success' ? <MessageImage block={block} /> : <SvgSpinners180Ring />
}
export default React.memo(ImageBlock)

View File

@ -1,6 +1,6 @@
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
import { MessageBlockStatus, MessageBlockType, type PlaceholderMessageBlock } from '@renderer/types/newMessage'
import React from 'react'
import { BeatLoader } from 'react-spinners'
import styled from 'styled-components'
interface PlaceholderBlockProps {
@ -10,7 +10,7 @@ const PlaceholderBlock: React.FC<PlaceholderBlockProps> = ({ block }) => {
if (block.status === MessageBlockStatus.PROCESSING && block.type === MessageBlockType.UNKNOWN) {
return (
<MessageContentLoading>
<BeatLoader size={8} />
<SvgSpinners180Ring />
</MessageContentLoading>
)
}

View File

@ -90,7 +90,9 @@ const MessageImage: FC<Props> = ({ block }) => {
const images = block.metadata?.generateImageResponse?.images?.length
? block.metadata?.generateImageResponse?.images
: // TODO 加file是否合适
[`file://${block?.file?.path}`]
block?.file?.path
? [`file://${block?.file?.path}`]
: []
return (
<Container style={{ marginBottom: 8 }}>
{images.map((image, index) => (

View File

@ -1,9 +1,9 @@
import { TranslationOutlined } from '@ant-design/icons'
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
import type { TranslationMessageBlock } from '@renderer/types/newMessage'
import { Divider } from 'antd'
import { FC, Fragment } from 'react'
import { useTranslation } from 'react-i18next'
import BeatLoader from 'react-spinners/BeatLoader'
import Markdown from '../Markdown/Markdown'
@ -24,7 +24,7 @@ const MessageTranslate: FC<Props> = ({ block }) => {
<TranslationOutlined />
</Divider>
{block.content === t('translate.processing') ? (
<BeatLoader color="var(--color-text-2)" size="10" style={{ marginBottom: 15 }} />
<SvgSpinners180Ring color="var(--color-text-2)" style={{ marginBottom: 15 }} />
) : (
<Markdown block={block} />
)}

View File

@ -1,3 +1,4 @@
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
import Scrollbar from '@renderer/components/Scrollbar'
import { LOAD_MORE_COUNT } from '@renderer/config/constant'
import { useAssistant } from '@renderer/hooks/useAssistant'
@ -24,7 +25,6 @@ import { last } from 'lodash'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import InfiniteScroll from 'react-infinite-scroll-component'
import BeatLoader from 'react-spinners/BeatLoader'
import styled from 'styled-components'
import ChatNavigation from './ChatNavigation'
@ -238,7 +238,7 @@ const Messages: React.FC<MessagesProps> = ({ assistant, topic, setActiveTopic })
style={{ overflow: 'visible' }}>
<ScrollContainer>
<LoaderContainer $loading={isLoadingMore}>
<BeatLoader size={8} color="var(--color-text-2)" />
<SvgSpinners180Ring color="var(--color-text-2)" />
</LoaderContainer>
{groupedMessages.map(([key, groupMessages]) => (
<MessageGroup

View File

@ -1,3 +1,4 @@
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
import { fetchSuggestions } from '@renderer/services/ApiService'
import { getUserMessage } from '@renderer/services/MessagesService'
import { useAppDispatch } from '@renderer/store'
@ -6,7 +7,6 @@ import { Assistant, Suggestion } from '@renderer/types'
import type { Message } from '@renderer/types/newMessage'
import { last } from 'lodash'
import { FC, memo, useEffect, useState } from 'react'
import BeatLoader from 'react-spinners/BeatLoader'
import styled from 'styled-components'
interface Props {
@ -66,7 +66,7 @@ const Suggestions: FC<Props> = ({ assistant, messages }) => {
if (loadingSuggestions) {
return (
<Container>
<BeatLoader color="var(--color-text-2)" size="10" />
<SvgSpinners180Ring color="var(--color-text-2)" />
</Container>
)
}

View File

@ -15,6 +15,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService'
import FileManager from '@renderer/services/FileManager'
import {
filterContextMessages,
filterEmptyMessages,
@ -47,12 +48,13 @@ import { mcpToolCallResponseToOpenAIMessage, parseAndCallTools } from '@renderer
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { isEmpty, takeRight } from 'lodash'
import OpenAI, { AzureOpenAI } from 'openai'
import OpenAI, { AzureOpenAI, toFile } from 'openai'
import {
ChatCompletionContentPart,
ChatCompletionCreateParamsNonStreaming,
ChatCompletionMessageParam
} from 'openai/resources'
import { FileLike } from 'openai/uploads'
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
@ -1118,50 +1120,119 @@ export default class OpenAIProvider extends BaseProvider {
public async generateImageByChat({ messages, assistant, onChunk }: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
// save image data from the last assistant message
messages = addImageFileToContents(messages)
const lastUserMessage = messages.findLast((m) => m.role === 'user')
const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant')
if (!lastUserMessage) {
return
}
const { abortController } = this.createAbortController(lastUserMessage?.id, true)
const { signal } = abortController
const content = getMainTextContent(lastUserMessage!)
let response: OpenAI.Images.ImagesResponse | null = null
let images: FileLike[] = []
onChunk({
type: ChunkType.IMAGE_CREATED
})
const start_time_millsec = new Date().getTime()
const response = await this.sdk.images.generate(
{
model: model.id,
prompt: getMainTextContent(lastUserMessage!) || '',
response_format: model.id.includes('gpt-image-1') ? undefined : 'b64_json'
},
{
signal
try {
if (lastUserMessage) {
const UserFiles = findImageBlocks(lastUserMessage)
const validUserFiles = UserFiles.filter((f) => f.file) // Filter out files that are undefined first
const userImages = await Promise.all(
validUserFiles.map(async (f) => {
// f.file is guaranteed to exist here due to the filter above
const fileInfo = f.file!
const binaryData = await FileManager.readFile(fileInfo)
console.log('binaryData', binaryData)
const file = await toFile(binaryData, fileInfo.origin_name || 'image.png', {
type: 'image/png'
})
return file
})
)
images = images.concat(userImages)
}
)
onChunk({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: response.data?.map((item) => `data:image/png;base64,${item.b64_json}`) || []
if (lastAssistantMessage) {
const assistantFiles = findImageBlocks(lastAssistantMessage)
const assistantImages = await Promise.all(
assistantFiles.filter(Boolean).map(async (f) => {
const base64Data = f?.url?.replace(/^data:image\/\w+;base64,/, '')
if (!base64Data) return null
const binary = atob(base64Data)
const bytes = new Uint8Array(binary.length)
for (let i = 0; i < binary.length; i++) {
bytes[i] = binary.charCodeAt(i)
}
const file = await toFile(bytes, 'assistant_image.png', {
type: 'image/png'
})
return file
})
)
images = images.concat(assistantImages.filter(Boolean) as FileLike[])
}
})
onChunk({
type: ChunkType.IMAGE_CREATED
})
// Create synthetic usage and metrics data for image generation
const time_completion_millsec = new Date().getTime() - start_time_millsec
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
usage: {
completion_tokens: response.usage?.output_tokens || 0,
prompt_tokens: response.usage?.input_tokens || 0,
total_tokens: response.usage?.total_tokens || 0
},
metrics: {
completion_tokens: response.usage?.output_tokens || 0,
time_first_token_millsec: 0, // Non-streaming, first token time is not relevant
time_completion_millsec
const start_time_millsec = new Date().getTime()
if (images.length > 0) {
response = await this.sdk.images.edit(
{
model: model.id,
image: images,
prompt: content || ''
},
{
signal,
timeout: 300_000
}
)
} else {
response = await this.sdk.images.generate(
{
model: model.id,
prompt: content || '',
response_format: model.id.includes('gpt-image-1') ? undefined : 'b64_json'
},
{
signal,
timeout: 300_000
}
)
}
onChunk({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: response?.data?.map((item) => `data:image/png;base64,${item.b64_json}`) || []
}
}
})
return
})
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
usage: {
completion_tokens: response.usage?.output_tokens || 0,
prompt_tokens: response.usage?.input_tokens || 0,
total_tokens: response.usage?.total_tokens || 0
},
metrics: {
completion_tokens: response.usage?.output_tokens || 0,
time_first_token_millsec: 0, // Non-streaming, first token time is not relevant
time_completion_millsec: new Date().getTime() - start_time_millsec
}
}
})
} catch (error: any) {
console.error('[generateImageByChat] error', error)
onChunk({
type: ChunkType.ERROR,
error
})
}
}
}

View File

@ -29,7 +29,8 @@ class FileManager {
}
static async readFile(file: FileType): Promise<Buffer> {
return (await window.api.file.binaryFile(file.id + file.ext)).data
const fileData = await window.api.file.binaryImage(file.id + file.ext)
return fileData.data
}
static async uploadFile(file: FileType): Promise<FileType> {

View File

@ -1557,6 +1557,22 @@ __metadata:
languageName: node
linkType: hard
"@iconify-json/svg-spinners@npm:^1.2.2":
version: 1.2.2
resolution: "@iconify-json/svg-spinners@npm:1.2.2"
dependencies:
"@iconify/types": "npm:*"
checksum: 10c0/61869963c21bc03052d64cd19155f9d596ffc71b3934ccdb468f5b5a1d3f003089c87744bf76145d8b1e946c45a88e1439e1f9470fcfc3a847b88262f1d71c76
languageName: node
linkType: hard
"@iconify/types@npm:*":
version: 2.0.0
resolution: "@iconify/types@npm:2.0.0"
checksum: 10c0/65a3be43500c7ccacf360e136d00e1717f050b7b91da644e94370256ac66f582d59212bdb30d00788aab4fc078262e91c95b805d1808d654b72f6d2072a7e4b2
languageName: node
linkType: hard
"@isaacs/cliui@npm:^8.0.2":
version: 8.0.2
resolution: "@isaacs/cliui@npm:8.0.2"
@ -4324,6 +4340,7 @@ __metadata:
"@eslint/js": "npm:^9.22.0"
"@google/genai": "npm:^0.10.0"
"@hello-pangea/dnd": "npm:^16.6.0"
"@iconify-json/svg-spinners": "npm:^1.2.2"
"@kangfenmao/keyv-storage": "npm:^0.1.0"
"@langchain/community": "npm:^0.3.36"
"@modelcontextprotocol/sdk": "npm:^1.10.2"