mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-21 16:01:35 +08:00
feat(OpenAIProvider): support image edit (#5469)
* feat(OpenAIProvider): support image edit * fix: can edit image * feat: add upload situation * chore: optimize abort * fix: image cannot read * chore(SvgSpinners180Ring): remove unused React import * refactor(FileManager): simplify file reading logic by storing intermediate data --------- Co-authored-by: eeee0717 <chentao020717Work@outlook.com>
This commit is contained in:
parent
23dd66d4a1
commit
b43702f21d
@ -118,6 +118,7 @@
|
||||
"@eslint/js": "^9.22.0",
|
||||
"@google/genai": "^0.10.0",
|
||||
"@hello-pangea/dnd": "^16.6.0",
|
||||
"@iconify-json/svg-spinners": "^1.2.2",
|
||||
"@kangfenmao/keyv-storage": "^0.1.0",
|
||||
"@modelcontextprotocol/sdk": "^1.10.2",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
|
||||
@ -109,7 +109,7 @@ export enum IpcChannel {
|
||||
File_Base64Image = 'file:base64Image',
|
||||
File_Download = 'file:download',
|
||||
File_Copy = 'file:copy',
|
||||
File_BinaryFile = 'file:binaryFile',
|
||||
File_BinaryImage = 'file:binaryImage',
|
||||
|
||||
Fs_Read = 'fs:read',
|
||||
|
||||
|
||||
@ -215,7 +215,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
ipcMain.handle(IpcChannel.File_Base64Image, fileManager.base64Image)
|
||||
ipcMain.handle(IpcChannel.File_Download, fileManager.downloadFile)
|
||||
ipcMain.handle(IpcChannel.File_Copy, fileManager.copyFile)
|
||||
ipcMain.handle(IpcChannel.File_BinaryFile, fileManager.binaryFile)
|
||||
ipcMain.handle(IpcChannel.File_BinaryImage, fileManager.binaryImage)
|
||||
|
||||
// fs
|
||||
ipcMain.handle(IpcChannel.Fs_Read, FileService.readFile)
|
||||
|
||||
@ -263,7 +263,7 @@ class FileStorage {
|
||||
}
|
||||
}
|
||||
|
||||
public binaryFile = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: Buffer; mime: string }> => {
|
||||
public binaryImage = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: Buffer; mime: string }> => {
|
||||
const filePath = path.join(this.storageDir, id)
|
||||
const data = await fs.promises.readFile(filePath)
|
||||
const mime = `image/${path.extname(filePath).slice(1)}`
|
||||
|
||||
@ -66,7 +66,7 @@ const api = {
|
||||
base64Image: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64Image, fileId),
|
||||
download: (url: string) => ipcRenderer.invoke(IpcChannel.File_Download, url),
|
||||
copy: (fileId: string, destPath: string) => ipcRenderer.invoke(IpcChannel.File_Copy, fileId, destPath),
|
||||
binaryFile: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryFile, fileId)
|
||||
binaryImage: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryImage, fileId)
|
||||
},
|
||||
fs: {
|
||||
read: (path: string) => ipcRenderer.invoke(IpcChannel.Fs_Read, path)
|
||||
|
||||
20
src/renderer/src/components/Icons/SvgSpinners180Ring.tsx
Normal file
20
src/renderer/src/components/Icons/SvgSpinners180Ring.tsx
Normal file
@ -0,0 +1,20 @@
|
||||
import { SVGProps } from 'react'
|
||||
|
||||
export function SvgSpinners180Ring(props: SVGProps<SVGSVGElement>) {
|
||||
return (
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" {...props}>
|
||||
{/* Icon from SVG Spinners by Utkarsh Verma - https://github.com/n3r4zzurr0/svg-spinners/blob/main/LICENSE */}
|
||||
<path
|
||||
fill="currentColor"
|
||||
d="M12,4a8,8,0,0,1,7.89,6.7A1.53,1.53,0,0,0,21.38,12h0a1.5,1.5,0,0,0,1.48-1.75,11,11,0,0,0-21.72,0A1.5,1.5,0,0,0,2.62,12h0a1.53,1.53,0,0,0,1.49-1.3A8,8,0,0,1,12,4Z">
|
||||
<animateTransform
|
||||
attributeName="transform"
|
||||
dur="0.75s"
|
||||
repeatCount="indefinite"
|
||||
type="rotate"
|
||||
values="0 12 12;360 12 12"></animateTransform>
|
||||
</path>
|
||||
</svg>
|
||||
)
|
||||
}
|
||||
export default SvgSpinners180Ring
|
||||
@ -24,9 +24,9 @@ import { Avatar, Drawer, Tooltip } from 'antd'
|
||||
import { WebviewTag } from 'electron'
|
||||
import { useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import BeatLoader from 'react-spinners/BeatLoader'
|
||||
import styled from 'styled-components'
|
||||
|
||||
import SvgSpinners180Ring from '../Icons/SvgSpinners180Ring'
|
||||
import WebviewContainer from './WebviewContainer'
|
||||
|
||||
interface AppExtraInfo {
|
||||
@ -375,7 +375,7 @@ const MinappPopupContainer: React.FC = () => {
|
||||
size={80}
|
||||
style={{ border: '1px solid var(--color-border)', marginTop: -150 }}
|
||||
/>
|
||||
<BeatLoader color="var(--color-text-2)" size="10px" style={{ marginTop: 15 }} />
|
||||
<SvgSpinners180Ring color="var(--color-text-2)" style={{ marginTop: 15 }} />
|
||||
</EmptyView>
|
||||
)}
|
||||
{WebviewContainerGroup}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { isVisionModel } from '@renderer/config/models'
|
||||
import { isGenerateImageModel, isVisionModel } from '@renderer/config/models'
|
||||
import { FileType, Model } from '@renderer/types'
|
||||
import { documentExts, imageExts, textExts } from '@shared/config/constant'
|
||||
import { Tooltip } from 'antd'
|
||||
@ -22,10 +22,19 @@ interface Props {
|
||||
const AttachmentButton: FC<Props> = ({ ref, model, files, setFiles, ToolbarButton, disabled }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const extensions = useMemo(
|
||||
() => (isVisionModel(model) ? [...imageExts, ...documentExts, ...textExts] : [...documentExts, ...textExts]),
|
||||
[model]
|
||||
)
|
||||
// const extensions = useMemo(
|
||||
// () => (isVisionModel(model) ? [...imageExts, ...documentExts, ...textExts] : [...documentExts, ...textExts]),
|
||||
// [model]
|
||||
// )
|
||||
const extensions = useMemo(() => {
|
||||
if (isVisionModel(model)) {
|
||||
return [...imageExts, ...documentExts, ...textExts]
|
||||
} else if (isGenerateImageModel(model)) {
|
||||
return [...imageExts]
|
||||
} else {
|
||||
return [...documentExts, ...textExts]
|
||||
}
|
||||
}, [model])
|
||||
|
||||
const onSelectFile = useCallback(async () => {
|
||||
const _files = await window.api.file.select({
|
||||
@ -54,7 +63,9 @@ const AttachmentButton: FC<Props> = ({ ref, model, files, setFiles, ToolbarButto
|
||||
return (
|
||||
<Tooltip
|
||||
placement="top"
|
||||
title={isVisionModel(model) ? t('chat.input.upload') : t('chat.input.upload.document')}
|
||||
title={
|
||||
isVisionModel(model) || isGenerateImageModel(model) ? t('chat.input.upload') : t('chat.input.upload.document')
|
||||
}
|
||||
arrow>
|
||||
<ToolbarButton type="text" onClick={onSelectFile} disabled={disabled}>
|
||||
<Paperclip size={18} style={{ color: files.length ? 'var(--color-primary)' : 'var(--color-icon)' }} />
|
||||
|
||||
@ -576,7 +576,8 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
|
||||
event.preventDefault()
|
||||
|
||||
if (file.path === '') {
|
||||
if (file.type.startsWith('image/') && isVisionModel(model)) {
|
||||
// 图像生成也支持图像编辑
|
||||
if (file.type.startsWith('image/') && (isVisionModel(model) || isGenerateImageModel(model))) {
|
||||
const tempFilePath = await window.api.file.create(file.name)
|
||||
const arrayBuffer = await file.arrayBuffer()
|
||||
const uint8Array = new Uint8Array(arrayBuffer)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
|
||||
import type { ImageMessageBlock } from '@renderer/types/newMessage'
|
||||
import React from 'react'
|
||||
|
||||
@ -8,7 +9,7 @@ interface Props {
|
||||
}
|
||||
|
||||
const ImageBlock: React.FC<Props> = ({ block }) => {
|
||||
return <MessageImage block={block} />
|
||||
return block.status === 'success' ? <MessageImage block={block} /> : <SvgSpinners180Ring />
|
||||
}
|
||||
|
||||
export default React.memo(ImageBlock)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
|
||||
import { MessageBlockStatus, MessageBlockType, type PlaceholderMessageBlock } from '@renderer/types/newMessage'
|
||||
import React from 'react'
|
||||
import { BeatLoader } from 'react-spinners'
|
||||
import styled from 'styled-components'
|
||||
|
||||
interface PlaceholderBlockProps {
|
||||
@ -10,7 +10,7 @@ const PlaceholderBlock: React.FC<PlaceholderBlockProps> = ({ block }) => {
|
||||
if (block.status === MessageBlockStatus.PROCESSING && block.type === MessageBlockType.UNKNOWN) {
|
||||
return (
|
||||
<MessageContentLoading>
|
||||
<BeatLoader size={8} />
|
||||
<SvgSpinners180Ring />
|
||||
</MessageContentLoading>
|
||||
)
|
||||
}
|
||||
|
||||
@ -90,7 +90,9 @@ const MessageImage: FC<Props> = ({ block }) => {
|
||||
const images = block.metadata?.generateImageResponse?.images?.length
|
||||
? block.metadata?.generateImageResponse?.images
|
||||
: // TODO 加file是否合适?
|
||||
[`file://${block?.file?.path}`]
|
||||
block?.file?.path
|
||||
? [`file://${block?.file?.path}`]
|
||||
: []
|
||||
return (
|
||||
<Container style={{ marginBottom: 8 }}>
|
||||
{images.map((image, index) => (
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import { TranslationOutlined } from '@ant-design/icons'
|
||||
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
|
||||
import type { TranslationMessageBlock } from '@renderer/types/newMessage'
|
||||
import { Divider } from 'antd'
|
||||
import { FC, Fragment } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import BeatLoader from 'react-spinners/BeatLoader'
|
||||
|
||||
import Markdown from '../Markdown/Markdown'
|
||||
|
||||
@ -24,7 +24,7 @@ const MessageTranslate: FC<Props> = ({ block }) => {
|
||||
<TranslationOutlined />
|
||||
</Divider>
|
||||
{block.content === t('translate.processing') ? (
|
||||
<BeatLoader color="var(--color-text-2)" size="10" style={{ marginBottom: 15 }} />
|
||||
<SvgSpinners180Ring color="var(--color-text-2)" style={{ marginBottom: 15 }} />
|
||||
) : (
|
||||
<Markdown block={block} />
|
||||
)}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
|
||||
import Scrollbar from '@renderer/components/Scrollbar'
|
||||
import { LOAD_MORE_COUNT } from '@renderer/config/constant'
|
||||
import { useAssistant } from '@renderer/hooks/useAssistant'
|
||||
@ -24,7 +25,6 @@ import { last } from 'lodash'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import InfiniteScroll from 'react-infinite-scroll-component'
|
||||
import BeatLoader from 'react-spinners/BeatLoader'
|
||||
import styled from 'styled-components'
|
||||
|
||||
import ChatNavigation from './ChatNavigation'
|
||||
@ -238,7 +238,7 @@ const Messages: React.FC<MessagesProps> = ({ assistant, topic, setActiveTopic })
|
||||
style={{ overflow: 'visible' }}>
|
||||
<ScrollContainer>
|
||||
<LoaderContainer $loading={isLoadingMore}>
|
||||
<BeatLoader size={8} color="var(--color-text-2)" />
|
||||
<SvgSpinners180Ring color="var(--color-text-2)" />
|
||||
</LoaderContainer>
|
||||
{groupedMessages.map(([key, groupMessages]) => (
|
||||
<MessageGroup
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
|
||||
import { fetchSuggestions } from '@renderer/services/ApiService'
|
||||
import { getUserMessage } from '@renderer/services/MessagesService'
|
||||
import { useAppDispatch } from '@renderer/store'
|
||||
@ -6,7 +7,6 @@ import { Assistant, Suggestion } from '@renderer/types'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import { last } from 'lodash'
|
||||
import { FC, memo, useEffect, useState } from 'react'
|
||||
import BeatLoader from 'react-spinners/BeatLoader'
|
||||
import styled from 'styled-components'
|
||||
|
||||
interface Props {
|
||||
@ -66,7 +66,7 @@ const Suggestions: FC<Props> = ({ assistant, messages }) => {
|
||||
if (loadingSuggestions) {
|
||||
return (
|
||||
<Container>
|
||||
<BeatLoader color="var(--color-text-2)" size="10" />
|
||||
<SvgSpinners180Ring color="var(--color-text-2)" />
|
||||
</Container>
|
||||
)
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import {
|
||||
filterContextMessages,
|
||||
filterEmptyMessages,
|
||||
@ -47,12 +48,13 @@ import { mcpToolCallResponseToOpenAIMessage, parseAndCallTools } from '@renderer
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { isEmpty, takeRight } from 'lodash'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
import OpenAI, { AzureOpenAI, toFile } from 'openai'
|
||||
import {
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionCreateParamsNonStreaming,
|
||||
ChatCompletionMessageParam
|
||||
} from 'openai/resources'
|
||||
import { FileLike } from 'openai/uploads'
|
||||
|
||||
import { CompletionsParams } from '.'
|
||||
import BaseProvider from './BaseProvider'
|
||||
@ -1118,35 +1120,98 @@ export default class OpenAIProvider extends BaseProvider {
|
||||
public async generateImageByChat({ messages, assistant, onChunk }: CompletionsParams): Promise<void> {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
// save image data from the last assistant message
|
||||
messages = addImageFileToContents(messages)
|
||||
const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||
const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant')
|
||||
if (!lastUserMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
const { abortController } = this.createAbortController(lastUserMessage?.id, true)
|
||||
const { signal } = abortController
|
||||
const content = getMainTextContent(lastUserMessage!)
|
||||
let response: OpenAI.Images.ImagesResponse | null = null
|
||||
let images: FileLike[] = []
|
||||
|
||||
try {
|
||||
if (lastUserMessage) {
|
||||
const UserFiles = findImageBlocks(lastUserMessage)
|
||||
const validUserFiles = UserFiles.filter((f) => f.file) // Filter out files that are undefined first
|
||||
const userImages = await Promise.all(
|
||||
validUserFiles.map(async (f) => {
|
||||
// f.file is guaranteed to exist here due to the filter above
|
||||
const fileInfo = f.file!
|
||||
const binaryData = await FileManager.readFile(fileInfo)
|
||||
console.log('binaryData', binaryData)
|
||||
const file = await toFile(binaryData, fileInfo.origin_name || 'image.png', {
|
||||
type: 'image/png'
|
||||
})
|
||||
return file
|
||||
})
|
||||
)
|
||||
images = images.concat(userImages)
|
||||
}
|
||||
|
||||
if (lastAssistantMessage) {
|
||||
const assistantFiles = findImageBlocks(lastAssistantMessage)
|
||||
const assistantImages = await Promise.all(
|
||||
assistantFiles.filter(Boolean).map(async (f) => {
|
||||
const base64Data = f?.url?.replace(/^data:image\/\w+;base64,/, '')
|
||||
if (!base64Data) return null
|
||||
const binary = atob(base64Data)
|
||||
const bytes = new Uint8Array(binary.length)
|
||||
for (let i = 0; i < binary.length; i++) {
|
||||
bytes[i] = binary.charCodeAt(i)
|
||||
}
|
||||
const file = await toFile(bytes, 'assistant_image.png', {
|
||||
type: 'image/png'
|
||||
})
|
||||
return file
|
||||
})
|
||||
)
|
||||
images = images.concat(assistantImages.filter(Boolean) as FileLike[])
|
||||
}
|
||||
onChunk({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
|
||||
const start_time_millsec = new Date().getTime()
|
||||
const response = await this.sdk.images.generate(
|
||||
|
||||
if (images.length > 0) {
|
||||
response = await this.sdk.images.edit(
|
||||
{
|
||||
model: model.id,
|
||||
prompt: getMainTextContent(lastUserMessage!) || '',
|
||||
image: images,
|
||||
prompt: content || ''
|
||||
},
|
||||
{
|
||||
signal,
|
||||
timeout: 300_000
|
||||
}
|
||||
)
|
||||
} else {
|
||||
response = await this.sdk.images.generate(
|
||||
{
|
||||
model: model.id,
|
||||
prompt: content || '',
|
||||
response_format: model.id.includes('gpt-image-1') ? undefined : 'b64_json'
|
||||
},
|
||||
{
|
||||
signal
|
||||
signal,
|
||||
timeout: 300_000
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: response.data?.map((item) => `data:image/png;base64,${item.b64_json}`) || []
|
||||
images: response?.data?.map((item) => `data:image/png;base64,${item.b64_json}`) || []
|
||||
}
|
||||
})
|
||||
|
||||
// Create synthetic usage and metrics data for image generation
|
||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||
onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
@ -1158,10 +1223,16 @@ export default class OpenAIProvider extends BaseProvider {
|
||||
metrics: {
|
||||
completion_tokens: response.usage?.output_tokens || 0,
|
||||
time_first_token_millsec: 0, // Non-streaming, first token time is not relevant
|
||||
time_completion_millsec
|
||||
time_completion_millsec: new Date().getTime() - start_time_millsec
|
||||
}
|
||||
}
|
||||
})
|
||||
return
|
||||
} catch (error: any) {
|
||||
console.error('[generateImageByChat] error', error)
|
||||
onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,7 +29,8 @@ class FileManager {
|
||||
}
|
||||
|
||||
static async readFile(file: FileType): Promise<Buffer> {
|
||||
return (await window.api.file.binaryFile(file.id + file.ext)).data
|
||||
const fileData = await window.api.file.binaryImage(file.id + file.ext)
|
||||
return fileData.data
|
||||
}
|
||||
|
||||
static async uploadFile(file: FileType): Promise<FileType> {
|
||||
|
||||
17
yarn.lock
17
yarn.lock
@ -1557,6 +1557,22 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@iconify-json/svg-spinners@npm:^1.2.2":
|
||||
version: 1.2.2
|
||||
resolution: "@iconify-json/svg-spinners@npm:1.2.2"
|
||||
dependencies:
|
||||
"@iconify/types": "npm:*"
|
||||
checksum: 10c0/61869963c21bc03052d64cd19155f9d596ffc71b3934ccdb468f5b5a1d3f003089c87744bf76145d8b1e946c45a88e1439e1f9470fcfc3a847b88262f1d71c76
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@iconify/types@npm:*":
|
||||
version: 2.0.0
|
||||
resolution: "@iconify/types@npm:2.0.0"
|
||||
checksum: 10c0/65a3be43500c7ccacf360e136d00e1717f050b7b91da644e94370256ac66f582d59212bdb30d00788aab4fc078262e91c95b805d1808d654b72f6d2072a7e4b2
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@isaacs/cliui@npm:^8.0.2":
|
||||
version: 8.0.2
|
||||
resolution: "@isaacs/cliui@npm:8.0.2"
|
||||
@ -4324,6 +4340,7 @@ __metadata:
|
||||
"@eslint/js": "npm:^9.22.0"
|
||||
"@google/genai": "npm:^0.10.0"
|
||||
"@hello-pangea/dnd": "npm:^16.6.0"
|
||||
"@iconify-json/svg-spinners": "npm:^1.2.2"
|
||||
"@kangfenmao/keyv-storage": "npm:^0.1.0"
|
||||
"@langchain/community": "npm:^0.3.36"
|
||||
"@modelcontextprotocol/sdk": "npm:^1.10.2"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user