From af2186e62935a2d8f648578092e2f60be59cdf1b Mon Sep 17 00:00:00 2001 From: kangfenmao Date: Tue, 3 Sep 2024 19:00:24 +0800 Subject: [PATCH] refactor: provider sdk --- src/renderer/src/App.tsx | 4 +- .../src/components/EmojiPicker/index.tsx | 2 +- src/renderer/src/components/MinApp/index.tsx | 2 +- src/renderer/src/components/app/Sidebar.tsx | 23 +- .../{providers => context}/AntdProvider.tsx | 0 .../{providers => context}/ThemeProvider.tsx | 0 src/renderer/src/pages/apps/App.tsx | 2 +- src/renderer/src/pages/home/Assistants.tsx | 28 +- src/renderer/src/pages/home/Chat.tsx | 2 +- src/renderer/src/pages/home/HomePage.tsx | 45 ++- .../pages/home/Inputbar/AttachmentButton.tsx | 24 +- .../src/pages/home/Inputbar/Inputbar.tsx | 14 +- .../src/pages/home/Markdown/CodeBlock.tsx | 2 +- src/renderer/src/pages/home/Settings.tsx | 24 +- .../ProviderSettings/ProviderSetting.tsx | 2 +- src/renderer/src/providers/AiProvider.ts | 40 ++ .../src/providers/AnthropicProvider.ts | 143 +++++++ src/renderer/src/providers/BaseProvider.ts | 33 ++ src/renderer/src/providers/GeminiProvider.ts | 170 +++++++++ src/renderer/src/providers/OpenAIProvider.ts | 185 +++++++++ src/renderer/src/providers/ProviderFactory.ts | 19 + src/renderer/src/services/ProviderSDK.ts | 358 ------------------ src/renderer/src/services/api.ts | 28 +- src/renderer/src/types/index.ts | 7 +- src/renderer/src/utils/index.ts | 15 + 25 files changed, 718 insertions(+), 454 deletions(-) rename src/renderer/src/{providers => context}/AntdProvider.tsx (100%) rename src/renderer/src/{providers => context}/ThemeProvider.tsx (100%) create mode 100644 src/renderer/src/providers/AiProvider.ts create mode 100644 src/renderer/src/providers/AnthropicProvider.ts create mode 100644 src/renderer/src/providers/BaseProvider.ts create mode 100644 src/renderer/src/providers/GeminiProvider.ts create mode 100644 src/renderer/src/providers/OpenAIProvider.ts create mode 100644 src/renderer/src/providers/ProviderFactory.ts delete mode 100644 src/renderer/src/services/ProviderSDK.ts diff --git a/src/renderer/src/App.tsx b/src/renderer/src/App.tsx index ec07c95aa1..d1abfc44f2 100644 --- a/src/renderer/src/App.tsx +++ b/src/renderer/src/App.tsx @@ -5,13 +5,13 @@ import { PersistGate } from 'redux-persist/integration/react' import Sidebar from './components/app/Sidebar' import TopViewContainer from './components/TopView' +import AntdProvider from './context/AntdProvider' +import { ThemeProvider } from './context/ThemeProvider' import AgentsPage from './pages/agents/AgentsPage' import AppsPage from './pages/apps/AppsPage' import HomePage from './pages/home/HomePage' import SettingsPage from './pages/settings/SettingsPage' import TranslatePage from './pages/translate/TranslatePage' -import AntdProvider from './providers/AntdProvider' -import { ThemeProvider } from './providers/ThemeProvider' function App(): JSX.Element { return ( diff --git a/src/renderer/src/components/EmojiPicker/index.tsx b/src/renderer/src/components/EmojiPicker/index.tsx index 1668b884e0..c9345c7253 100644 --- a/src/renderer/src/components/EmojiPicker/index.tsx +++ b/src/renderer/src/components/EmojiPicker/index.tsx @@ -1,4 +1,4 @@ -import { useTheme } from '@renderer/providers/ThemeProvider' +import { useTheme } from '@renderer/context/ThemeProvider' import { FC, useEffect, useRef } from 'react' interface Props { diff --git a/src/renderer/src/components/MinApp/index.tsx b/src/renderer/src/components/MinApp/index.tsx index a3b4e4dc4e..0564d62cf9 100644 --- a/src/renderer/src/components/MinApp/index.tsx +++ b/src/renderer/src/components/MinApp/index.tsx @@ -95,7 +95,7 @@ const PopupContainer: React.FC = ({ app, resolve }) => { maskClosable={false} closeIcon={null} style={{ marginLeft: 'var(--sidebar-width)' }}> - + ) } diff --git a/src/renderer/src/components/app/Sidebar.tsx b/src/renderer/src/components/app/Sidebar.tsx index ead585a2da..6d3fdd68fa 100644 --- a/src/renderer/src/components/app/Sidebar.tsx +++ b/src/renderer/src/components/app/Sidebar.tsx @@ -6,7 +6,7 @@ import { useRuntime, useShowAssistants } from '@renderer/hooks/useStore' import { Avatar } from 'antd' import { FC } from 'react' import { useTranslation } from 'react-i18next' -import { Link, useLocation } from 'react-router-dom' +import { useLocation, useNavigate } from 'react-router-dom' import styled from 'styled-components' import UserPopup from '../Popups/UserPopup' @@ -20,6 +20,7 @@ const Sidebar: FC = () => { const { toggleShowAssistants } = useShowAssistants() const { generating } = useRuntime() const { t } = useTranslation() + const navigate = useNavigate() const isRoute = (path: string): string => (pathname === path ? 'active' : '') @@ -28,15 +29,13 @@ const Sidebar: FC = () => { const to = (path: string) => { if (generating) { window.message.warning({ content: t('message.switch.disabled'), key: 'switch-assistant' }) - return '/' + return } - return path + navigate(path) } const onToggleShowAssistants = () => { - if (pathname === '/') { - toggleShowAssistants() - } + pathname === '/' ? toggleShowAssistants() : navigate('/') } return ( @@ -44,22 +43,22 @@ const Sidebar: FC = () => { - + - + to('/agents')}> - + to('/translate')}> - + to('/apps')}> @@ -67,7 +66,7 @@ const Sidebar: FC = () => { - + to(isLocalAi ? '/settings/assistant' : '/settings/provider')}> @@ -149,7 +148,7 @@ const Icon = styled.div` } ` -const StyledLink = styled(Link)` +const StyledLink = styled.div` text-decoration: none; -webkit-app-region: none; &* { diff --git a/src/renderer/src/providers/AntdProvider.tsx b/src/renderer/src/context/AntdProvider.tsx similarity index 100% rename from src/renderer/src/providers/AntdProvider.tsx rename to src/renderer/src/context/AntdProvider.tsx diff --git a/src/renderer/src/providers/ThemeProvider.tsx b/src/renderer/src/context/ThemeProvider.tsx similarity index 100% rename from src/renderer/src/providers/ThemeProvider.tsx rename to src/renderer/src/context/ThemeProvider.tsx diff --git a/src/renderer/src/pages/apps/App.tsx b/src/renderer/src/pages/apps/App.tsx index 1df2d1e390..0ef3292a4d 100644 --- a/src/renderer/src/pages/apps/App.tsx +++ b/src/renderer/src/pages/apps/App.tsx @@ -1,5 +1,5 @@ import MinApp from '@renderer/components/MinApp' -import { useTheme } from '@renderer/providers/ThemeProvider' +import { useTheme } from '@renderer/context/ThemeProvider' import { MinAppType } from '@renderer/types' import { FC } from 'react' import styled from 'styled-components' diff --git a/src/renderer/src/pages/home/Assistants.tsx b/src/renderer/src/pages/home/Assistants.tsx index b238d81986..18bf879032 100644 --- a/src/renderer/src/pages/home/Assistants.tsx +++ b/src/renderer/src/pages/home/Assistants.tsx @@ -1,5 +1,4 @@ import { ArrowRightOutlined, CopyOutlined, DeleteOutlined, EditOutlined } from '@ant-design/icons' -import { ArrowLeftOutlined } from '@ant-design/icons' import DragableList from '@renderer/components/DragableList' import { HStack } from '@renderer/components/Layout' import AssistantSettingPopup from '@renderer/components/Popups/AssistantSettingPopup' @@ -104,10 +103,6 @@ const Assistants: FC = ({ if (showTopics) { return ( - setShowTopics(false)}> - - {t('common.back')} - ) @@ -142,6 +137,7 @@ const Container = styled.div` height: calc(100vh - var(--navbar-height)); overflow-y: auto; padding: 10px 0; + padding-bottom: 0; ` const AssistantItem = styled.div` @@ -155,40 +151,24 @@ const AssistantItem = styled.div` cursor: pointer; font-family: Ubuntu; .anticon { - display: none; + opacity: 0; color: var(--color-text-3); + transition: opacity 0.2s ease-in-out; } &:hover { background-color: var(--color-background-soft); - .count { - display: none; - } .anticon { - display: block; + opacity: 1; } } &.active { background-color: var(--color-background-mute); - cursor: pointer; .name { font-weight: 500; } } ` -const NavigtaionHeader = styled.div` - display: flex; - flex-direction: row; - align-items: center; - justify-content: flex-start; - gap: 10px; - padding: 0 5px; - cursor: pointer; - color: var(--color-text-3); - margin: 10px; - margin-top: 0; -` - const AssistantName = styled.div` color: var(--color-text); display: -webkit-box; diff --git a/src/renderer/src/pages/home/Chat.tsx b/src/renderer/src/pages/home/Chat.tsx index b2d5f76188..d314545100 100644 --- a/src/renderer/src/pages/home/Chat.tsx +++ b/src/renderer/src/pages/home/Chat.tsx @@ -29,7 +29,7 @@ const Chat: FC = (props) => { setShowSetting={setShowSetting} /> - {showSetting && } + {showSetting && setShowSetting(false)} />} ) } diff --git a/src/renderer/src/pages/home/HomePage.tsx b/src/renderer/src/pages/home/HomePage.tsx index f7ffb8cf03..6ca6c90f25 100644 --- a/src/renderer/src/pages/home/HomePage.tsx +++ b/src/renderer/src/pages/home/HomePage.tsx @@ -1,9 +1,11 @@ +import { ArrowLeftOutlined } from '@ant-design/icons' import { Navbar, NavbarCenter, NavbarLeft, NavbarRight } from '@renderer/components/app/Navbar' import { isMac, isWindows } from '@renderer/config/constant' -import { useAssistants, useDefaultAssistant } from '@renderer/hooks/useAssistant' +import { useTheme } from '@renderer/context/ThemeProvider' +import { useAssistant, useAssistants, useDefaultAssistant } from '@renderer/hooks/useAssistant' import { useShowAssistants } from '@renderer/hooks/useStore' import { useActiveTopic } from '@renderer/hooks/useTopic' -import { useTheme } from '@renderer/providers/ThemeProvider' +import { getDefaultTopic } from '@renderer/services/assistant' import { Assistant, Topic } from '@renderer/types' import { uuid } from '@renderer/utils' import { Switch } from 'antd' @@ -29,6 +31,7 @@ const HomePage: FC = () => { const { t } = useTranslation() const { activeTopic, setActiveTopic } = useActiveTopic(activeAssistant) + const { addTopic } = useAssistant(activeAssistant.id) _activeAssistant = activeAssistant _showTopics = showTopics @@ -39,9 +42,15 @@ const HomePage: FC = () => { setActiveAssistant(assistant) } - const onCreateAssistant = async () => { - const assistant = await AddAssistantPopup.show() - assistant && setActiveAssistant(assistant) + const onCreate = async () => { + if (showTopics) { + const topic = getDefaultTopic() + addTopic(topic) + setActiveTopic(topic) + } else { + const assistant = await AddAssistantPopup.show() + assistant && setActiveAssistant(assistant) + } } const onSetActiveTopic = (topic: Topic) => { @@ -53,8 +62,13 @@ const HomePage: FC = () => { {showAssistants && ( - - + + setShowTopics(false)} style={{ opacity: showTopics ? 1 : 0 }}> + + {t('common.back')} + + @@ -103,6 +117,23 @@ const ContentContainer = styled.div` background-color: var(--color-background); ` +const NavigtaionBack = styled.div` + display: flex; + flex-direction: row; + align-items: center; + justify-content: flex-start; + gap: 10px; + cursor: pointer; + margin-left: ${isMac ? '16px' : 0}; + -webkit-app-region: none; + transition: all 0.2s ease-in-out; + color: var(--color-icon); + transition: opacity 0.2s ease-in-out; + &:hover { + color: var(--color-text); + } +` + const AssistantName = styled.span` margin-left: 5px; margin-right: 10px; diff --git a/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx b/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx index db805eabf4..33c3f195ea 100644 --- a/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx +++ b/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx @@ -4,12 +4,12 @@ import { FC } from 'react' import { useTranslation } from 'react-i18next' interface Props { - images: string[] - setImages: (images: string[]) => void + files: File[] + setFiles: (files: File[]) => void ToolbarButton: any } -const AttachmentButton: FC = ({ images, setImages, ToolbarButton }) => { +const AttachmentButton: FC = ({ files, setFiles, ToolbarButton }) => { const { t } = useTranslation() return ( @@ -19,22 +19,8 @@ const AttachmentButton: FC = ({ images, setImages, ToolbarButton }) => { accept="image/*" itemRender={() => null} maxCount={1} - onChange={async ({ file }) => { - try { - const _file = file.originFileObj as File - const reader = new FileReader() - reader.onload = (e: ProgressEvent) => { - const result = e.target?.result - if (typeof result === 'string') { - setImages([result]) - } - } - reader.readAsDataURL(_file) - } catch (error: any) { - window.message.error(error.message) - } - }}> - + onChange={async ({ file }) => file?.originFileObj && setFiles([file.originFileObj as File])}> + diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index 103f68993d..57cfefb6ce 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -25,7 +25,6 @@ import { CSSProperties, FC, useCallback, useEffect, useMemo, useRef, useState } import { useTranslation } from 'react-i18next' import styled from 'styled-components' -import AttachmentButton from './AttachmentButton' import SendMessageButton from './SendMessageButton' interface Props { @@ -46,7 +45,7 @@ const Inputbar: FC = ({ assistant, setActiveTopic, showSetting, setShowSe const [estimateTokenCount, setEstimateTokenCount] = useState(0) const generating = useAppSelector((state) => state.runtime.generating) const textareaRef = useRef(null) - const [images, setImages] = useState([]) + const [files, setFiles] = useState([]) const { t } = useTranslation() const containerRef = useRef(null) @@ -71,18 +70,19 @@ const Inputbar: FC = ({ assistant, setActiveTopic, showSetting, setShowSe status: 'success' } - if (images.length > 0) { - message.images = images + if (files.length > 0) { + message.files = files } EventEmitter.emit(EVENT_NAMES.SEND_MESSAGE, message) setText('') - setImages([]) + setFiles([]) setTimeout(() => setText(''), 500) + setTimeout(() => resizeTextArea(), 0) setExpend(false) - }, [assistant.id, assistant.topics, generating, images, text]) + }, [assistant.id, assistant.topics, generating, files, text]) const inputTokenCount = useMemo(() => estimateInputTokenCount(text), [text]) @@ -226,7 +226,7 @@ const Inputbar: FC = ({ assistant, setActiveTopic, showSetting, setShowSe - + {/* */} {expended ? : } diff --git a/src/renderer/src/pages/home/Markdown/CodeBlock.tsx b/src/renderer/src/pages/home/Markdown/CodeBlock.tsx index b3bac4fdaf..ecf4d698f1 100644 --- a/src/renderer/src/pages/home/Markdown/CodeBlock.tsx +++ b/src/renderer/src/pages/home/Markdown/CodeBlock.tsx @@ -1,7 +1,7 @@ import { CheckOutlined } from '@ant-design/icons' import CopyIcon from '@renderer/components/Icons/CopyIcon' +import { useTheme } from '@renderer/context/ThemeProvider' import { initMermaid } from '@renderer/init' -import { useTheme } from '@renderer/providers/ThemeProvider' import { ThemeMode } from '@renderer/store/settings' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' diff --git a/src/renderer/src/pages/home/Settings.tsx b/src/renderer/src/pages/home/Settings.tsx index 5338b6758b..a5baaf73ff 100644 --- a/src/renderer/src/pages/home/Settings.tsx +++ b/src/renderer/src/pages/home/Settings.tsx @@ -1,4 +1,4 @@ -import { CheckOutlined, QuestionCircleOutlined, ReloadOutlined } from '@ant-design/icons' +import { CheckOutlined, CloseOutlined, QuestionCircleOutlined, ReloadOutlined } from '@ant-design/icons' import { HStack } from '@renderer/components/Layout' import { DEFAULT_CONEXTCOUNT, DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE } from '@renderer/config/constant' import { useAssistant } from '@renderer/hooks/useAssistant' @@ -19,6 +19,7 @@ import styled from 'styled-components' interface Props { assistant: Assistant + onClose: () => void } const SettingsTab: FC = (props) => { @@ -87,6 +88,10 @@ const SettingsTab: FC = (props) => { return ( + + {t('settings.title')} + + {t('settings.messages.model.title')}{' '} @@ -259,4 +264,21 @@ const SettingRowTitleSmall = styled(SettingRowTitle)` font-size: 13px; ` +const SettingsHeader = styled.div` + display: flex; + flex-direction: row; + align-items: center; + justify-content: space-between; + padding: 10px 15px; + border-bottom: 0.5px solid var(--color-border); + margin-left: -15px; + margin-right: -15px; +` + +const CloseIcon = styled(CloseOutlined)` + font-size: 14px; + cursor: pointer; + color: var(--color-text-3); +` + export default SettingsTab diff --git a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx index c4ea8bd991..92c3e2898b 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx @@ -8,8 +8,8 @@ import { } from '@ant-design/icons' import { getModelLogo } from '@renderer/config/provider' import { PROVIDER_CONFIG } from '@renderer/config/provider' +import { useTheme } from '@renderer/context/ThemeProvider' import { useProvider } from '@renderer/hooks/useProvider' -import { useTheme } from '@renderer/providers/ThemeProvider' import { checkApi } from '@renderer/services/api' import { Provider } from '@renderer/types' import { Avatar, Button, Card, Divider, Flex, Input, Space, Switch } from 'antd' diff --git a/src/renderer/src/providers/AiProvider.ts b/src/renderer/src/providers/AiProvider.ts new file mode 100644 index 0000000000..b986e21595 --- /dev/null +++ b/src/renderer/src/providers/AiProvider.ts @@ -0,0 +1,40 @@ +import BaseProvider from '@renderer/providers/BaseProvider' +import ProviderFactory from '@renderer/providers/ProviderFactory' +import { Assistant, Message, Provider, Suggestion } from '@renderer/types' +import OpenAI from 'openai' + +export default class AiProvider { + private sdk: BaseProvider + + constructor(provider: Provider) { + this.sdk = ProviderFactory.create(provider) + } + + public async completions( + messages: Message[], + assistant: Assistant, + onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void + ): Promise { + return this.sdk.completions(messages, assistant, onChunk) + } + + public async translate(message: Message, assistant: Assistant): Promise { + return this.sdk.translate(message, assistant) + } + + public async summaries(messages: Message[], assistant: Assistant): Promise { + return this.sdk.summaries(messages, assistant) + } + + public async suggestions(messages: Message[], assistant: Assistant): Promise { + return this.sdk.suggestions(messages, assistant) + } + + public async check(): Promise<{ valid: boolean; error: Error | null }> { + return this.sdk.check() + } + + public async models(): Promise { + return this.sdk.models() + } +} diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts new file mode 100644 index 0000000000..8af370d892 --- /dev/null +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -0,0 +1,143 @@ +import Anthropic from '@anthropic-ai/sdk' +import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources' +import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' +import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant' +import { EVENT_NAMES } from '@renderer/services/event' +import { Assistant, Message, Provider, Suggestion } from '@renderer/types' +import { sum, takeRight } from 'lodash' +import OpenAI from 'openai' + +import BaseProvider from './BaseProvider' + +export default class AnthropicProvider extends BaseProvider { + private sdk: Anthropic + + constructor(provider: Provider) { + super(provider) + this.sdk = new Anthropic({ apiKey: provider.apiKey, baseURL: this.getBaseURL() }) + } + + public async completions( + messages: Message[], + assistant: Assistant, + onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void + ) { + const defaultModel = getDefaultModel() + const model = assistant.model || defaultModel + const { contextCount, maxTokens } = getAssistantSettings(assistant) + + const userMessages = takeRight(messages, contextCount + 1).map((message) => { + return { + role: message.role, + content: message.content + } + }) + + return new Promise((resolve, reject) => { + const stream = this.sdk.messages + .stream({ + model: model.id, + messages: userMessages.filter(Boolean) as MessageParam[], + max_tokens: maxTokens || DEFAULT_MAX_TOKENS, + temperature: assistant?.settings?.temperature, + system: assistant.prompt, + stream: true + }) + .on('text', (text) => { + if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { + resolve() + return stream.controller.abort() + } + onChunk({ text }) + }) + .on('finalMessage', (message) => { + onChunk({ + text: '', + usage: { + prompt_tokens: message.usage.input_tokens, + completion_tokens: message.usage.output_tokens, + total_tokens: sum(Object.values(message.usage)) + } + }) + resolve() + }) + .on('error', (error) => reject(error)) + }) + } + + public async translate(message: Message, assistant: Assistant) { + const defaultModel = getDefaultModel() + const model = assistant.model || defaultModel + const messages = [ + { role: 'system', content: assistant.prompt }, + { role: 'user', content: message.content } + ] + + const response = await this.sdk.messages.create({ + model: model.id, + messages: messages.filter((m) => m.role === 'user') as MessageParam[], + max_tokens: 4096, + temperature: assistant?.settings?.temperature, + system: assistant.prompt, + stream: false + }) + + return response.content[0].type === 'text' ? response.content[0].text : '' + } + + public async summaries(messages: Message[], assistant: Assistant): Promise { + const model = getTopNamingModel() || assistant.model || getDefaultModel() + + const userMessages = takeRight(messages, 5).map((message) => ({ + role: message.role, + content: message.content + })) + + const systemMessage = { + role: 'system', + content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要使用标点符号和其他特殊符号。' + } + + const message = await this.sdk.messages.create({ + messages: userMessages as Anthropic.Messages.MessageParam[], + model: model.id, + system: systemMessage.content, + stream: false, + max_tokens: 4096 + }) + + return message.content[0].type === 'text' ? message.content[0].text : null + } + + public async suggestions(): Promise { + return [] + } + + public async check(): Promise<{ valid: boolean; error: Error | null }> { + const model = this.provider.models[0] + + const body = { + model: model.id, + messages: [{ role: 'user', content: 'hi' }], + max_tokens: 100, + stream: false + } + + try { + const message = await this.sdk.messages.create(body as MessageCreateParamsNonStreaming) + return { + valid: message.content.length > 0, + error: null + } + } catch (error: any) { + return { + valid: false, + error + } + } + } + + public async models(): Promise { + return [] + } +} diff --git a/src/renderer/src/providers/BaseProvider.ts b/src/renderer/src/providers/BaseProvider.ts new file mode 100644 index 0000000000..c1d83de2d3 --- /dev/null +++ b/src/renderer/src/providers/BaseProvider.ts @@ -0,0 +1,33 @@ +import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama' +import { Assistant, Message, Provider, Suggestion } from '@renderer/types' +import OpenAI from 'openai' + +export default abstract class BaseProvider { + protected provider: Provider + protected host: string + + constructor(provider: Provider) { + this.provider = provider + this.host = this.getBaseURL() + } + + public getBaseURL(): string { + const host = this.provider.apiHost + return host.endsWith('/') ? host : `${host}/v1/` + } + + public get keepAliveTime() { + return this.provider.id === 'ollama' ? getOllamaKeepAliveTime() : undefined + } + + abstract completions( + messages: Message[], + assistant: Assistant, + onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void + ): Promise + abstract translate(message: Message, assistant: Assistant): Promise + abstract summaries(messages: Message[], assistant: Assistant): Promise + abstract suggestions(messages: Message[], assistant: Assistant): Promise + abstract check(): Promise<{ valid: boolean; error: Error | null }> + abstract models(): Promise +} diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts new file mode 100644 index 0000000000..1fbebbd384 --- /dev/null +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -0,0 +1,170 @@ +import { GoogleGenerativeAI } from '@google/generative-ai' +import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant' +import { EVENT_NAMES } from '@renderer/services/event' +import { Assistant, Message, Provider, Suggestion } from '@renderer/types' +import axios from 'axios' +import { isEmpty, takeRight } from 'lodash' +import OpenAI from 'openai' + +import BaseProvider from './BaseProvider' + +export default class GeminiProvider extends BaseProvider { + private sdk: GoogleGenerativeAI + + constructor(provider: Provider) { + super(provider) + this.sdk = new GoogleGenerativeAI(provider.apiKey) + } + + public async completions( + messages: Message[], + assistant: Assistant, + onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void + ) { + const defaultModel = getDefaultModel() + const model = assistant.model || defaultModel + const { contextCount, maxTokens } = getAssistantSettings(assistant) + + const userMessages = takeRight(messages, contextCount + 1).map((message) => { + return { + role: message.role, + content: message.content + } + }) + + const geminiModel = this.sdk.getGenerativeModel({ + model: model.id, + systemInstruction: assistant.prompt, + generationConfig: { + maxOutputTokens: maxTokens, + temperature: assistant?.settings?.temperature + } + }) + + const userLastMessage = userMessages.pop() + + const chat = geminiModel.startChat({ + history: userMessages.map((message) => ({ + role: message.role === 'user' ? 'user' : 'model', + parts: [{ text: message.content }] + })) + }) + + const userMessagesStream = await chat.sendMessageStream(userLastMessage?.content!) + + for await (const chunk of userMessagesStream.stream) { + if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break + onChunk({ + text: chunk.text(), + usage: { + prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0, + completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0, + total_tokens: chunk.usageMetadata?.totalTokenCount || 0 + } + }) + } + } + + async translate(message: Message, assistant: Assistant) { + const defaultModel = getDefaultModel() + const { maxTokens } = getAssistantSettings(assistant) + const model = assistant.model || defaultModel + + const geminiModel = this.sdk.getGenerativeModel({ + model: model.id, + systemInstruction: assistant.prompt, + generationConfig: { + maxOutputTokens: maxTokens, + temperature: assistant?.settings?.temperature + } + }) + + const { response } = await geminiModel.generateContent(message.content) + + return response.text() + } + + public async summaries(messages: Message[], assistant: Assistant): Promise { + const model = getTopNamingModel() || assistant.model || getDefaultModel() + + const userMessages = takeRight(messages, 5).map((message) => ({ + role: message.role, + content: message.content + })) + + const systemMessage = { + role: 'system', + content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要使用标点符号和其他特殊符号。' + } + + const geminiModel = this.sdk.getGenerativeModel({ + model: model.id, + systemInstruction: systemMessage.content, + generationConfig: { + temperature: assistant?.settings?.temperature + } + }) + + const lastUserMessage = userMessages.pop() + + const chat = await geminiModel.startChat({ + history: userMessages.map((message) => ({ + role: message.role === 'user' ? 'user' : 'model', + parts: [{ text: message.content }] + })) + }) + + const { response } = await chat.sendMessage(lastUserMessage?.content!) + + return response.text() + } + + public async suggestions(): Promise { + return [] + } + + public async check(): Promise<{ valid: boolean; error: Error | null }> { + const model = this.provider.models[0] + + const body = { + model: model.id, + messages: [{ role: 'user', content: 'hi' }], + max_tokens: 100, + stream: false + } + + try { + const geminiModel = this.sdk.getGenerativeModel({ model: body.model }) + const result = await geminiModel.generateContent(body.messages[0].content) + return { + valid: !isEmpty(result.response.text()), + error: null + } + } catch (error: any) { + return { + valid: false, + error + } + } + } + + public async models(): Promise { + try { + const api = this.provider.apiHost + '/v1beta/models' + const { data } = await axios.get(api, { params: { key: this.provider.apiKey } }) + return data.models.map( + (m: any) => + ({ + id: m.name.replace('models/', ''), + name: m.displayName, + description: m.description, + object: 'model', + created: Date.now(), + owned_by: 'gemini' + }) as OpenAI.Models.Model + ) + } catch (error) { + return [] + } + } +} diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts new file mode 100644 index 0000000000..b9f7180961 --- /dev/null +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -0,0 +1,185 @@ +import { isLocalAi } from '@renderer/config/env' +import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant' +import { EVENT_NAMES } from '@renderer/services/event' +import { Assistant, Message, Provider, Suggestion } from '@renderer/types' +import { fileToBase64, removeQuotes } from '@renderer/utils' +import { first, takeRight } from 'lodash' +import OpenAI from 'openai' +import { + ChatCompletionContentPart, + ChatCompletionCreateParamsNonStreaming, + ChatCompletionMessageParam +} from 'openai/resources' + +import BaseProvider from './BaseProvider' + +export default class OpenAIProvider extends BaseProvider { + private sdk: OpenAI + + constructor(provider: Provider) { + super(provider) + this.sdk = new OpenAI({ + dangerouslyAllowBrowser: true, + apiKey: provider.apiKey, + baseURL: this.getBaseURL() + }) + } + + private async getMessageContent(message: Message): Promise { + const file = first(message.files) + + if (!file) { + return message.content + } + + if (file.type.includes('image')) { + return [ + { type: 'text', text: message.content }, + { + type: 'image_url', + image_url: { + url: await fileToBase64(file) + } + } + ] + } + + return message.content + } + + async completions( + messages: Message[], + assistant: Assistant, + onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void + ): Promise { + const defaultModel = getDefaultModel() + const model = assistant.model || defaultModel + const { contextCount, maxTokens } = getAssistantSettings(assistant) + + const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined + + const userMessages: ChatCompletionMessageParam[] = [] + + for (const message of takeRight(messages, contextCount + 1)) { + userMessages.push({ + role: message.role, + content: await this.getMessageContent(message) + } as ChatCompletionMessageParam) + } + + // @ts-ignore key is not typed + const stream = await this.sdk.chat.completions.create({ + model: model.id, + messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[], + stream: true, + temperature: assistant?.settings?.temperature, + max_tokens: maxTokens, + keep_alive: this.keepAliveTime + }) + + for await (const chunk of stream) { + if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break + onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage }) + } + } + + async translate(message: Message, assistant: Assistant) { + const defaultModel = getDefaultModel() + const model = assistant.model || defaultModel + const messages = [ + { role: 'system', content: assistant.prompt }, + { role: 'user', content: message.content } + ] + + // @ts-ignore key is not typed + const response = await this.sdk.chat.completions.create({ + model: model.id, + messages: messages as ChatCompletionMessageParam[], + stream: false, + keep_alive: this.keepAliveTime + }) + + return response.choices[0].message?.content || '' + } + + public async summaries(messages: Message[], assistant: Assistant): Promise { + const model = getTopNamingModel() || assistant.model || getDefaultModel() + + const userMessages = takeRight(messages, 5).map((message) => ({ + role: message.role, + content: message.content + })) + + const systemMessage = { + role: 'system', + content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要使用标点符号和其他特殊符号。' + } + + // @ts-ignore key is not typed + const response = await this.sdk.chat.completions.create({ + model: model.id, + messages: [systemMessage, ...(isLocalAi ? [first(userMessages)] : userMessages)] as ChatCompletionMessageParam[], + stream: false, + max_tokens: 50, + keep_alive: this.keepAliveTime + }) + + return removeQuotes(response.choices[0].message?.content?.substring(0, 50) || '') + } + + async suggestions(messages: Message[], assistant: Assistant): Promise { + const model = assistant.model + + if (!model) { + return [] + } + + const response: any = await this.sdk.request({ + method: 'post', + path: '/advice_questions', + body: { + messages: messages.filter((m) => m.role === 'user').map((m) => ({ role: m.role, content: m.content })), + model: model.id, + max_tokens: 0, + temperature: 0, + n: 0 + } + }) + + return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || [] + } + + public async check(): Promise<{ valid: boolean; error: Error | null }> { + const model = this.provider.models[0] + + const body = { + model: model.id, + messages: [{ role: 'user', content: 'hi' }], + max_tokens: 100, + stream: false + } + + try { + const response = await this.sdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming) + + return { + valid: Boolean(response?.choices[0].message), + error: null + } + } catch (error: any) { + return { + valid: false, + error + } + } + } + + public async models(): Promise { + try { + const response = await this.sdk.models.list() + return response.data + } catch (error) { + return [] + } + } +} diff --git a/src/renderer/src/providers/ProviderFactory.ts b/src/renderer/src/providers/ProviderFactory.ts new file mode 100644 index 0000000000..53e50fdb9f --- /dev/null +++ b/src/renderer/src/providers/ProviderFactory.ts @@ -0,0 +1,19 @@ +import { Provider } from '@renderer/types' + +import AnthropicProvider from './AnthropicProvider' +import BaseProvider from './BaseProvider' +import GeminiProvider from './GeminiProvider' +import OpenAIProvider from './OpenAIProvider' + +export default class ProviderFactory { + static create(provider: Provider): BaseProvider { + switch (provider.id) { + case 'anthropic': + return new AnthropicProvider(provider) + case 'gemini': + return new GeminiProvider(provider) + default: + return new OpenAIProvider(provider) + } + } +} diff --git a/src/renderer/src/services/ProviderSDK.ts b/src/renderer/src/services/ProviderSDK.ts deleted file mode 100644 index 1916ac722f..0000000000 --- a/src/renderer/src/services/ProviderSDK.ts +++ /dev/null @@ -1,358 +0,0 @@ -import Anthropic from '@anthropic-ai/sdk' -import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources' -import { GoogleGenerativeAI } from '@google/generative-ai' -import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' -import { isLocalAi } from '@renderer/config/env' -import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama' -import { Assistant, Message, Provider, Suggestion } from '@renderer/types' -import { removeQuotes } from '@renderer/utils' -import axios from 'axios' -import { first, isEmpty, sum, takeRight } from 'lodash' -import OpenAI from 'openai' -import { ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam } from 'openai/resources' - -import { getAssistantSettings, getDefaultModel, getTopNamingModel } from './assistant' -import { EVENT_NAMES } from './event' - -export default class ProviderSDK { - provider: Provider - openaiSdk: OpenAI - anthropicSdk: Anthropic - geminiSdk: GoogleGenerativeAI - - constructor(provider: Provider) { - this.provider = provider - const host = provider.apiHost - const baseURL = host.endsWith('/') ? host : `${provider.apiHost}/v1/` - this.anthropicSdk = new Anthropic({ apiKey: provider.apiKey, baseURL }) - this.openaiSdk = new OpenAI({ dangerouslyAllowBrowser: true, apiKey: provider.apiKey, baseURL }) - this.geminiSdk = new GoogleGenerativeAI(provider.apiKey) - } - - private get isAnthropic() { - return this.provider.id === 'anthropic' - } - - private get isGemini() { - return this.provider.id === 'gemini' - } - - private get keepAliveTime() { - return this.provider.id === 'ollama' ? getOllamaKeepAliveTime() : undefined - } - - public async completions( - messages: Message[], - assistant: Assistant, - onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void - ) { - const defaultModel = getDefaultModel() - const model = assistant.model || defaultModel - const { contextCount, maxTokens } = getAssistantSettings(assistant) - - const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined - const userMessages = takeRight(messages, contextCount + 1).map((message) => { - return { - role: message.role, - content: message.content - } - }) - - if (this.isAnthropic) { - return new Promise((resolve, reject) => { - const stream = this.anthropicSdk.messages - .stream({ - model: model.id, - messages: userMessages.filter(Boolean) as MessageParam[], - max_tokens: maxTokens || DEFAULT_MAX_TOKENS, - temperature: assistant?.settings?.temperature, - system: assistant.prompt, - stream: true - }) - .on('text', (text) => { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { - resolve() - return stream.controller.abort() - } - onChunk({ text }) - }) - .on('finalMessage', (message) => { - onChunk({ - text: '', - usage: { - prompt_tokens: message.usage.input_tokens, - completion_tokens: message.usage.output_tokens, - total_tokens: sum(Object.values(message.usage)) - } - }) - resolve() - }) - .on('error', (error) => reject(error)) - }) - } - - if (this.isGemini) { - const geminiModel = this.geminiSdk.getGenerativeModel({ - model: model.id, - systemInstruction: assistant.prompt, - generationConfig: { - maxOutputTokens: maxTokens, - temperature: assistant?.settings?.temperature - } - }) - - const userLastMessage = userMessages.pop() - - const chat = geminiModel.startChat({ - history: userMessages.map((message) => ({ - role: message.role === 'user' ? 'user' : 'model', - parts: [{ text: message.content }] - })) - }) - - const userMessagesStream = await chat.sendMessageStream(userLastMessage?.content!) - - for await (const chunk of userMessagesStream.stream) { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break - onChunk({ - text: chunk.text(), - usage: { - prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0, - completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0, - total_tokens: chunk.usageMetadata?.totalTokenCount || 0 - } - }) - } - - return - } - - const _userMessages = takeRight(messages, contextCount + 1).map((message) => { - return { - role: message.role, - content: message.images - ? [ - { type: 'text', text: message.content }, - ...message.images!.map((image) => ({ type: 'image_url', image_url: image })) - ] - : message.content - } - }) - - // @ts-ignore key is not typed - const stream = await this.openaiSdk.chat.completions.create({ - model: model.id, - messages: [systemMessage, ..._userMessages].filter(Boolean) as ChatCompletionMessageParam[], - stream: true, - temperature: assistant?.settings?.temperature, - max_tokens: maxTokens, - keep_alive: this.keepAliveTime - }) - - for await (const chunk of stream) { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break - onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage }) - } - } - - public async translate(message: Message, assistant: Assistant) { - const defaultModel = getDefaultModel() - const { maxTokens } = getAssistantSettings(assistant) - const model = assistant.model || defaultModel - const messages = [ - { role: 'system', content: assistant.prompt }, - { role: 'user', content: message.content } - ] - - if (this.isAnthropic) { - const response = await this.anthropicSdk.messages.create({ - model: model.id, - messages: messages.filter((m) => m.role === 'user') as MessageParam[], - max_tokens: 4096, - temperature: assistant?.settings?.temperature, - system: assistant.prompt, - stream: false - }) - - return response.content[0].type === 'text' ? response.content[0].text : '' - } - - if (this.isGemini) { - const geminiModel = this.geminiSdk.getGenerativeModel({ - model: model.id, - systemInstruction: assistant.prompt, - generationConfig: { - maxOutputTokens: maxTokens, - temperature: assistant?.settings?.temperature - } - }) - - const { response } = await geminiModel.generateContent(message.content) - - return response.text() - } - - // @ts-ignore key is not typed - const response = await this.openaiSdk.chat.completions.create({ - model: model.id, - messages: messages as ChatCompletionMessageParam[], - stream: false, - keep_alive: this.keepAliveTime - }) - - return response.choices[0].message?.content || '' - } - - public async summaries(messages: Message[], assistant: Assistant): Promise { - const model = getTopNamingModel() || assistant.model || getDefaultModel() - - const userMessages = takeRight(messages, 5).map((message) => ({ - role: message.role, - content: message.content - })) - - const systemMessage = { - role: 'system', - content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要使用标点符号和其他特殊符号。' - } - - if (this.isAnthropic) { - const message = await this.anthropicSdk.messages.create({ - messages: userMessages as Anthropic.Messages.MessageParam[], - model: model.id, - system: systemMessage.content, - stream: false, - max_tokens: 4096 - }) - - return message.content[0].type === 'text' ? message.content[0].text : null - } - - if (this.isGemini) { - const geminiModel = this.geminiSdk.getGenerativeModel({ - model: model.id, - systemInstruction: systemMessage.content, - generationConfig: { - temperature: assistant?.settings?.temperature - } - }) - - const lastUserMessage = userMessages.pop() - - const chat = await geminiModel.startChat({ - history: userMessages.map((message) => ({ - role: message.role === 'user' ? 'user' : 'model', - parts: [{ text: message.content }] - })) - }) - - const { response } = await chat.sendMessage(lastUserMessage?.content!) - - return response.text() - } - - // @ts-ignore key is not typed - const response = await this.openaiSdk.chat.completions.create({ - model: model.id, - messages: [systemMessage, ...(isLocalAi ? [first(userMessages)] : userMessages)] as ChatCompletionMessageParam[], - stream: false, - max_tokens: 50, - keep_alive: this.keepAliveTime - }) - - return removeQuotes(response.choices[0].message?.content?.substring(0, 50) || '') - } - - public async suggestions(messages: Message[], assistant: Assistant): Promise { - const model = assistant.model - - if (!model) { - return [] - } - - const response: any = await this.openaiSdk.request({ - method: 'post', - path: '/advice_questions', - body: { - messages: messages.filter((m) => m.role === 'user').map((m) => ({ role: m.role, content: m.content })), - model: model.id, - max_tokens: 0, - temperature: 0, - n: 0 - } - }) - - return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || [] - } - - public async check(): Promise<{ valid: boolean; error: Error | null }> { - const model = this.provider.models[0] - - const body = { - model: model.id, - messages: [{ role: 'user', content: 'hi' }], - max_tokens: 100, - stream: false - } - - try { - if (this.isAnthropic) { - const message = await this.anthropicSdk.messages.create(body as MessageCreateParamsNonStreaming) - return { - valid: message.content.length > 0, - error: null - } - } - - if (this.isGemini) { - const geminiModel = this.geminiSdk.getGenerativeModel({ model: body.model }) - const result = await geminiModel.generateContent(body.messages[0].content) - return { - valid: !isEmpty(result.response.text()), - error: null - } - } - - const response = await this.openaiSdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming) - - return { - valid: Boolean(response?.choices[0].message), - error: null - } - } catch (error: any) { - return { - valid: false, - error - } - } - } - - public async models(): Promise { - try { - if (this.isAnthropic) { - return [] - } - - if (this.isGemini) { - const api = this.provider.apiHost + '/v1beta/models' - const { data } = await axios.get(api, { params: { key: this.provider.apiKey } }) - return data.models.map( - (m: any) => - ({ - id: m.name.replace('models/', ''), - name: m.displayName, - description: m.description, - object: 'model', - created: Date.now(), - owned_by: 'gemini' - }) as OpenAI.Models.Model - ) - } - - const response = await this.openaiSdk.models.list() - return response.data - } catch (error) { - return [] - } - } -} diff --git a/src/renderer/src/services/api.ts b/src/renderer/src/services/api.ts index a4a7166bd0..b5fc8da11b 100644 --- a/src/renderer/src/services/api.ts +++ b/src/renderer/src/services/api.ts @@ -6,6 +6,7 @@ import { uuid } from '@renderer/utils' import dayjs from 'dayjs' import { isEmpty } from 'lodash' +import AiProvider from '../providers/AiProvider' import { getAssistantProvider, getDefaultModel, @@ -15,7 +16,6 @@ import { } from './assistant' import { EVENT_NAMES, EventEmitter } from './event' import { filterMessages } from './messages' -import ProviderSDK from './ProviderSDK' export async function fetchChatCompletion({ messages, @@ -33,7 +33,7 @@ export async function fetchChatCompletion({ const provider = getAssistantProvider(assistant) const defaultModel = getDefaultModel() const model = assistant.model || defaultModel - const providerSdk = new ProviderSDK(provider) + const AI = new AiProvider(provider) store.dispatch(setGenerating(true)) @@ -61,7 +61,7 @@ export async function fetchChatCompletion({ }, 1000) try { - await providerSdk.completions(filterMessages(messages), assistant, ({ text, usage }) => { + await AI.completions(filterMessages(messages), assistant, ({ text, usage }) => { message.content = message.content + text || '' message.usage = usage onResponse({ ...message, status: 'pending' }) @@ -103,10 +103,10 @@ export async function fetchTranslate({ message, assistant }: { message: Message; return '' } - const providerSdk = new ProviderSDK(provider) + const AI = new AiProvider(provider) try { - return await providerSdk.translate(message, assistant) + return await AI.translate(message, assistant) } catch (error: any) { return '' } @@ -120,10 +120,10 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages: return null } - const providerSdk = new ProviderSDK(provider) + const AI = new AiProvider(provider) try { - return await providerSdk.summaries(filterMessages(messages), assistant) + return await AI.summaries(filterMessages(messages), assistant) } catch (error: any) { return null } @@ -136,10 +136,8 @@ export async function fetchSuggestions({ messages: Message[] assistant: Assistant }): Promise { - console.debug('fetchSuggestions', messages, assistant) const provider = getAssistantProvider(assistant) - const providerSdk = new ProviderSDK(provider) - console.debug('fetchSuggestions', provider) + const AI = new AiProvider(provider) const model = assistant.model if (!model) { @@ -155,7 +153,7 @@ export async function fetchSuggestions({ } try { - return await providerSdk.suggestions(messages, assistant) + return await AI.suggestions(messages, assistant) } catch (error: any) { return [] } @@ -183,9 +181,9 @@ export async function checkApi(provider: Provider) { return false } - const providerSdk = new ProviderSDK(provider) + const AI = new AiProvider(provider) - const { valid } = await providerSdk.check() + const { valid } = await AI.check() window.message[valid ? 'success' : 'error']({ key: 'api-check', @@ -204,10 +202,10 @@ function hasApiKey(provider: Provider) { } export async function fetchModels(provider: Provider) { - const providerSdk = new ProviderSDK(provider) + const AI = new AiProvider(provider) try { - return await providerSdk.models() + return await AI.models() } catch (error) { return [] } diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 31527309d8..19170242f4 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -20,14 +20,15 @@ export type AssistantSettings = { export type Message = { id: string + assistantId: string role: 'user' | 'assistant' content: string - images?: string[] - assistantId: string topicId: string - modelId?: string createdAt: string status: 'sending' | 'pending' | 'success' | 'paused' | 'error' + modelId?: string + files?: File[] + images?: string[] usage?: OpenAI.Completions.CompletionUsage type?: 'text' | '@' } diff --git a/src/renderer/src/utils/index.ts b/src/renderer/src/utils/index.ts index 1c0ca0182d..dbdbcc7731 100644 --- a/src/renderer/src/utils/index.ts +++ b/src/renderer/src/utils/index.ts @@ -223,3 +223,18 @@ export function getBriefInfo(text: string, maxLength: number = 50): string { // 截取前面的内容,并在末尾添加 "..." return truncatedText + '...' } + +export async function fileToBase64(file: File): Promise { + return new Promise((resolve, reject) => { + try { + const reader = new FileReader() + reader.onload = (e: ProgressEvent) => { + const result = e.target?.result + resolve(typeof result === 'string' ? result : '') + } + reader.readAsDataURL(file) + } catch (error: any) { + reject(error) + } + }) +}