diff --git a/src/renderer/src/assets/images/providers/ollama.png b/src/renderer/src/assets/images/providers/ollama.png
new file mode 100644
index 0000000000..4ca6ec0057
Binary files /dev/null and b/src/renderer/src/assets/images/providers/ollama.png differ
diff --git a/src/renderer/src/pages/settings/ProviderSettings.tsx b/src/renderer/src/pages/settings/ProviderSettings.tsx
index be787be5b1..5a84a48ef4 100644
--- a/src/renderer/src/pages/settings/ProviderSettings.tsx
+++ b/src/renderer/src/pages/settings/ProviderSettings.tsx
@@ -2,9 +2,9 @@ import { useSystemProviders } from '@renderer/hooks/useProvider'
import { Provider } from '@renderer/types'
import { FC, useState } from 'react'
import styled from 'styled-components'
-import ProviderModals from './components/ProviderModals'
import { Avatar } from 'antd'
import { getProviderLogo } from '@renderer/services/provider'
+import ProviderModels from './components/ProviderModels'
const ProviderSettings: FC = () => {
const providers = useSystemProviders()
@@ -23,7 +23,7 @@ const ProviderSettings: FC = () => {
))}
-
+
)
}
diff --git a/src/renderer/src/pages/settings/components/ModelAddPopup.tsx b/src/renderer/src/pages/settings/components/ModelAddPopup.tsx
new file mode 100644
index 0000000000..a29f57c38b
--- /dev/null
+++ b/src/renderer/src/pages/settings/components/ModelAddPopup.tsx
@@ -0,0 +1,124 @@
+import { TopView } from '@renderer/components/TopView'
+import { useProvider } from '@renderer/hooks/useProvider'
+import { Model, Provider } from '@renderer/types'
+import { getDefaultGroupName } from '@renderer/utils'
+import { Button, Form, FormProps, Input, Modal } from 'antd'
+import { find } from 'lodash'
+import { useState } from 'react'
+
+interface ShowParams {
+ title: string
+ provider: Provider
+}
+
+interface Props extends ShowParams {
+ resolve: (data: any) => void
+}
+
+type FieldType = {
+ provider: string
+ id: string
+ name?: string
+ group?: string
+}
+
+const PopupContainer: React.FC = ({ title, provider, resolve }) => {
+ const [open, setOpen] = useState(true)
+ const [form] = Form.useForm()
+ const { addModel, models } = useProvider(provider.id)
+
+ const onOk = () => {
+ setOpen(false)
+ }
+
+ const onCancel = () => {
+ setOpen(false)
+ }
+
+ const onClose = () => {
+ resolve({})
+ }
+
+ const onFinish: FormProps['onFinish'] = (values) => {
+ if (find(models, { id: values.id })) {
+ Modal.error({ title: 'Error', content: 'Model ID already exists' })
+ return
+ }
+
+ const model: Model = {
+ id: values.id,
+ provider: provider.id,
+ name: values.name ? values.name : values.id.toUpperCase(),
+ group: getDefaultGroupName(values.group || values.id),
+ temperature: 0.7
+ }
+
+ addModel(model)
+
+ resolve(model)
+ }
+
+ return (
+
+
+
+
+
+ {
+ form.setFieldValue('name', e.target.value.toUpperCase())
+ form.setFieldValue('group', getDefaultGroupName(e.target.value))
+ }}
+ />
+
+
+
+
+
+
+
+
+
+
+
+
+ )
+}
+
+export default class ModalAddPopup {
+ static topviewId = 0
+ static hide() {
+ TopView.hide(this.topviewId)
+ }
+ static show(props: ShowParams) {
+ return new Promise((resolve) => {
+ this.topviewId = TopView.show(
+ {
+ resolve(v)
+ this.hide()
+ }}
+ />
+ )
+ })
+ }
+}
diff --git a/src/renderer/src/components/Popups/ModalListPopup.tsx b/src/renderer/src/pages/settings/components/ModelListPopup.tsx
similarity index 85%
rename from src/renderer/src/components/Popups/ModalListPopup.tsx
rename to src/renderer/src/pages/settings/components/ModelListPopup.tsx
index 14e6f76bba..1cf80cfb8b 100644
--- a/src/renderer/src/components/Popups/ModalListPopup.tsx
+++ b/src/renderer/src/pages/settings/components/ModelListPopup.tsx
@@ -1,8 +1,8 @@
-import { Avatar, Button, Modal } from 'antd'
+import { Avatar, Button, Empty, Modal } from 'antd'
import { useState } from 'react'
-import { TopView } from '../TopView'
+import { TopView } from '../../../components/TopView'
import { Model, Provider } from '@renderer/types'
-import { groupBy } from 'lodash'
+import { groupBy, isEmpty, uniqBy } from 'lodash'
import styled from 'styled-components'
import { MinusOutlined, PlusOutlined } from '@ant-design/icons'
import { useProvider } from '@renderer/hooks/useProvider'
@@ -19,10 +19,11 @@ interface Props extends ShowParams {
const PopupContainer: React.FC = ({ provider: _provider, resolve }) => {
const [open, setOpen] = useState(true)
- const { provider, addModel, removeModel } = useProvider(_provider.id)
+ const { provider, models, addModel, removeModel } = useProvider(_provider.id)
- const systemModels = SYSTEM_MODELS[_provider.id]
- const systemModelGroups = groupBy(systemModels, 'group')
+ const systemModels = SYSTEM_MODELS[_provider.id] || []
+ const allModels = uniqBy([...systemModels, ...models], 'id')
+ const systemModelGroups = groupBy(allModels, 'group')
const onOk = () => {
setOpen(false)
@@ -79,6 +80,7 @@ const PopupContainer: React.FC = ({ provider: _provider, resolve }) => {
})}
))}
+ {isEmpty(allModels) && }
)
@@ -124,7 +126,7 @@ const ListItemName = styled.div`
margin-left: 6px;
`
-export default class ModalListPopup {
+export default class ModelListPopup {
static topviewId = 0
static hide() {
TopView.hide(this.topviewId)
diff --git a/src/renderer/src/pages/settings/components/ProviderModals.tsx b/src/renderer/src/pages/settings/components/ProviderModels.tsx
similarity index 74%
rename from src/renderer/src/pages/settings/components/ProviderModals.tsx
rename to src/renderer/src/pages/settings/components/ProviderModels.tsx
index 1fbb86e15f..fef0541b1c 100644
--- a/src/renderer/src/pages/settings/components/ProviderModals.tsx
+++ b/src/renderer/src/pages/settings/components/ProviderModels.tsx
@@ -1,18 +1,20 @@
import { Provider } from '@renderer/types'
import { FC, useEffect, useState } from 'react'
import styled from 'styled-components'
-import { Avatar, Button, Card, Divider, Input } from 'antd'
+import { Avatar, Button, Card, Divider, Flex, Input } from 'antd'
import { useProvider } from '@renderer/hooks/useProvider'
-import ModalListPopup from '@renderer/components/Popups/ModalListPopup'
import { groupBy } from 'lodash'
import { SettingContainer, SettingSubtitle, SettingTitle } from './SettingComponent'
import { getModelLogo } from '@renderer/services/provider'
+import { EditOutlined, PlusOutlined } from '@ant-design/icons'
+import ModalAddPopup from './ModelAddPopup'
+import ModelListPopup from './ModelListPopup'
interface Props {
provider: Provider
}
-const ProviderModals: FC = ({ provider }) => {
+const ProviderModels: FC = ({ provider }) => {
const [apiKey, setApiKey] = useState(provider.apiKey)
const [apiHost, setApiHost] = useState(provider.apiHost)
const { updateProvider, models } = useProvider(provider.id)
@@ -32,8 +34,12 @@ const ProviderModals: FC = ({ provider }) => {
updateProvider({ ...provider, apiHost })
}
- const onAddModal = () => {
- ModalListPopup.show({ provider })
+ const onManageModel = () => {
+ ModelListPopup.show({ provider })
+ }
+
+ const onAddModel = () => {
+ ModalAddPopup.show({ title: 'Add Model', provider })
}
return (
@@ -66,9 +72,14 @@ const ProviderModals: FC = ({ provider }) => {
))}
))}
-
+
+ }>
+ Manage
+
+ }>
+ Add
+
+
)
}
@@ -81,4 +92,4 @@ const ModelListItem = styled.div`
padding: 5px 0;
`
-export default ProviderModals
+export default ProviderModels
diff --git a/src/renderer/src/services/provider.ts b/src/renderer/src/services/provider.ts
index 8efdfc8df8..c1d36e527f 100644
--- a/src/renderer/src/services/provider.ts
+++ b/src/renderer/src/services/provider.ts
@@ -4,6 +4,7 @@ import DeepSeekProviderLogo from '@renderer/assets/images/providers/deepseek.png
import YiProviderLogo from '@renderer/assets/images/providers/yi.svg'
import GroqProviderLogo from '@renderer/assets/images/providers/groq.png'
import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png'
+import OllamaProviderLogo from '@renderer/assets/images/providers/ollama.png'
import ChatGPTModelLogo from '@renderer/assets/images/models/chatgpt.jpeg'
import ChatGLMModelLogo from '@renderer/assets/images/models/chatglm.jpeg'
import DeepSeekModelLogo from '@renderer/assets/images/models/deepseek.png'
@@ -14,66 +15,42 @@ import LlamaModelLogo from '@renderer/assets/images/models/llama.jpeg'
import MixtralModelLogo from '@renderer/assets/images/models/mixtral.jpeg'
export function getProviderLogo(providerId: string) {
- if (providerId === 'openai') {
- return OpenAiProviderLogo
+ switch (providerId) {
+ case 'openai':
+ return OpenAiProviderLogo
+ case 'silicon':
+ return SiliconFlowProviderLogo
+ case 'deepseek':
+ return DeepSeekProviderLogo
+ case 'yi':
+ return YiProviderLogo
+ case 'groq':
+ return GroqProviderLogo
+ case 'zhipu':
+ return ZhipuProviderLogo
+ case 'ollama':
+ return OllamaProviderLogo
+ default:
+ return ''
}
-
- if (providerId === 'silicon') {
- return SiliconFlowProviderLogo
- }
-
- if (providerId === 'deepseek') {
- return DeepSeekProviderLogo
- }
-
- if (providerId === 'yi') {
- return YiProviderLogo
- }
-
- if (providerId === 'groq') {
- return GroqProviderLogo
- }
-
- if (providerId === 'zhipu') {
- return ZhipuProviderLogo
- }
-
- return ''
}
export function getModelLogo(modelId: string) {
- const _modelId = modelId.toLowerCase()
-
- if (_modelId.includes('gpt')) {
- return ChatGPTModelLogo
+ const logoMap = {
+ gpt: ChatGPTModelLogo,
+ glm: ChatGLMModelLogo,
+ deepseek: DeepSeekModelLogo,
+ qwen: QwenModelLogo,
+ gemma: GemmaModelLogo,
+ 'yi-': YiModelLogo,
+ llama: LlamaModelLogo,
+ mixtral: MixtralModelLogo
}
- if (_modelId.includes('glm')) {
- return ChatGLMModelLogo
- }
-
- if (_modelId.includes('deepseek')) {
- return DeepSeekModelLogo
- }
-
- if (_modelId.includes('qwen')) {
- return QwenModelLogo
- }
-
- if (_modelId.includes('gemma')) {
- return GemmaModelLogo
- }
-
- if (_modelId.includes('yi-')) {
- return YiModelLogo
- }
-
- if (_modelId.includes('llama')) {
- return LlamaModelLogo
- }
-
- if (_modelId.includes('mixtral')) {
- return MixtralModelLogo
+ for (const key in logoMap) {
+ if (modelId.toLowerCase().includes(key)) {
+ return logoMap[key]
+ }
}
return ''
diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts
index e43047a30a..6a29d7675c 100644
--- a/src/renderer/src/store/index.ts
+++ b/src/renderer/src/store/index.ts
@@ -19,7 +19,7 @@ const persistedReducer = persistReducer(
{
key: 'cherry-studio',
storage,
- version: 3,
+ version: 4,
blacklist: ['runtime'],
migrate
},
diff --git a/src/renderer/src/store/llm.ts b/src/renderer/src/store/llm.ts
index 98fbb2d898..23562de6a2 100644
--- a/src/renderer/src/store/llm.ts
+++ b/src/renderer/src/store/llm.ts
@@ -60,6 +60,14 @@ const initialState: LlmState = {
apiHost: 'https://api.groq.com/openai',
isSystem: true,
models: SYSTEM_MODELS.groq.filter((m) => m.defaultEnabled)
+ },
+ {
+ id: 'ollama',
+ name: 'Ollama',
+ apiKey: '',
+ apiHost: 'http://localhost:11434/v1/',
+ isSystem: true,
+ models: []
}
]
}
diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts
index e8b98b1082..be98a91516 100644
--- a/src/renderer/src/store/migrate.ts
+++ b/src/renderer/src/store/migrate.ts
@@ -42,6 +42,26 @@ const migrate = createMigrate({
]
}
}
+ },
+ // @ts-ignore store type is unknown
+ '4': (state: RootState) => {
+ return {
+ ...state,
+ llm: {
+ ...state.llm,
+ providers: [
+ ...state.llm.providers,
+ {
+ id: 'ollama',
+ name: 'Ollama',
+ apiKey: '',
+ apiHost: 'http://localhost:11434/v1/',
+ isSystem: true,
+ models: []
+ }
+ ]
+ }
+ }
}
})
diff --git a/src/renderer/src/utils/index.ts b/src/renderer/src/utils/index.ts
index 0aebc309f6..c51e40ef40 100644
--- a/src/renderer/src/utils/index.ts
+++ b/src/renderer/src/utils/index.ts
@@ -66,3 +66,18 @@ export const compressImage = async (file: File) => {
useWebWorker: false
})
}
+
+// Converts 'gpt-3.5-turbo-16k-0613' to 'GPT-3.5-Turbo'
+// Converts 'qwen2:1.5b' to 'QWEN2'
+export const getDefaultGroupName = (id: string) => {
+ if (id.includes(':')) {
+ return id.split(':')[0].toUpperCase()
+ }
+
+ if (id.includes('-')) {
+ const parts = id.split('-')
+ return parts[0].toUpperCase() + '-' + parts[1].toUpperCase()
+ }
+
+ return id.toUpperCase()
+}