diff --git a/src/renderer/src/assets/images/providers/ollama.png b/src/renderer/src/assets/images/providers/ollama.png new file mode 100644 index 0000000000..4ca6ec0057 Binary files /dev/null and b/src/renderer/src/assets/images/providers/ollama.png differ diff --git a/src/renderer/src/pages/settings/ProviderSettings.tsx b/src/renderer/src/pages/settings/ProviderSettings.tsx index be787be5b1..5a84a48ef4 100644 --- a/src/renderer/src/pages/settings/ProviderSettings.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings.tsx @@ -2,9 +2,9 @@ import { useSystemProviders } from '@renderer/hooks/useProvider' import { Provider } from '@renderer/types' import { FC, useState } from 'react' import styled from 'styled-components' -import ProviderModals from './components/ProviderModals' import { Avatar } from 'antd' import { getProviderLogo } from '@renderer/services/provider' +import ProviderModels from './components/ProviderModels' const ProviderSettings: FC = () => { const providers = useSystemProviders() @@ -23,7 +23,7 @@ const ProviderSettings: FC = () => { ))} - + ) } diff --git a/src/renderer/src/pages/settings/components/ModelAddPopup.tsx b/src/renderer/src/pages/settings/components/ModelAddPopup.tsx new file mode 100644 index 0000000000..a29f57c38b --- /dev/null +++ b/src/renderer/src/pages/settings/components/ModelAddPopup.tsx @@ -0,0 +1,124 @@ +import { TopView } from '@renderer/components/TopView' +import { useProvider } from '@renderer/hooks/useProvider' +import { Model, Provider } from '@renderer/types' +import { getDefaultGroupName } from '@renderer/utils' +import { Button, Form, FormProps, Input, Modal } from 'antd' +import { find } from 'lodash' +import { useState } from 'react' + +interface ShowParams { + title: string + provider: Provider +} + +interface Props extends ShowParams { + resolve: (data: any) => void +} + +type FieldType = { + provider: string + id: string + name?: string + group?: string +} + +const PopupContainer: React.FC = ({ title, provider, resolve }) => { + const [open, setOpen] = useState(true) + const [form] = Form.useForm() + const { addModel, models } = useProvider(provider.id) + + const onOk = () => { + setOpen(false) + } + + const onCancel = () => { + setOpen(false) + } + + const onClose = () => { + resolve({}) + } + + const onFinish: FormProps['onFinish'] = (values) => { + if (find(models, { id: values.id })) { + Modal.error({ title: 'Error', content: 'Model ID already exists' }) + return + } + + const model: Model = { + id: values.id, + provider: provider.id, + name: values.name ? values.name : values.id.toUpperCase(), + group: getDefaultGroupName(values.group || values.id), + temperature: 0.7 + } + + addModel(model) + + resolve(model) + } + + return ( + +
+ + + + + { + form.setFieldValue('name', e.target.value.toUpperCase()) + form.setFieldValue('group', getDefaultGroupName(e.target.value)) + }} + /> + + + + + + + + + + +
+
+ ) +} + +export default class ModalAddPopup { + static topviewId = 0 + static hide() { + TopView.hide(this.topviewId) + } + static show(props: ShowParams) { + return new Promise((resolve) => { + this.topviewId = TopView.show( + { + resolve(v) + this.hide() + }} + /> + ) + }) + } +} diff --git a/src/renderer/src/components/Popups/ModalListPopup.tsx b/src/renderer/src/pages/settings/components/ModelListPopup.tsx similarity index 85% rename from src/renderer/src/components/Popups/ModalListPopup.tsx rename to src/renderer/src/pages/settings/components/ModelListPopup.tsx index 14e6f76bba..1cf80cfb8b 100644 --- a/src/renderer/src/components/Popups/ModalListPopup.tsx +++ b/src/renderer/src/pages/settings/components/ModelListPopup.tsx @@ -1,8 +1,8 @@ -import { Avatar, Button, Modal } from 'antd' +import { Avatar, Button, Empty, Modal } from 'antd' import { useState } from 'react' -import { TopView } from '../TopView' +import { TopView } from '../../../components/TopView' import { Model, Provider } from '@renderer/types' -import { groupBy } from 'lodash' +import { groupBy, isEmpty, uniqBy } from 'lodash' import styled from 'styled-components' import { MinusOutlined, PlusOutlined } from '@ant-design/icons' import { useProvider } from '@renderer/hooks/useProvider' @@ -19,10 +19,11 @@ interface Props extends ShowParams { const PopupContainer: React.FC = ({ provider: _provider, resolve }) => { const [open, setOpen] = useState(true) - const { provider, addModel, removeModel } = useProvider(_provider.id) + const { provider, models, addModel, removeModel } = useProvider(_provider.id) - const systemModels = SYSTEM_MODELS[_provider.id] - const systemModelGroups = groupBy(systemModels, 'group') + const systemModels = SYSTEM_MODELS[_provider.id] || [] + const allModels = uniqBy([...systemModels, ...models], 'id') + const systemModelGroups = groupBy(allModels, 'group') const onOk = () => { setOpen(false) @@ -79,6 +80,7 @@ const PopupContainer: React.FC = ({ provider: _provider, resolve }) => { })} ))} + {isEmpty(allModels) && } ) @@ -124,7 +126,7 @@ const ListItemName = styled.div` margin-left: 6px; ` -export default class ModalListPopup { +export default class ModelListPopup { static topviewId = 0 static hide() { TopView.hide(this.topviewId) diff --git a/src/renderer/src/pages/settings/components/ProviderModals.tsx b/src/renderer/src/pages/settings/components/ProviderModels.tsx similarity index 74% rename from src/renderer/src/pages/settings/components/ProviderModals.tsx rename to src/renderer/src/pages/settings/components/ProviderModels.tsx index 1fbb86e15f..fef0541b1c 100644 --- a/src/renderer/src/pages/settings/components/ProviderModals.tsx +++ b/src/renderer/src/pages/settings/components/ProviderModels.tsx @@ -1,18 +1,20 @@ import { Provider } from '@renderer/types' import { FC, useEffect, useState } from 'react' import styled from 'styled-components' -import { Avatar, Button, Card, Divider, Input } from 'antd' +import { Avatar, Button, Card, Divider, Flex, Input } from 'antd' import { useProvider } from '@renderer/hooks/useProvider' -import ModalListPopup from '@renderer/components/Popups/ModalListPopup' import { groupBy } from 'lodash' import { SettingContainer, SettingSubtitle, SettingTitle } from './SettingComponent' import { getModelLogo } from '@renderer/services/provider' +import { EditOutlined, PlusOutlined } from '@ant-design/icons' +import ModalAddPopup from './ModelAddPopup' +import ModelListPopup from './ModelListPopup' interface Props { provider: Provider } -const ProviderModals: FC = ({ provider }) => { +const ProviderModels: FC = ({ provider }) => { const [apiKey, setApiKey] = useState(provider.apiKey) const [apiHost, setApiHost] = useState(provider.apiHost) const { updateProvider, models } = useProvider(provider.id) @@ -32,8 +34,12 @@ const ProviderModals: FC = ({ provider }) => { updateProvider({ ...provider, apiHost }) } - const onAddModal = () => { - ModalListPopup.show({ provider }) + const onManageModel = () => { + ModelListPopup.show({ provider }) + } + + const onAddModel = () => { + ModalAddPopup.show({ title: 'Add Model', provider }) } return ( @@ -66,9 +72,14 @@ const ProviderModals: FC = ({ provider }) => { ))} ))} - + + + + ) } @@ -81,4 +92,4 @@ const ModelListItem = styled.div` padding: 5px 0; ` -export default ProviderModals +export default ProviderModels diff --git a/src/renderer/src/services/provider.ts b/src/renderer/src/services/provider.ts index 8efdfc8df8..c1d36e527f 100644 --- a/src/renderer/src/services/provider.ts +++ b/src/renderer/src/services/provider.ts @@ -4,6 +4,7 @@ import DeepSeekProviderLogo from '@renderer/assets/images/providers/deepseek.png import YiProviderLogo from '@renderer/assets/images/providers/yi.svg' import GroqProviderLogo from '@renderer/assets/images/providers/groq.png' import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png' +import OllamaProviderLogo from '@renderer/assets/images/providers/ollama.png' import ChatGPTModelLogo from '@renderer/assets/images/models/chatgpt.jpeg' import ChatGLMModelLogo from '@renderer/assets/images/models/chatglm.jpeg' import DeepSeekModelLogo from '@renderer/assets/images/models/deepseek.png' @@ -14,66 +15,42 @@ import LlamaModelLogo from '@renderer/assets/images/models/llama.jpeg' import MixtralModelLogo from '@renderer/assets/images/models/mixtral.jpeg' export function getProviderLogo(providerId: string) { - if (providerId === 'openai') { - return OpenAiProviderLogo + switch (providerId) { + case 'openai': + return OpenAiProviderLogo + case 'silicon': + return SiliconFlowProviderLogo + case 'deepseek': + return DeepSeekProviderLogo + case 'yi': + return YiProviderLogo + case 'groq': + return GroqProviderLogo + case 'zhipu': + return ZhipuProviderLogo + case 'ollama': + return OllamaProviderLogo + default: + return '' } - - if (providerId === 'silicon') { - return SiliconFlowProviderLogo - } - - if (providerId === 'deepseek') { - return DeepSeekProviderLogo - } - - if (providerId === 'yi') { - return YiProviderLogo - } - - if (providerId === 'groq') { - return GroqProviderLogo - } - - if (providerId === 'zhipu') { - return ZhipuProviderLogo - } - - return '' } export function getModelLogo(modelId: string) { - const _modelId = modelId.toLowerCase() - - if (_modelId.includes('gpt')) { - return ChatGPTModelLogo + const logoMap = { + gpt: ChatGPTModelLogo, + glm: ChatGLMModelLogo, + deepseek: DeepSeekModelLogo, + qwen: QwenModelLogo, + gemma: GemmaModelLogo, + 'yi-': YiModelLogo, + llama: LlamaModelLogo, + mixtral: MixtralModelLogo } - if (_modelId.includes('glm')) { - return ChatGLMModelLogo - } - - if (_modelId.includes('deepseek')) { - return DeepSeekModelLogo - } - - if (_modelId.includes('qwen')) { - return QwenModelLogo - } - - if (_modelId.includes('gemma')) { - return GemmaModelLogo - } - - if (_modelId.includes('yi-')) { - return YiModelLogo - } - - if (_modelId.includes('llama')) { - return LlamaModelLogo - } - - if (_modelId.includes('mixtral')) { - return MixtralModelLogo + for (const key in logoMap) { + if (modelId.toLowerCase().includes(key)) { + return logoMap[key] + } } return '' diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index e43047a30a..6a29d7675c 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -19,7 +19,7 @@ const persistedReducer = persistReducer( { key: 'cherry-studio', storage, - version: 3, + version: 4, blacklist: ['runtime'], migrate }, diff --git a/src/renderer/src/store/llm.ts b/src/renderer/src/store/llm.ts index 98fbb2d898..23562de6a2 100644 --- a/src/renderer/src/store/llm.ts +++ b/src/renderer/src/store/llm.ts @@ -60,6 +60,14 @@ const initialState: LlmState = { apiHost: 'https://api.groq.com/openai', isSystem: true, models: SYSTEM_MODELS.groq.filter((m) => m.defaultEnabled) + }, + { + id: 'ollama', + name: 'Ollama', + apiKey: '', + apiHost: 'http://localhost:11434/v1/', + isSystem: true, + models: [] } ] } diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index e8b98b1082..be98a91516 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -42,6 +42,26 @@ const migrate = createMigrate({ ] } } + }, + // @ts-ignore store type is unknown + '4': (state: RootState) => { + return { + ...state, + llm: { + ...state.llm, + providers: [ + ...state.llm.providers, + { + id: 'ollama', + name: 'Ollama', + apiKey: '', + apiHost: 'http://localhost:11434/v1/', + isSystem: true, + models: [] + } + ] + } + } } }) diff --git a/src/renderer/src/utils/index.ts b/src/renderer/src/utils/index.ts index 0aebc309f6..c51e40ef40 100644 --- a/src/renderer/src/utils/index.ts +++ b/src/renderer/src/utils/index.ts @@ -66,3 +66,18 @@ export const compressImage = async (file: File) => { useWebWorker: false }) } + +// Converts 'gpt-3.5-turbo-16k-0613' to 'GPT-3.5-Turbo' +// Converts 'qwen2:1.5b' to 'QWEN2' +export const getDefaultGroupName = (id: string) => { + if (id.includes(':')) { + return id.split(':')[0].toUpperCase() + } + + if (id.includes('-')) { + const parts = id.split('-') + return parts[0].toUpperCase() + '-' + parts[1].toUpperCase() + } + + return id.toUpperCase() +}