feat: copy and paste files or images

This commit is contained in:
kangfenmao 2024-09-18 21:00:15 +08:00
parent b9bb0c0f40
commit 540f0126d8
10 changed files with 168 additions and 108 deletions

View File

@ -41,6 +41,12 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle('file:clear', async () => await fileManager.clear()) ipcMain.handle('file:clear', async () => await fileManager.clear())
ipcMain.handle('file:read', async (_, id: string) => await fileManager.readFile(id)) ipcMain.handle('file:read', async (_, id: string) => await fileManager.readFile(id))
ipcMain.handle('file:delete', async (_, id: string) => await fileManager.deleteFile(id)) ipcMain.handle('file:delete', async (_, id: string) => await fileManager.deleteFile(id))
ipcMain.handle('file:get', async (_, filePath: string) => await fileManager.getFile(filePath))
ipcMain.handle('file:create', async (_, fileName: string) => await fileManager.createTempFile(fileName))
ipcMain.handle(
'file:write',
async (_, filePath: string, data: Uint8Array | string) => await fileManager.writeFile(filePath, data)
)
ipcMain.handle('minapp', (_, args) => { ipcMain.handle('minapp', (_, args) => {
createMinappWindow({ createMinappWindow({

View File

@ -131,6 +131,30 @@ class File {
return fileMetadata return fileMetadata
} }
async getFile(filePath: string): Promise<FileType | null> {
if (!fs.existsSync(filePath)) {
return null
}
const stats = fs.statSync(filePath)
const ext = path.extname(filePath)
const fileType = getFileType(ext)
const fileInfo: FileType = {
id: uuidv4(),
origin_name: path.basename(filePath),
name: path.basename(filePath),
path: filePath,
created_at: stats.birthtime,
size: stats.size,
ext: ext,
type: fileType,
count: 1
}
return fileInfo
}
async deleteFile(id: string): Promise<void> { async deleteFile(id: string): Promise<void> {
await fs.promises.unlink(path.join(this.storageDir, id)) await fs.promises.unlink(path.join(this.storageDir, id))
} }
@ -140,6 +164,19 @@ class File {
return fs.readFileSync(filePath, 'utf8') return fs.readFileSync(filePath, 'utf8')
} }
async createTempFile(fileName: string): Promise<string> {
const tempDir = path.join(app.getPath('temp'), 'CherryStudio')
if (!fs.existsSync(tempDir)) {
fs.mkdirSync(tempDir, { recursive: true })
}
const tempFilePath = path.join(tempDir, `${uuidv4()}_${fileName}`)
return tempFilePath
}
async writeFile(filePath: string, data: Uint8Array | string): Promise<void> {
await fs.promises.writeFile(filePath, data)
}
async base64Image(id: string): Promise<{ mime: string; base64: string; data: string }> { async base64Image(id: string): Promise<{ mime: string; base64: string; data: string }> {
const filePath = path.join(this.storageDir, id) const filePath = path.join(this.storageDir, id)
const data = await fs.promises.readFile(filePath) const data = await fs.promises.readFile(filePath)

View File

@ -28,6 +28,9 @@ declare global {
read: (fileId: string) => Promise<string> read: (fileId: string) => Promise<string>
base64Image: (fileId: string) => Promise<{ mime: string; base64: string; data: string }> base64Image: (fileId: string) => Promise<{ mime: string; base64: string; data: string }>
clear: () => Promise<void> clear: () => Promise<void>
get: (filePath: string) => Promise<FileType | null>
create: (fileName: string) => Promise<string>
write: (filePath: string, data: Uint8Array | string) => Promise<void>
} }
} }
} }

View File

@ -22,7 +22,10 @@ const api = {
delete: (fileId: string) => ipcRenderer.invoke('file:delete', fileId), delete: (fileId: string) => ipcRenderer.invoke('file:delete', fileId),
read: (fileId: string) => ipcRenderer.invoke('file:read', fileId), read: (fileId: string) => ipcRenderer.invoke('file:read', fileId),
base64Image: (fileId: string) => ipcRenderer.invoke('file:base64Image', fileId), base64Image: (fileId: string) => ipcRenderer.invoke('file:base64Image', fileId),
clear: () => ipcRenderer.invoke('file:clear') clear: () => ipcRenderer.invoke('file:clear'),
get: (filePath: string) => ipcRenderer.invoke('file:get', filePath),
create: (fileName: string) => ipcRenderer.invoke('file:create', fileName),
write: (filePath: string, data: Uint8Array | string) => ipcRenderer.invoke('file:write', filePath, data)
} }
} }

View File

@ -18,6 +18,9 @@ const AttachmentButton: FC<Props> = ({ model, files, setFiles, ToolbarButton })
const extensions = isVisionModel(model) ? [...imageExts, ...textExts] : [...textExts] const extensions = isVisionModel(model) ? [...imageExts, ...textExts] : [...textExts]
const onSelectFile = async () => { const onSelectFile = async () => {
if (files.length > 0) {
return setFiles([])
}
const _files = await window.api.file.select({ filters: [{ name: 'Files', extensions }] }) const _files = await window.api.file.select({ filters: [{ name: 'Files', extensions }] })
_files && setFiles(_files) _files && setFiles(_files)
} }

View File

@ -8,6 +8,7 @@ import {
PauseCircleOutlined, PauseCircleOutlined,
QuestionCircleOutlined QuestionCircleOutlined
} from '@ant-design/icons' } from '@ant-design/icons'
import { textExts } from '@renderer/config/constant'
import db from '@renderer/databases' import db from '@renderer/databases'
import { useAssistant } from '@renderer/hooks/useAssistant' import { useAssistant } from '@renderer/hooks/useAssistant'
import { useSettings } from '@renderer/hooks/useSettings' import { useSettings } from '@renderer/hooks/useSettings'
@ -19,7 +20,7 @@ import { estimateTextTokens } from '@renderer/services/tokens'
import store, { useAppDispatch, useAppSelector } from '@renderer/store' import store, { useAppDispatch, useAppSelector } from '@renderer/store'
import { setGenerating, setSearching } from '@renderer/store/runtime' import { setGenerating, setSearching } from '@renderer/store/runtime'
import { Assistant, FileType, Message, Topic } from '@renderer/types' import { Assistant, FileType, Message, Topic } from '@renderer/types'
import { delay, uuid } from '@renderer/utils' import { delay, getFileExtension, uuid } from '@renderer/utils'
import { Button, Popconfirm, Tooltip } from 'antd' import { Button, Popconfirm, Tooltip } from 'antd'
import TextArea, { TextAreaRef } from 'antd/es/input/TextArea' import TextArea, { TextAreaRef } from 'antd/es/input/TextArea'
import dayjs from 'dayjs' import dayjs from 'dayjs'
@ -171,6 +172,44 @@ const Inputbar: FC<Props> = ({ assistant, setActiveTopic }) => {
const onInput = () => !expended && resizeTextArea() const onInput = () => !expended && resizeTextArea()
const onPaste = useCallback(async (event: ClipboardEvent) => {
for (const file of event.clipboardData?.files || []) {
event.preventDefault()
const ext = getFileExtension(file.path)
if (textExts.includes(ext)) {
const selectedFile = await window.api.file.get(file.path)
selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile])
}
}
if (event.clipboardData?.items) {
const item = event.clipboardData.items[0]
const file = item.getAsFile()
if (file && file.type.startsWith('image/')) {
const tempFilePath = await window.api.file.create(file.name)
const arrayBuffer = await file.arrayBuffer()
const uint8Array = new Uint8Array(arrayBuffer)
await window.api.file.write(tempFilePath, uint8Array)
const selectedFile = await window.api.file.get(tempFilePath)
selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile])
}
// if (item.kind === 'string' && item.type === 'text/plain') {
// // 处理文本内容
// await new Promise<void>((resolve) => {
// item.getAsString(async (text) => {
// const tempFilePath = await window.api.file.create('pasted_text.txt')
// await window.api.file.write(tempFilePath, text)
// const selectedFile = await window.api.file.get(tempFilePath)
// if (selectedFile) {
// newFiles.push(selectedFile)
// }
// resolve()
// })
// })
// }
}
}, [])
// Command or Ctrl + N create new topic // Command or Ctrl + N create new topic
useEffect(() => { useEffect(() => {
const onKeydown = (e) => { const onKeydown = (e) => {
@ -206,6 +245,11 @@ const Inputbar: FC<Props> = ({ assistant, setActiveTopic }) => {
textareaRef.current?.focus() textareaRef.current?.focus()
}, [assistant]) }, [assistant])
useEffect(() => {
document.addEventListener('paste', onPaste)
return () => document.removeEventListener('paste', onPaste)
}, [onPaste])
return ( return (
<Container> <Container>
<AttachmentPreview files={files} setFiles={setFiles} /> <AttachmentPreview files={files} setFiles={setFiles} />

View File

@ -18,49 +18,33 @@ export default class AnthropicProvider extends BaseProvider {
this.sdk = new Anthropic({ apiKey: provider.apiKey, baseURL: this.getBaseURL() }) this.sdk = new Anthropic({ apiKey: provider.apiKey, baseURL: this.getBaseURL() })
} }
private async getMessageParam(message: Message): Promise<MessageParam[]> { private async getMessageParam(message: Message): Promise<MessageParam> {
const file = first(message.files) const parts: MessageParam['content'] = [{ type: 'text', text: message.content }]
if (file) { for (const file of message.files || []) {
if (file.type === FileTypes.IMAGE) { if (file.type === FileTypes.IMAGE) {
const base64Data = await window.api.file.base64Image(file.id + file.ext) const base64Data = await window.api.file.base64Image(file.id + file.ext)
return [ parts.push({
{ type: 'image',
role: message.role, source: {
content: [ data: base64Data.base64,
{ type: 'text', text: message.content }, media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
{ type: 'base64'
type: 'image', }
source: { })
data: base64Data.base64,
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
type: 'base64'
}
}
]
} as MessageParam
]
} }
if (file.type === FileTypes.TEXT) { if (file.type === FileTypes.TEXT) {
return [ parts.push({
{ type: 'text',
role: message.role, text: (await window.api.file.read(file.id + file.ext)).trimEnd()
content: message.content })
} as MessageParam,
{
role: 'assistant',
content: (await window.api.file.read(file.id + file.ext)).trimEnd()
} as MessageParam
]
} }
} }
return [ return {
{ role: message.role,
role: message.role, content: parts
content: message.content }
} as MessageParam
]
} }
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) { public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) {

View File

@ -1,10 +1,10 @@
import { Content, GoogleGenerativeAI, InlineDataPart, TextPart } from '@google/generative-ai' import { Content, GoogleGenerativeAI, InlineDataPart, Part, TextPart } from '@google/generative-ai'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
import { EVENT_NAMES } from '@renderer/services/event' import { EVENT_NAMES } from '@renderer/services/event'
import { filterContextMessages, filterMessages } from '@renderer/services/messages' import { filterContextMessages, filterMessages } from '@renderer/services/messages'
import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types' import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types'
import axios from 'axios' import axios from 'axios'
import { first, flatten, isEmpty, takeRight } from 'lodash' import { flatten, isEmpty, takeRight } from 'lodash'
import OpenAI from 'openai' import OpenAI from 'openai'
import BaseProvider from './BaseProvider' import BaseProvider from './BaseProvider'
@ -17,48 +17,37 @@ export default class GeminiProvider extends BaseProvider {
this.sdk = new GoogleGenerativeAI(provider.apiKey) this.sdk = new GoogleGenerativeAI(provider.apiKey)
} }
private async getMessageContents(message: Message): Promise<Content[]> { private async getMessageContents(message: Message): Promise<Content> {
const file = first(message.files)
const role = message.role === 'user' ? 'user' : 'model' const role = message.role === 'user' ? 'user' : 'model'
if (file) { const parts: Part[] = [
{
type: 'text',
text: message.content
} as TextPart
]
for (const file of message.files || []) {
if (file.type === FileTypes.IMAGE) { if (file.type === FileTypes.IMAGE) {
const base64Data = await window.api.file.base64Image(file.id + file.ext) const base64Data = await window.api.file.base64Image(file.id + file.ext)
return [ parts.push({
{ inlineData: {
role: message.role, data: base64Data.base64,
parts: [ mimeType: base64Data.mime
{ text: message.content } as TextPart,
{
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
}
} as InlineDataPart
]
} }
] } as InlineDataPart)
} }
if (file.type === FileTypes.TEXT) { if (file.type === FileTypes.TEXT) {
return [ parts.push({
{ text: await window.api.file.read(file.id + file.ext)
role: 'model', } as TextPart)
parts: [{ text: await window.api.file.read(file.id + file.ext) } as TextPart]
},
{
role,
parts: [{ text: message.content } as TextPart]
}
]
} }
} }
return [ return {
{ role,
role, parts: parts
parts: [{ text: message.content } as TextPart] }
}
]
} }
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) { public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) {

View File

@ -33,49 +33,34 @@ export default class OpenAIProvider extends BaseProvider {
return true return true
} }
private async getMessageParam(message: Message): Promise<OpenAI.Chat.Completions.ChatCompletionMessageParam[]> { private async getMessageParam(message: Message): Promise<OpenAI.Chat.Completions.ChatCompletionMessageParam> {
const file = first(message.files) const parts: ChatCompletionContentPart[] = [
{
type: 'text',
text: message.content
}
]
const content: string | ChatCompletionContentPart[] = message.content for (const file of message.files || []) {
if (file) {
if (file.type === FileTypes.IMAGE) { if (file.type === FileTypes.IMAGE) {
const image = await window.api.file.base64Image(file.id + file.ext) const image = await window.api.file.base64Image(file.id + file.ext)
return [ parts.push({
{ type: 'image_url',
role: message.role, image_url: { url: image.data }
content: [ })
{ type: 'text', text: message.content },
{
type: 'image_url',
image_url: {
url: image.data
}
}
]
} as ChatCompletionMessageParam
]
} }
if (file.type === FileTypes.TEXT) { if (file.type === FileTypes.TEXT) {
return [ parts.push({
{ type: 'text',
role: 'assistant', text: await window.api.file.read(file.id + file.ext)
content: await window.api.file.read(file.id + file.ext) })
} as ChatCompletionMessageParam,
{
role: message.role,
content
} as ChatCompletionMessageParam
]
} }
} }
return [ return {
{ role: message.role,
role: message.role, content: parts
content } as ChatCompletionMessageParam
} as ChatCompletionMessageParam
]
} }
async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> { async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
@ -84,13 +69,13 @@ export default class OpenAIProvider extends BaseProvider {
const { contextCount, maxTokens } = getAssistantSettings(assistant) const { contextCount, maxTokens } = getAssistantSettings(assistant)
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
let userMessages: ChatCompletionMessageParam[] = [] const userMessages: ChatCompletionMessageParam[] = []
const _messages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1))) const _messages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
onFilterMessages(_messages) onFilterMessages(_messages)
for (const message of _messages) { for (const message of _messages) {
userMessages = userMessages.concat(await this.getMessageParam(message)) userMessages.push(await this.getMessageParam(message))
} }
// @ts-ignore key is not typed // @ts-ignore key is not typed

View File

@ -235,3 +235,9 @@ export function getFileDirectory(filePath: string) {
const directory = parts.slice(0, -1).join('/') const directory = parts.slice(0, -1).join('/')
return directory return directory
} }
export function getFileExtension(filePath: string) {
const parts = filePath.split('.')
const extension = parts.slice(-1)[0]
return '.' + extension
}