diff --git a/src/main/ipc.ts b/src/main/ipc.ts index 8558a96cd5..7e00749a66 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -41,6 +41,12 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle('file:clear', async () => await fileManager.clear()) ipcMain.handle('file:read', async (_, id: string) => await fileManager.readFile(id)) ipcMain.handle('file:delete', async (_, id: string) => await fileManager.deleteFile(id)) + ipcMain.handle('file:get', async (_, filePath: string) => await fileManager.getFile(filePath)) + ipcMain.handle('file:create', async (_, fileName: string) => await fileManager.createTempFile(fileName)) + ipcMain.handle( + 'file:write', + async (_, filePath: string, data: Uint8Array | string) => await fileManager.writeFile(filePath, data) + ) ipcMain.handle('minapp', (_, args) => { createMinappWindow({ diff --git a/src/main/services/File.ts b/src/main/services/File.ts index 8d5113537f..aa660e0a3c 100644 --- a/src/main/services/File.ts +++ b/src/main/services/File.ts @@ -131,6 +131,30 @@ class File { return fileMetadata } + async getFile(filePath: string): Promise { + if (!fs.existsSync(filePath)) { + return null + } + + const stats = fs.statSync(filePath) + const ext = path.extname(filePath) + const fileType = getFileType(ext) + + const fileInfo: FileType = { + id: uuidv4(), + origin_name: path.basename(filePath), + name: path.basename(filePath), + path: filePath, + created_at: stats.birthtime, + size: stats.size, + ext: ext, + type: fileType, + count: 1 + } + + return fileInfo + } + async deleteFile(id: string): Promise { await fs.promises.unlink(path.join(this.storageDir, id)) } @@ -140,6 +164,19 @@ class File { return fs.readFileSync(filePath, 'utf8') } + async createTempFile(fileName: string): Promise { + const tempDir = path.join(app.getPath('temp'), 'CherryStudio') + if (!fs.existsSync(tempDir)) { + fs.mkdirSync(tempDir, { recursive: true }) + } + const tempFilePath = path.join(tempDir, `${uuidv4()}_${fileName}`) + return tempFilePath + } + + async writeFile(filePath: string, data: Uint8Array | string): Promise { + await fs.promises.writeFile(filePath, data) + } + async base64Image(id: string): Promise<{ mime: string; base64: string; data: string }> { const filePath = path.join(this.storageDir, id) const data = await fs.promises.readFile(filePath) diff --git a/src/preload/index.d.ts b/src/preload/index.d.ts index ac379e1d8d..ecbf28024e 100644 --- a/src/preload/index.d.ts +++ b/src/preload/index.d.ts @@ -28,6 +28,9 @@ declare global { read: (fileId: string) => Promise base64Image: (fileId: string) => Promise<{ mime: string; base64: string; data: string }> clear: () => Promise + get: (filePath: string) => Promise + create: (fileName: string) => Promise + write: (filePath: string, data: Uint8Array | string) => Promise } } } diff --git a/src/preload/index.ts b/src/preload/index.ts index d1e85c7c8d..b894397fe8 100644 --- a/src/preload/index.ts +++ b/src/preload/index.ts @@ -22,7 +22,10 @@ const api = { delete: (fileId: string) => ipcRenderer.invoke('file:delete', fileId), read: (fileId: string) => ipcRenderer.invoke('file:read', fileId), base64Image: (fileId: string) => ipcRenderer.invoke('file:base64Image', fileId), - clear: () => ipcRenderer.invoke('file:clear') + clear: () => ipcRenderer.invoke('file:clear'), + get: (filePath: string) => ipcRenderer.invoke('file:get', filePath), + create: (fileName: string) => ipcRenderer.invoke('file:create', fileName), + write: (filePath: string, data: Uint8Array | string) => ipcRenderer.invoke('file:write', filePath, data) } } diff --git a/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx b/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx index 474f7af659..5fb0ebf230 100644 --- a/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx +++ b/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx @@ -18,6 +18,9 @@ const AttachmentButton: FC = ({ model, files, setFiles, ToolbarButton }) const extensions = isVisionModel(model) ? [...imageExts, ...textExts] : [...textExts] const onSelectFile = async () => { + if (files.length > 0) { + return setFiles([]) + } const _files = await window.api.file.select({ filters: [{ name: 'Files', extensions }] }) _files && setFiles(_files) } diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index 5f8d5bbc2d..75b0f6b6af 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -8,6 +8,7 @@ import { PauseCircleOutlined, QuestionCircleOutlined } from '@ant-design/icons' +import { textExts } from '@renderer/config/constant' import db from '@renderer/databases' import { useAssistant } from '@renderer/hooks/useAssistant' import { useSettings } from '@renderer/hooks/useSettings' @@ -19,7 +20,7 @@ import { estimateTextTokens } from '@renderer/services/tokens' import store, { useAppDispatch, useAppSelector } from '@renderer/store' import { setGenerating, setSearching } from '@renderer/store/runtime' import { Assistant, FileType, Message, Topic } from '@renderer/types' -import { delay, uuid } from '@renderer/utils' +import { delay, getFileExtension, uuid } from '@renderer/utils' import { Button, Popconfirm, Tooltip } from 'antd' import TextArea, { TextAreaRef } from 'antd/es/input/TextArea' import dayjs from 'dayjs' @@ -171,6 +172,44 @@ const Inputbar: FC = ({ assistant, setActiveTopic }) => { const onInput = () => !expended && resizeTextArea() + const onPaste = useCallback(async (event: ClipboardEvent) => { + for (const file of event.clipboardData?.files || []) { + event.preventDefault() + const ext = getFileExtension(file.path) + if (textExts.includes(ext)) { + const selectedFile = await window.api.file.get(file.path) + selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile]) + } + } + + if (event.clipboardData?.items) { + const item = event.clipboardData.items[0] + const file = item.getAsFile() + if (file && file.type.startsWith('image/')) { + const tempFilePath = await window.api.file.create(file.name) + const arrayBuffer = await file.arrayBuffer() + const uint8Array = new Uint8Array(arrayBuffer) + await window.api.file.write(tempFilePath, uint8Array) + const selectedFile = await window.api.file.get(tempFilePath) + selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile]) + } + // if (item.kind === 'string' && item.type === 'text/plain') { + // // 处理文本内容 + // await new Promise((resolve) => { + // item.getAsString(async (text) => { + // const tempFilePath = await window.api.file.create('pasted_text.txt') + // await window.api.file.write(tempFilePath, text) + // const selectedFile = await window.api.file.get(tempFilePath) + // if (selectedFile) { + // newFiles.push(selectedFile) + // } + // resolve() + // }) + // }) + // } + } + }, []) + // Command or Ctrl + N create new topic useEffect(() => { const onKeydown = (e) => { @@ -206,6 +245,11 @@ const Inputbar: FC = ({ assistant, setActiveTopic }) => { textareaRef.current?.focus() }, [assistant]) + useEffect(() => { + document.addEventListener('paste', onPaste) + return () => document.removeEventListener('paste', onPaste) + }, [onPaste]) + return ( diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts index d4e33e4c41..0723091812 100644 --- a/src/renderer/src/providers/AnthropicProvider.ts +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -18,49 +18,33 @@ export default class AnthropicProvider extends BaseProvider { this.sdk = new Anthropic({ apiKey: provider.apiKey, baseURL: this.getBaseURL() }) } - private async getMessageParam(message: Message): Promise { - const file = first(message.files) + private async getMessageParam(message: Message): Promise { + const parts: MessageParam['content'] = [{ type: 'text', text: message.content }] - if (file) { + for (const file of message.files || []) { if (file.type === FileTypes.IMAGE) { const base64Data = await window.api.file.base64Image(file.id + file.ext) - return [ - { - role: message.role, - content: [ - { type: 'text', text: message.content }, - { - type: 'image', - source: { - data: base64Data.base64, - media_type: base64Data.mime.replace('jpg', 'jpeg') as any, - type: 'base64' - } - } - ] - } as MessageParam - ] + parts.push({ + type: 'image', + source: { + data: base64Data.base64, + media_type: base64Data.mime.replace('jpg', 'jpeg') as any, + type: 'base64' + } + }) } if (file.type === FileTypes.TEXT) { - return [ - { - role: message.role, - content: message.content - } as MessageParam, - { - role: 'assistant', - content: (await window.api.file.read(file.id + file.ext)).trimEnd() - } as MessageParam - ] + parts.push({ + type: 'text', + text: (await window.api.file.read(file.id + file.ext)).trimEnd() + }) } } - return [ - { - role: message.role, - content: message.content - } as MessageParam - ] + return { + role: message.role, + content: parts + } } public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) { diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index c450820fb2..f1ccfeccca 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -1,10 +1,10 @@ -import { Content, GoogleGenerativeAI, InlineDataPart, TextPart } from '@google/generative-ai' +import { Content, GoogleGenerativeAI, InlineDataPart, Part, TextPart } from '@google/generative-ai' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant' import { EVENT_NAMES } from '@renderer/services/event' import { filterContextMessages, filterMessages } from '@renderer/services/messages' import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types' import axios from 'axios' -import { first, flatten, isEmpty, takeRight } from 'lodash' +import { flatten, isEmpty, takeRight } from 'lodash' import OpenAI from 'openai' import BaseProvider from './BaseProvider' @@ -17,48 +17,37 @@ export default class GeminiProvider extends BaseProvider { this.sdk = new GoogleGenerativeAI(provider.apiKey) } - private async getMessageContents(message: Message): Promise { - const file = first(message.files) + private async getMessageContents(message: Message): Promise { const role = message.role === 'user' ? 'user' : 'model' - if (file) { + const parts: Part[] = [ + { + type: 'text', + text: message.content + } as TextPart + ] + + for (const file of message.files || []) { if (file.type === FileTypes.IMAGE) { const base64Data = await window.api.file.base64Image(file.id + file.ext) - return [ - { - role: message.role, - parts: [ - { text: message.content } as TextPart, - { - inlineData: { - data: base64Data.base64, - mimeType: base64Data.mime - } - } as InlineDataPart - ] + parts.push({ + inlineData: { + data: base64Data.base64, + mimeType: base64Data.mime } - ] + } as InlineDataPart) } if (file.type === FileTypes.TEXT) { - return [ - { - role: 'model', - parts: [{ text: await window.api.file.read(file.id + file.ext) } as TextPart] - }, - { - role, - parts: [{ text: message.content } as TextPart] - } - ] + parts.push({ + text: await window.api.file.read(file.id + file.ext) + } as TextPart) } } - return [ - { - role, - parts: [{ text: message.content } as TextPart] - } - ] + return { + role, + parts: parts + } } public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) { diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index 58ea55e9a0..af2d39d4b1 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -33,49 +33,34 @@ export default class OpenAIProvider extends BaseProvider { return true } - private async getMessageParam(message: Message): Promise { - const file = first(message.files) + private async getMessageParam(message: Message): Promise { + const parts: ChatCompletionContentPart[] = [ + { + type: 'text', + text: message.content + } + ] - const content: string | ChatCompletionContentPart[] = message.content - - if (file) { + for (const file of message.files || []) { if (file.type === FileTypes.IMAGE) { const image = await window.api.file.base64Image(file.id + file.ext) - return [ - { - role: message.role, - content: [ - { type: 'text', text: message.content }, - { - type: 'image_url', - image_url: { - url: image.data - } - } - ] - } as ChatCompletionMessageParam - ] + parts.push({ + type: 'image_url', + image_url: { url: image.data } + }) } if (file.type === FileTypes.TEXT) { - return [ - { - role: 'assistant', - content: await window.api.file.read(file.id + file.ext) - } as ChatCompletionMessageParam, - { - role: message.role, - content - } as ChatCompletionMessageParam - ] + parts.push({ + type: 'text', + text: await window.api.file.read(file.id + file.ext) + }) } } - return [ - { - role: message.role, - content - } as ChatCompletionMessageParam - ] + return { + role: message.role, + content: parts + } as ChatCompletionMessageParam } async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise { @@ -84,13 +69,13 @@ export default class OpenAIProvider extends BaseProvider { const { contextCount, maxTokens } = getAssistantSettings(assistant) const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined - let userMessages: ChatCompletionMessageParam[] = [] + const userMessages: ChatCompletionMessageParam[] = [] const _messages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1))) onFilterMessages(_messages) for (const message of _messages) { - userMessages = userMessages.concat(await this.getMessageParam(message)) + userMessages.push(await this.getMessageParam(message)) } // @ts-ignore key is not typed diff --git a/src/renderer/src/utils/index.ts b/src/renderer/src/utils/index.ts index 9dab51029c..c97cc32401 100644 --- a/src/renderer/src/utils/index.ts +++ b/src/renderer/src/utils/index.ts @@ -235,3 +235,9 @@ export function getFileDirectory(filePath: string) { const directory = parts.slice(0, -1).join('/') return directory } + +export function getFileExtension(filePath: string) { + const parts = filePath.split('.') + const extension = parts.slice(-1)[0] + return '.' + extension +}