diff --git a/package.json b/package.json index 65f569ecdb..e12267d8f2 100644 --- a/package.json +++ b/package.json @@ -118,6 +118,7 @@ "@eslint/js": "^9.22.0", "@google/genai": "^0.10.0", "@hello-pangea/dnd": "^16.6.0", + "@iconify-json/svg-spinners": "^1.2.2", "@kangfenmao/keyv-storage": "^0.1.0", "@modelcontextprotocol/sdk": "^1.10.2", "@mozilla/readability": "^0.6.0", diff --git a/packages/shared/IpcChannel.ts b/packages/shared/IpcChannel.ts index 63c340e7a7..c1725968f4 100644 --- a/packages/shared/IpcChannel.ts +++ b/packages/shared/IpcChannel.ts @@ -109,7 +109,7 @@ export enum IpcChannel { File_Base64Image = 'file:base64Image', File_Download = 'file:download', File_Copy = 'file:copy', - File_BinaryFile = 'file:binaryFile', + File_BinaryImage = 'file:binaryImage', Fs_Read = 'fs:read', diff --git a/src/main/ipc.ts b/src/main/ipc.ts index 0ae09d8b21..e4c9bf3e40 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -215,7 +215,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.File_Base64Image, fileManager.base64Image) ipcMain.handle(IpcChannel.File_Download, fileManager.downloadFile) ipcMain.handle(IpcChannel.File_Copy, fileManager.copyFile) - ipcMain.handle(IpcChannel.File_BinaryFile, fileManager.binaryFile) + ipcMain.handle(IpcChannel.File_BinaryImage, fileManager.binaryImage) // fs ipcMain.handle(IpcChannel.Fs_Read, FileService.readFile) diff --git a/src/main/services/FileStorage.ts b/src/main/services/FileStorage.ts index accad80959..c374c8f0db 100644 --- a/src/main/services/FileStorage.ts +++ b/src/main/services/FileStorage.ts @@ -263,7 +263,7 @@ class FileStorage { } } - public binaryFile = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: Buffer; mime: string }> => { + public binaryImage = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: Buffer; mime: string }> => { const filePath = path.join(this.storageDir, id) const data = await fs.promises.readFile(filePath) const mime = `image/${path.extname(filePath).slice(1)}` diff --git a/src/preload/index.ts b/src/preload/index.ts index 0e3f432f65..e28dc316e7 100644 --- a/src/preload/index.ts +++ b/src/preload/index.ts @@ -66,7 +66,7 @@ const api = { base64Image: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64Image, fileId), download: (url: string) => ipcRenderer.invoke(IpcChannel.File_Download, url), copy: (fileId: string, destPath: string) => ipcRenderer.invoke(IpcChannel.File_Copy, fileId, destPath), - binaryFile: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryFile, fileId) + binaryImage: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryImage, fileId) }, fs: { read: (path: string) => ipcRenderer.invoke(IpcChannel.Fs_Read, path) diff --git a/src/renderer/src/components/Icons/SvgSpinners180Ring.tsx b/src/renderer/src/components/Icons/SvgSpinners180Ring.tsx new file mode 100644 index 0000000000..9e39d6a770 --- /dev/null +++ b/src/renderer/src/components/Icons/SvgSpinners180Ring.tsx @@ -0,0 +1,20 @@ +import { SVGProps } from 'react' + +export function SvgSpinners180Ring(props: SVGProps) { + return ( + + {/* Icon from SVG Spinners by Utkarsh Verma - https://github.com/n3r4zzurr0/svg-spinners/blob/main/LICENSE */} + + + + + ) +} +export default SvgSpinners180Ring diff --git a/src/renderer/src/components/MinApp/MinappPopupContainer.tsx b/src/renderer/src/components/MinApp/MinappPopupContainer.tsx index 9ee53bab13..2ace2a18ba 100644 --- a/src/renderer/src/components/MinApp/MinappPopupContainer.tsx +++ b/src/renderer/src/components/MinApp/MinappPopupContainer.tsx @@ -24,9 +24,9 @@ import { Avatar, Drawer, Tooltip } from 'antd' import { WebviewTag } from 'electron' import { useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import BeatLoader from 'react-spinners/BeatLoader' import styled from 'styled-components' +import SvgSpinners180Ring from '../Icons/SvgSpinners180Ring' import WebviewContainer from './WebviewContainer' interface AppExtraInfo { @@ -375,7 +375,7 @@ const MinappPopupContainer: React.FC = () => { size={80} style={{ border: '1px solid var(--color-border)', marginTop: -150 }} /> - + )} {WebviewContainerGroup} diff --git a/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx b/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx index 051f5a5784..44b8772d75 100644 --- a/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx +++ b/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx @@ -1,4 +1,4 @@ -import { isVisionModel } from '@renderer/config/models' +import { isGenerateImageModel, isVisionModel } from '@renderer/config/models' import { FileType, Model } from '@renderer/types' import { documentExts, imageExts, textExts } from '@shared/config/constant' import { Tooltip } from 'antd' @@ -22,10 +22,19 @@ interface Props { const AttachmentButton: FC = ({ ref, model, files, setFiles, ToolbarButton, disabled }) => { const { t } = useTranslation() - const extensions = useMemo( - () => (isVisionModel(model) ? [...imageExts, ...documentExts, ...textExts] : [...documentExts, ...textExts]), - [model] - ) + // const extensions = useMemo( + // () => (isVisionModel(model) ? [...imageExts, ...documentExts, ...textExts] : [...documentExts, ...textExts]), + // [model] + // ) + const extensions = useMemo(() => { + if (isVisionModel(model)) { + return [...imageExts, ...documentExts, ...textExts] + } else if (isGenerateImageModel(model)) { + return [...imageExts] + } else { + return [...documentExts, ...textExts] + } + }, [model]) const onSelectFile = useCallback(async () => { const _files = await window.api.file.select({ @@ -54,7 +63,9 @@ const AttachmentButton: FC = ({ ref, model, files, setFiles, ToolbarButto return ( diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index 6cb8559af8..71128bf620 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -576,7 +576,8 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = event.preventDefault() if (file.path === '') { - if (file.type.startsWith('image/') && isVisionModel(model)) { + // 图像生成也支持图像编辑 + if (file.type.startsWith('image/') && (isVisionModel(model) || isGenerateImageModel(model))) { const tempFilePath = await window.api.file.create(file.name) const arrayBuffer = await file.arrayBuffer() const uint8Array = new Uint8Array(arrayBuffer) diff --git a/src/renderer/src/pages/home/Messages/Blocks/ImageBlock.tsx b/src/renderer/src/pages/home/Messages/Blocks/ImageBlock.tsx index 2b75f24fde..5ede5ec773 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/ImageBlock.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/ImageBlock.tsx @@ -1,3 +1,4 @@ +import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring' import type { ImageMessageBlock } from '@renderer/types/newMessage' import React from 'react' @@ -8,7 +9,7 @@ interface Props { } const ImageBlock: React.FC = ({ block }) => { - return + return block.status === 'success' ? : } export default React.memo(ImageBlock) diff --git a/src/renderer/src/pages/home/Messages/Blocks/PlaceholderBlock.tsx b/src/renderer/src/pages/home/Messages/Blocks/PlaceholderBlock.tsx index 36d8f5ff94..261ed7f140 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/PlaceholderBlock.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/PlaceholderBlock.tsx @@ -1,6 +1,6 @@ +import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring' import { MessageBlockStatus, MessageBlockType, type PlaceholderMessageBlock } from '@renderer/types/newMessage' import React from 'react' -import { BeatLoader } from 'react-spinners' import styled from 'styled-components' interface PlaceholderBlockProps { @@ -10,7 +10,7 @@ const PlaceholderBlock: React.FC = ({ block }) => { if (block.status === MessageBlockStatus.PROCESSING && block.type === MessageBlockType.UNKNOWN) { return ( - + ) } diff --git a/src/renderer/src/pages/home/Messages/MessageImage.tsx b/src/renderer/src/pages/home/Messages/MessageImage.tsx index 0d1d5c3e0b..cdeff1b443 100644 --- a/src/renderer/src/pages/home/Messages/MessageImage.tsx +++ b/src/renderer/src/pages/home/Messages/MessageImage.tsx @@ -90,7 +90,9 @@ const MessageImage: FC = ({ block }) => { const images = block.metadata?.generateImageResponse?.images?.length ? block.metadata?.generateImageResponse?.images : // TODO 加file是否合适? - [`file://${block?.file?.path}`] + block?.file?.path + ? [`file://${block?.file?.path}`] + : [] return ( {images.map((image, index) => ( diff --git a/src/renderer/src/pages/home/Messages/MessageTranslate.tsx b/src/renderer/src/pages/home/Messages/MessageTranslate.tsx index 305c1fbbd9..83bbd77c21 100644 --- a/src/renderer/src/pages/home/Messages/MessageTranslate.tsx +++ b/src/renderer/src/pages/home/Messages/MessageTranslate.tsx @@ -1,9 +1,9 @@ import { TranslationOutlined } from '@ant-design/icons' +import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring' import type { TranslationMessageBlock } from '@renderer/types/newMessage' import { Divider } from 'antd' import { FC, Fragment } from 'react' import { useTranslation } from 'react-i18next' -import BeatLoader from 'react-spinners/BeatLoader' import Markdown from '../Markdown/Markdown' @@ -24,7 +24,7 @@ const MessageTranslate: FC = ({ block }) => { {block.content === t('translate.processing') ? ( - + ) : ( )} diff --git a/src/renderer/src/pages/home/Messages/Messages.tsx b/src/renderer/src/pages/home/Messages/Messages.tsx index ed1f7e5239..566aa145c9 100644 --- a/src/renderer/src/pages/home/Messages/Messages.tsx +++ b/src/renderer/src/pages/home/Messages/Messages.tsx @@ -1,3 +1,4 @@ +import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring' import Scrollbar from '@renderer/components/Scrollbar' import { LOAD_MORE_COUNT } from '@renderer/config/constant' import { useAssistant } from '@renderer/hooks/useAssistant' @@ -24,7 +25,6 @@ import { last } from 'lodash' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import InfiniteScroll from 'react-infinite-scroll-component' -import BeatLoader from 'react-spinners/BeatLoader' import styled from 'styled-components' import ChatNavigation from './ChatNavigation' @@ -238,7 +238,7 @@ const Messages: React.FC = ({ assistant, topic, setActiveTopic }) style={{ overflow: 'visible' }}> - + {groupedMessages.map(([key, groupMessages]) => ( = ({ assistant, messages }) => { if (loadingSuggestions) { return ( - + ) } diff --git a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts index 4e1328e151..8dfd3a9c92 100644 --- a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts +++ b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts @@ -15,6 +15,7 @@ import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' +import FileManager from '@renderer/services/FileManager' import { filterContextMessages, filterEmptyMessages, @@ -47,12 +48,13 @@ import { mcpToolCallResponseToOpenAIMessage, parseAndCallTools } from '@renderer import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { buildSystemPrompt } from '@renderer/utils/prompt' import { isEmpty, takeRight } from 'lodash' -import OpenAI, { AzureOpenAI } from 'openai' +import OpenAI, { AzureOpenAI, toFile } from 'openai' import { ChatCompletionContentPart, ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam } from 'openai/resources' +import { FileLike } from 'openai/uploads' import { CompletionsParams } from '.' import BaseProvider from './BaseProvider' @@ -1118,50 +1120,119 @@ export default class OpenAIProvider extends BaseProvider { public async generateImageByChat({ messages, assistant, onChunk }: CompletionsParams): Promise { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel + // save image data from the last assistant message + messages = addImageFileToContents(messages) const lastUserMessage = messages.findLast((m) => m.role === 'user') + const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant') + if (!lastUserMessage) { + return + } + const { abortController } = this.createAbortController(lastUserMessage?.id, true) const { signal } = abortController + const content = getMainTextContent(lastUserMessage!) + let response: OpenAI.Images.ImagesResponse | null = null + let images: FileLike[] = [] - onChunk({ - type: ChunkType.IMAGE_CREATED - }) - const start_time_millsec = new Date().getTime() - const response = await this.sdk.images.generate( - { - model: model.id, - prompt: getMainTextContent(lastUserMessage!) || '', - response_format: model.id.includes('gpt-image-1') ? undefined : 'b64_json' - }, - { - signal + try { + if (lastUserMessage) { + const UserFiles = findImageBlocks(lastUserMessage) + const validUserFiles = UserFiles.filter((f) => f.file) // Filter out files that are undefined first + const userImages = await Promise.all( + validUserFiles.map(async (f) => { + // f.file is guaranteed to exist here due to the filter above + const fileInfo = f.file! + const binaryData = await FileManager.readFile(fileInfo) + console.log('binaryData', binaryData) + const file = await toFile(binaryData, fileInfo.origin_name || 'image.png', { + type: 'image/png' + }) + return file + }) + ) + images = images.concat(userImages) } - ) - onChunk({ - type: ChunkType.IMAGE_COMPLETE, - image: { - type: 'base64', - images: response.data?.map((item) => `data:image/png;base64,${item.b64_json}`) || [] + if (lastAssistantMessage) { + const assistantFiles = findImageBlocks(lastAssistantMessage) + const assistantImages = await Promise.all( + assistantFiles.filter(Boolean).map(async (f) => { + const base64Data = f?.url?.replace(/^data:image\/\w+;base64,/, '') + if (!base64Data) return null + const binary = atob(base64Data) + const bytes = new Uint8Array(binary.length) + for (let i = 0; i < binary.length; i++) { + bytes[i] = binary.charCodeAt(i) + } + const file = await toFile(bytes, 'assistant_image.png', { + type: 'image/png' + }) + return file + }) + ) + images = images.concat(assistantImages.filter(Boolean) as FileLike[]) } - }) + onChunk({ + type: ChunkType.IMAGE_CREATED + }) - // Create synthetic usage and metrics data for image generation - const time_completion_millsec = new Date().getTime() - start_time_millsec - onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - usage: { - completion_tokens: response.usage?.output_tokens || 0, - prompt_tokens: response.usage?.input_tokens || 0, - total_tokens: response.usage?.total_tokens || 0 - }, - metrics: { - completion_tokens: response.usage?.output_tokens || 0, - time_first_token_millsec: 0, // Non-streaming, first token time is not relevant - time_completion_millsec + const start_time_millsec = new Date().getTime() + + if (images.length > 0) { + response = await this.sdk.images.edit( + { + model: model.id, + image: images, + prompt: content || '' + }, + { + signal, + timeout: 300_000 + } + ) + } else { + response = await this.sdk.images.generate( + { + model: model.id, + prompt: content || '', + response_format: model.id.includes('gpt-image-1') ? undefined : 'b64_json' + }, + { + signal, + timeout: 300_000 + } + ) + } + + onChunk({ + type: ChunkType.IMAGE_COMPLETE, + image: { + type: 'base64', + images: response?.data?.map((item) => `data:image/png;base64,${item.b64_json}`) || [] } - } - }) - return + }) + + onChunk({ + type: ChunkType.BLOCK_COMPLETE, + response: { + usage: { + completion_tokens: response.usage?.output_tokens || 0, + prompt_tokens: response.usage?.input_tokens || 0, + total_tokens: response.usage?.total_tokens || 0 + }, + metrics: { + completion_tokens: response.usage?.output_tokens || 0, + time_first_token_millsec: 0, // Non-streaming, first token time is not relevant + time_completion_millsec: new Date().getTime() - start_time_millsec + } + } + }) + } catch (error: any) { + console.error('[generateImageByChat] error', error) + onChunk({ + type: ChunkType.ERROR, + error + }) + } } } diff --git a/src/renderer/src/services/FileManager.ts b/src/renderer/src/services/FileManager.ts index 59452c8746..a4bf7625bc 100644 --- a/src/renderer/src/services/FileManager.ts +++ b/src/renderer/src/services/FileManager.ts @@ -29,7 +29,8 @@ class FileManager { } static async readFile(file: FileType): Promise { - return (await window.api.file.binaryFile(file.id + file.ext)).data + const fileData = await window.api.file.binaryImage(file.id + file.ext) + return fileData.data } static async uploadFile(file: FileType): Promise { diff --git a/yarn.lock b/yarn.lock index ec7c80d1b7..0e84f2eeb7 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1557,6 +1557,22 @@ __metadata: languageName: node linkType: hard +"@iconify-json/svg-spinners@npm:^1.2.2": + version: 1.2.2 + resolution: "@iconify-json/svg-spinners@npm:1.2.2" + dependencies: + "@iconify/types": "npm:*" + checksum: 10c0/61869963c21bc03052d64cd19155f9d596ffc71b3934ccdb468f5b5a1d3f003089c87744bf76145d8b1e946c45a88e1439e1f9470fcfc3a847b88262f1d71c76 + languageName: node + linkType: hard + +"@iconify/types@npm:*": + version: 2.0.0 + resolution: "@iconify/types@npm:2.0.0" + checksum: 10c0/65a3be43500c7ccacf360e136d00e1717f050b7b91da644e94370256ac66f582d59212bdb30d00788aab4fc078262e91c95b805d1808d654b72f6d2072a7e4b2 + languageName: node + linkType: hard + "@isaacs/cliui@npm:^8.0.2": version: 8.0.2 resolution: "@isaacs/cliui@npm:8.0.2" @@ -4324,6 +4340,7 @@ __metadata: "@eslint/js": "npm:^9.22.0" "@google/genai": "npm:^0.10.0" "@hello-pangea/dnd": "npm:^16.6.0" + "@iconify-json/svg-spinners": "npm:^1.2.2" "@kangfenmao/keyv-storage": "npm:^0.1.0" "@langchain/community": "npm:^0.3.36" "@modelcontextprotocol/sdk": "npm:^1.10.2"