refactor: match provider and model using a consistent method (#7933)

* refactor: match provider and model using a consistent method

* refactor: use keywords matching across model selectors

* refactor: update match, reuse getFancyProviderName

* refactor: use modelSelectFilter in knowledgebase settings

* refactor: use filter in ModelList

* refactor: add filterModelsByKeywords

* refactor: add getModelSelectOptions

* style: better function names

* fix: update effect dependencies in popup and panel components

Adjusted dependency arrays in HtmlArtifactsPopup and QuickPanelView to ensure correct effect execution. This change improves state synchronization and prevents unnecessary updates.

* refactor: use match in memory settings

* refactor: add avatar to model selector

* refactor: simplify utils, move select options to components

* docs: add comments

* refactor: move filter to SelectOptions

* test: add tests for SelectOptions

* test: remove type mock

* refactor: use match in EditModelsPopup

* refactor: use SelectOptions in SelectProviderModelPopup, add more tests

* fix: api check model select

* refactor: improve websearch rag model select style

* refactor: add a ModelSelector

* test: update tests for ModelSelector

* docs: comments

---------

Co-authored-by: 自由的世界人 <3196812536@qq.com>
This commit is contained in:
one 2025-07-23 10:45:09 +08:00 committed by GitHub
parent d0649d29fb
commit 736f73a726
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 763 additions and 352 deletions

View File

@ -247,7 +247,7 @@ The Enterprise Edition addresses core challenges in team collaboration by centra
| Feature | Community Edition | Enterprise Edition |
| :---------------- | :----------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------- |
| **Open Source** | ✅ Yes | ⭕️ Partially released to customers |
| **Open Source** | ✅ Yes | ⭕️ Partially released to customers |
| **Cost** | Free for Personal Use / Commercial License | Buyout / Subscription Fee |
| **Admin Backend** | — | ● Centralized **Model** Access<br>**Employee** Management<br>● Shared **Knowledge Base**<br>**Access** Control<br>**Data** Backup |
| **Server** | — | ✅ Dedicated Private Deployment |

View File

@ -145,8 +145,8 @@ In a development environment, you can define environment variables to filter dis
Environment variables can be set in the terminal or defined in the `.env` file in the project's root directory. The available variables are as follows:
| Variable Name | Description |
| ------- | ------- |
| Variable Name | Description |
| ------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| CSLOGGER_MAIN_LEVEL | Log level for the `main` process. Logs below this level will not be displayed. |
| CSLOGGER_MAIN_SHOW_MODULES | Filters log modules for the `main` process. Use a comma (`,`) to separate modules. The filter is case-sensitive. Only logs from modules in this list will be displayed. |
| CSLOGGER_RENDERER_LEVEL | Log level for the `renderer` process. Logs below this level will not be displayed. |
@ -160,6 +160,7 @@ CSLOGGER_MAIN_SHOW_MODULES=MCPService,SelectionService
```
Note:
- Environment variables are only effective in the development environment.
- These variables only affect the logs displayed in the terminal or DevTools. They do not affect file logging or the `logToMain` recording logic.
@ -168,7 +169,7 @@ Note:
There are many log levels. The following are the guidelines that should be followed in CherryStudio for when to use each level:
(Arranged from highest to lowest log level)
| Log Level | Core Definition & Use case | Example |
| Log Level | Core Definition & Use case | Example |
| :------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **`error`** | **Critical error causing the program to crash or core functionality to become unusable.** <br> This is the highest-priority log, usually requiring immediate reporting or user notification. | - Main or renderer process crash. <br> - Failure to read/write critical user data files (e.g., database, configuration files), preventing the application from running. <br> - All unhandled exceptions. |
| **`warn`** | **Potential issue or unexpected situation that does not affect the program's core functionality.** <br> The program can recover or use a fallback. | - Configuration file `settings.json` is missing; started with default settings. <br> - Auto-update check failed, but does not affect the use of the current version. <br> - A non-essential plugin failed to load. |

View File

@ -145,12 +145,12 @@ const logger = loggerService.initWindowSource('Worker').withContext('LetsWork')
环境变量可以在终端中自行设置,或者在开发根目录的`.env`文件中进行定义,可以定义的变量如下:
| 变量名 | 含义 |
| ----- | ----- |
| CSLOGGER_MAIN_LEVEL | 用于`main`进程的日志级别,低于该级别的日志将不显示 |
| CSLOGGER_MAIN_SHOW_MODULES | 用于`main`进程的日志module筛选用`,`分隔区分大小写。只有在该列表中的module的日志才会显示 |
| CSLOGGER_RENDERER_LEVEL | 用于`renderer`进程的日志级别,低于该级别的日志将不显示 |
| CSLOGGER_RENDERER_SHOW_MODULES | 用于`renderer`进程的日志module筛选用`,`分隔区分大小写。只有在该列表中的module的日志才会显示 |
| 变量名 | 含义 |
| ------------------------------ | ----------------------------------------------------------------------------------------------- |
| CSLOGGER_MAIN_LEVEL | 用于`main`进程的日志级别,低于该级别的日志将不显示 |
| CSLOGGER_MAIN_SHOW_MODULES | 用于`main`进程的日志module筛选用`,`分隔区分大小写。只有在该列表中的module的日志才会显示 |
| CSLOGGER_RENDERER_LEVEL | 用于`renderer`进程的日志级别,低于该级别的日志将不显示 |
| CSLOGGER_RENDERER_SHOW_MODULES | 用于`renderer`进程的日志module筛选用`,`分隔区分大小写。只有在该列表中的module的日志才会显示 |
示例:
@ -160,6 +160,7 @@ CSLOGGER_MAIN_SHOW_MODULES=MCPService,SelectionService
```
注意:
- 环境变量仅在开发环境中生效
- 该变量仅会改变在终端或在devTools中显示的日志不会影响文件日志和`logToMain`的记录逻辑

View File

@ -23,7 +23,13 @@ import { useProvider } from '@renderer/hooks/useProvider'
import FileItem from '@renderer/pages/files/FileItem'
import { fetchModels } from '@renderer/services/ApiService'
import { Model, Provider } from '@renderer/types'
import { getDefaultGroupName, isFreeModel, runAsyncFunction } from '@renderer/utils'
import {
filterModelsByKeywords,
getDefaultGroupName,
getFancyProviderName,
isFreeModel,
runAsyncFunction
} from '@renderer/utils'
import { Avatar, Button, Empty, Flex, Modal, Spin, Tabs, Tooltip } from 'antd'
import Input from 'antd/es/input/Input'
import { groupBy, isEmpty, uniqBy } from 'lodash'
@ -86,34 +92,30 @@ const PopupContainer: React.FC<Props> = ({ provider: _provider, resolve }) => {
const systemModels = SYSTEM_MODELS[_provider.id] || []
const allModels = uniqBy([...systemModels, ...listModels, ...models], 'id')
const list = allModels.filter((model) => {
if (
filterSearchText &&
!model.id.toLocaleLowerCase().includes(filterSearchText.toLocaleLowerCase()) &&
!model.name?.toLocaleLowerCase().includes(filterSearchText.toLocaleLowerCase())
) {
return false
}
switch (actualFilterType) {
case 'reasoning':
return isReasoningModel(model)
case 'vision':
return isVisionModel(model)
case 'websearch':
return isWebSearchModel(model)
case 'free':
return isFreeModel(model)
case 'embedding':
return isEmbeddingModel(model)
case 'function_calling':
return isFunctionCallingModel(model)
case 'rerank':
return isRerankModel(model)
default:
return true
}
})
const list = useMemo(
() =>
filterModelsByKeywords(filterSearchText, allModels).filter((model) => {
switch (actualFilterType) {
case 'reasoning':
return isReasoningModel(model)
case 'vision':
return isVisionModel(model)
case 'websearch':
return isWebSearchModel(model)
case 'free':
return isFreeModel(model)
case 'embedding':
return isEmbeddingModel(model)
case 'function_calling':
return isFunctionCallingModel(model)
case 'rerank':
return isRerankModel(model)
default:
return true
}
}),
[filterSearchText, actualFilterType, allModels]
)
const modelGroups = useMemo(
() =>
@ -202,7 +204,7 @@ const PopupContainer: React.FC<Props> = ({ provider: _provider, resolve }) => {
return (
<Flex>
<ModelHeaderTitle>
{provider.isSystem ? t(`provider.${provider.id}`) : provider.name}
{getFancyProviderName(provider)}
{i18n.language.startsWith('zh') ? '' : ' '}
{t('common.models')}
</ModelHeaderTitle>

View File

@ -12,6 +12,7 @@ import { SettingHelpLink, SettingHelpText, SettingHelpTextRow, SettingSubtitle }
import { useAppDispatch } from '@renderer/store'
import { setModel } from '@renderer/store/assistants'
import { Model } from '@renderer/types'
import { filterModelsByKeywords } from '@renderer/utils'
import { Button, Flex, Tooltip } from 'antd'
import { groupBy, sortBy, toPairs } from 'lodash'
import { ListCheck, Plus } from 'lucide-react'
@ -51,9 +52,7 @@ const ModelList: React.FC<ModelListProps> = ({ providerId }) => {
}, [])
const modelGroups = useMemo(() => {
const filteredModels = searchText
? models.filter((model) => model.name.toLowerCase().includes(searchText.toLowerCase()))
: models
const filteredModels = searchText ? filterModelsByKeywords(searchText, models) : models
return groupBy(filteredModels, 'group')
}, [searchText, models])

View File

@ -0,0 +1,123 @@
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
import { getModelUniqId } from '@renderer/services/ModelService'
import { Model, Provider } from '@renderer/types'
import { matchKeywordsInString } from '@renderer/utils'
import { getFancyProviderName } from '@renderer/utils/naming'
import { Select, SelectProps } from 'antd'
import { sortBy } from 'lodash'
import { BaseSelectRef } from 'rc-select'
import { memo, useCallback, useMemo } from 'react'
interface ModelOption {
label: React.ReactNode
title: string
value: string
}
interface GroupedModelOption {
label: string
title: string
options: ModelOption[]
}
type SelectOption = ModelOption | GroupedModelOption
interface ModelSelectorProps extends SelectProps {
providers?: Provider[]
predicate?: (model: Model) => boolean
grouped?: boolean
showAvatar?: boolean
showSuffix?: boolean
}
/**
* antd Select
* - predicate
* -
* - avatar suffix
* @param providers
* @param predicate
* @param grouped
* @param showAvatar
* @param showSuffix
*/
const ModelSelector = ({
providers,
predicate,
grouped = true,
showAvatar = true,
showSuffix = true,
ref,
...props
}: ModelSelectorProps & { ref?: React.Ref<BaseSelectRef> | null }) => {
// 单个 provider 的模型选项
const getModelOptions = useCallback(
(p: Provider, fancyName: string) => {
const suffix = showSuffix ? <span style={{ opacity: 0.45 }}>{` | ${fancyName}`}</span> : null
return sortBy(p.models, 'name')
.filter((model) => predicate?.(model) ?? true)
.map((m) => ({
label: (
<div style={{ display: 'flex', alignItems: 'center', gap: 8 }}>
{showAvatar && <ModelAvatar model={m} size={18} />}
<span>
{m.name}
{suffix}
</span>
</div>
),
title: `${m.name} | ${fancyName}`,
value: getModelUniqId(m)
}))
},
[predicate, showAvatar, showSuffix]
)
// 所有 provider 的模型选项
const options = useMemo((): SelectOption[] => {
if (!providers) return []
if (grouped) {
return providers.flatMap((p) => {
const fancyName = getFancyProviderName(p)
const modelOptions = getModelOptions(p, fancyName)
return modelOptions.length > 0
? [
{
label: fancyName,
title: p.name,
options: modelOptions
} as GroupedModelOption
]
: []
})
}
return providers.flatMap((p) => getModelOptions(p, getFancyProviderName(p)))
}, [providers, grouped, getModelOptions])
return <Select ref={ref} options={options} filterOption={modelSelectFilter} showSearch {...props} />
}
export default memo(ModelSelector)
/**
* antd Select filterOption
* - 使 title
* - 使 label
* - 使 value
*
* @param input
* @param option Select label value
* @returns
*/
export function modelSelectFilter(input: string, option: any): boolean {
const target =
typeof option?.title === 'string'
? option.title
: typeof option?.label === 'string'
? option.label
: typeof option?.value === 'string'
? option.value
: ''
return matchKeywordsInString(input, target)
}

View File

@ -7,7 +7,7 @@ import { usePinnedModels } from '@renderer/hooks/usePinnedModels'
import { useProviders } from '@renderer/hooks/useProvider'
import { getModelUniqId } from '@renderer/services/ModelService'
import { Model } from '@renderer/types'
import { classNames } from '@renderer/utils/style'
import { classNames, filterModelsByKeywords, getFancyProviderName } from '@renderer/utils'
import { Avatar, Divider, Empty, Input, InputRef, Modal } from 'antd'
import { first, sortBy } from 'lodash'
import { Search } from 'lucide-react'
@ -102,27 +102,19 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter }) => {
let models = provider.models.filter((m) => !isEmbeddingModel(m) && !isRerankModel(m))
if (searchText.trim()) {
const keywords = searchText.toLowerCase().split(/\s+/).filter(Boolean)
models = models.filter((m) => {
const fullName = provider.isSystem
? `${m.name} ${provider.name} ${t('provider.' + provider.id)}`
: `${m.name} ${provider.name}`
const lowerFullName = fullName.toLowerCase()
return keywords.every((keyword) => lowerFullName.includes(keyword))
})
models = filterModelsByKeywords(searchText, models, provider)
}
return sortBy(models, ['group', 'name'])
},
[searchText, t]
[searchText]
)
// 创建模型列表项
const createModelItem = useCallback(
(model: Model, provider: any, isPinned: boolean): FlatListItem => {
const modelId = getModelUniqId(model)
const groupName = provider.isSystem ? t(`provider.${provider.id}`) : provider.name
const groupName = getFancyProviderName(provider)
return {
key: isPinned ? `${modelId}_pinned` : modelId,
@ -148,7 +140,7 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter }) => {
isSelected: modelId === currentModelId
}
},
[t, currentModelId]
[currentModelId]
)
// 构建扁平化列表数据
@ -189,7 +181,7 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter }) => {
items.push({
key: `provider-${p.id}`,
type: 'group',
name: p.isSystem ? t(`provider.${p.id}`) : p.name,
name: getFancyProviderName(p),
isSelected: false
})

View File

@ -0,0 +1,225 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { describe, expect, it, vi } from 'vitest'
// Mock the imported modules
vi.mock('@renderer/components/Avatar/ModelAvatar', () => ({
default: ({ model, size }: any) => (
<div data-testid="model-avatar" style={{ width: size, height: size }}>
{model.name.charAt(0)}
</div>
)
}))
vi.mock('@renderer/services/ModelService', () => ({
getModelUniqId: (model: any) => `${model.provider}-${model.id}`
}))
vi.mock('@renderer/utils', () => ({
matchKeywordsInString: (input: string, target: string) => target.toLowerCase().includes(input.toLowerCase())
}))
vi.mock('@renderer/utils/naming', () => ({
getFancyProviderName: (provider: any) => provider.name
}))
// Import after mocking
import { Provider } from '@renderer/types'
import ModelSelector, { modelSelectFilter } from '../ModelSelector'
describe('ModelSelector', () => {
const mockProviders: Provider[] = [
{
id: 'openai',
name: 'OpenAI',
type: 'openai',
apiKey: '123',
apiHost: 'https://api.openai.com',
models: [
{ id: 'text-embedding-ada-002', name: 'text-embedding-ada-002', provider: 'openai', group: 'embedding' },
{ id: 'gpt-4.1', name: 'GPT-4.1', provider: 'openai', group: 'chat' }
]
},
{
id: 'cohere',
name: 'Cohere',
type: 'openai',
apiKey: '123',
apiHost: 'https://api.cohere.com',
models: [
{ id: 'embed-english-v3.0', name: 'embed-english-v3.0', provider: 'cohere', group: 'embedding' },
{ id: 'rerank-english-v2.0', name: 'rerank-english-v2.0', provider: 'cohere', group: 'rerank' }
]
},
{
id: 'empty-provider',
name: 'EmptyProvider',
type: 'openai',
apiKey: '123',
apiHost: 'https://api.cohere.com',
models: []
}
]
describe('grouped mode (grouped=true)', () => {
it('should render grouped options and apply predicate', () => {
render(
<ModelSelector
providers={mockProviders}
predicate={(model) => model.group === 'embedding'}
open // Keep dropdown open for testing
/>
)
// Check for group labels
expect(screen.getByText('OpenAI')).toBeInTheDocument()
expect(screen.getByText('Cohere')).toBeInTheDocument()
expect(screen.queryByText('EmptyProvider')).not.toBeInTheDocument()
// Check for correct models
const ada = screen.getByText('text-embedding-ada-002')
const cohere = screen.getByText('embed-english-v3.0')
expect(ada).toBeInTheDocument()
expect(cohere).toBeInTheDocument()
// Check suffix is present by default
expect(ada.textContent).toContain(' | OpenAI')
expect(cohere.textContent).toContain(' | Cohere')
// Check that filtered models are not present
expect(screen.queryByText('GPT-4.1')).not.toBeInTheDocument()
expect(screen.queryByText('rerank-english-v2.0')).not.toBeInTheDocument()
})
it('should hide suffix when showSuffix is false', () => {
render(
<ModelSelector
providers={mockProviders}
predicate={(model) => model.group === 'embedding'}
showSuffix={false}
open
/>
)
const ada = screen.getByText('text-embedding-ada-002')
expect(ada.textContent).toBe('text-embedding-ada-002')
expect(ada.textContent).not.toContain(' | OpenAI')
})
it('should hide avatar when showAvatar is false', () => {
render(<ModelSelector providers={mockProviders} showAvatar={false} open />)
expect(screen.queryByTestId('model-avatar')).not.toBeInTheDocument()
})
it('should show avatar when showAvatar is true', () => {
render(<ModelSelector providers={mockProviders} showAvatar={true} open />)
// 4 models in total from mockProviders
expect(screen.getAllByTestId('model-avatar')).toHaveLength(4)
})
})
describe('flat mode (grouped=false)', () => {
it('should render flat options and apply predicate', () => {
render(
<ModelSelector
providers={mockProviders}
predicate={(model) => model.group === 'embedding'}
grouped={false}
open
/>
)
// In flat mode, there are no group labels in the dropdown structure
expect(document.querySelector('.ant-select-item-option-group')).toBeNull()
// Check for correct models
const ada = screen.getByText('text-embedding-ada-002')
const cohere = screen.getByText('embed-english-v3.0')
expect(ada).toBeInTheDocument()
expect(cohere).toBeInTheDocument()
// Check suffix is present by default
expect(ada.textContent).toContain(' | OpenAI')
expect(cohere.textContent).toContain(' | Cohere')
// Check that filtered models are not present
expect(screen.queryByText('GPT-4.1')).not.toBeInTheDocument()
expect(screen.queryByText('rerank-english-v2.0')).not.toBeInTheDocument()
})
it('should hide suffix when showSuffix is false', () => {
render(<ModelSelector providers={mockProviders} grouped={false} showSuffix={false} open />)
const gpt4 = screen.getByText('GPT-4.1')
expect(gpt4.textContent).toBe('GPT-4.1')
expect(gpt4.textContent).not.toContain(' | OpenAI')
})
})
describe('edge cases', () => {
it('should handle empty providers array', () => {
render(<ModelSelector providers={[]} open />)
expect(document.querySelector('.ant-select-item-option')).toBeNull()
})
it('should handle undefined providers', () => {
render(<ModelSelector providers={undefined} open />)
expect(document.querySelector('.ant-select-item-option')).toBeNull()
})
})
describe('modelSelectFilter function', () => {
it('should filter by provider name in title', () => {
const mockOption = {
title: 'GPT-4.1 | OpenAI',
value: 'openai-gpt-4.1'
}
expect(modelSelectFilter('openai', mockOption)).toBe(true)
})
it('should filter by model name in title', () => {
const mockOption = {
title: 'embed-english-v3.0 | Cohere',
value: 'cohere-embed-english-v3.0'
}
expect(modelSelectFilter('english', mockOption)).toBe(true)
})
it('should filter by value if title is not present', () => {
const mockOption = {
value: 'openai-gpt-4.1'
}
expect(modelSelectFilter('gpt', mockOption)).toBe(true)
})
it('should return false for no match', () => {
const mockOption = {
title: 'GPT-4.1 | OpenAI',
value: 'openai-gpt-4.1'
}
expect(modelSelectFilter('nonexistent', mockOption)).toBe(false)
})
})
describe('integration', () => {
it('should filter options correctly when user types in search input', async () => {
const user = userEvent.setup()
render(<ModelSelector providers={mockProviders} open />)
// Find the search input field, which is a combobox
const searchInput = screen.getByRole('combobox')
await user.type(searchInput, 'embed')
// After filtering, only embedding models should be visible
expect(screen.getByText('text-embedding-ada-002')).toBeInTheDocument()
expect(screen.getByText('embed-english-v3.0')).toBeInTheDocument()
// Other models should not be visible
expect(screen.queryByText('GPT-4.1')).not.toBeInTheDocument()
expect(screen.queryByText('rerank-english-v2.0')).not.toBeInTheDocument()
// The group titles for visible items should still be there
expect(screen.getByText('OpenAI')).toBeInTheDocument()
expect(screen.getByText('Cohere')).toBeInTheDocument()
})
})
})

View File

@ -6,6 +6,7 @@ import db from '@renderer/databases'
import { useProviders } from '@renderer/hooks/useProvider'
import { getModelUniqId } from '@renderer/services/ModelService'
import { FileType, Model } from '@renderer/types'
import { getFancyProviderName } from '@renderer/utils'
import { Avatar, Tooltip } from 'antd'
import { useLiveQuery } from 'dexie-react-hooks'
import { first, sortBy } from 'lodash'
@ -62,7 +63,7 @@ const MentionModelsButton: FC<Props> = ({
.map((m) => ({
label: (
<>
<ProviderName>{p.isSystem ? t(`provider.${p.id}`) : p.name}</ProviderName>
<ProviderName>{getFancyProviderName(p)}</ProviderName>
<span style={{ opacity: 0.8 }}> | {m.name}</span>
</>
),
@ -72,7 +73,7 @@ const MentionModelsButton: FC<Props> = ({
{first(m.name)}
</Avatar>
),
filterText: (p.isSystem ? t(`provider.${p.id}`) : p.name) + m.name,
filterText: getFancyProviderName(p) + m.name,
action: () => onMentionModel(m),
isSelected: mentionedModels.some((selected) => getModelUniqId(selected) === getModelUniqId(m))
}))
@ -95,7 +96,7 @@ const MentionModelsButton: FC<Props> = ({
const providerModelItems = providerModels.map((m) => ({
label: (
<>
<ProviderName>{p.isSystem ? t(`provider.${p.id}`) : p.name}</ProviderName>
<ProviderName>{getFancyProviderName(p)}</ProviderName>
<span style={{ opacity: 0.8 }}> | {m.name}</span>
</>
),
@ -105,7 +106,7 @@ const MentionModelsButton: FC<Props> = ({
{first(m.name)}
</Avatar>
),
filterText: (p.isSystem ? t(`provider.${p.id}`) : p.name) + m.name,
filterText: getFancyProviderName(p) + m.name,
action: () => onMentionModel(m),
isSelected: mentionedModels.some((selected) => getModelUniqId(selected) === getModelUniqId(m))
}))

View File

@ -2,8 +2,8 @@ import CustomTag from '@renderer/components/CustomTag'
import { useProviders } from '@renderer/hooks/useProvider'
import { getModelUniqId } from '@renderer/services/ModelService'
import { Model } from '@renderer/types'
import { getFancyProviderName } from '@renderer/utils'
import { FC } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
const MentionModelsInput: FC<{
@ -11,11 +11,10 @@ const MentionModelsInput: FC<{
onRemoveModel: (model: Model) => void
}> = ({ selectedModels, onRemoveModel }) => {
const { providers } = useProviders()
const { t } = useTranslation()
const getProviderName = (model: Model) => {
const provider = providers.find((p) => p.id === model?.provider)
return provider ? (provider.isSystem ? t(`provider.${provider.id}`) : provider.name) : ''
return provider ? getFancyProviderName(provider) : ''
}
return (

View File

@ -2,11 +2,11 @@ import { InfoCircleOutlined, WarningOutlined } from '@ant-design/icons'
import { loggerService } from '@logger'
import AiProvider from '@renderer/aiCore'
import { HStack } from '@renderer/components/Layout'
import ModelSelector from '@renderer/components/ModelSelector'
import { TopView } from '@renderer/components/TopView'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, isMac } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
import { useOcrProviders } from '@renderer/hooks/useOcr'
import { usePreprocessProviders } from '@renderer/hooks/usePreprocess'
@ -16,7 +16,7 @@ import { getModelUniqId } from '@renderer/services/ModelService'
import { KnowledgeBase, Model, OcrProvider, PreprocessProvider } from '@renderer/types'
import { getErrorMessage } from '@renderer/utils/error'
import { Alert, Input, InputNumber, Modal, Select, Slider, Switch, Tooltip } from 'antd'
import { find, sortBy } from 'lodash'
import { find } from 'lodash'
import { ChevronDown } from 'lucide-react'
import { nanoid } from 'nanoid'
import { useEffect, useMemo, useRef, useState } from 'react'
@ -65,41 +65,6 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
const nameInputRef = useRef<any>(null)
const scrollContainerRef = useRef<HTMLDivElement>(null)
const embeddingSelectOptions = useMemo(() => {
return providers
.filter((p) => p.models.length > 0)
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => isEmbeddingModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m),
providerId: p.id,
modelId: m.id
}))
}))
.filter((group) => group.options.length > 0)
}, [providers, t])
const rerankSelectOptions = useMemo(() => {
return providers
.filter((p) => p.models.length > 0)
.filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => isRerankModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m)
}))
}))
.filter((group) => group.options.length > 0)
}, [providers, t])
const preprocessOrOcrSelectOptions = useMemo(() => {
const preprocessOptions = {
label: t('settings.tool.preprocess.provider'),
@ -292,9 +257,10 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
<InfoCircleOutlined style={{ marginLeft: 8, color: 'var(--color-text-3)' }} />
</Tooltip>
</div>
<Select
<ModelSelector
providers={providers}
predicate={isEmbeddingModel}
style={{ width: '100%' }}
options={embeddingSelectOptions}
placeholder={t('settings.models.empty')}
onChange={(value) => {
const model = value
@ -313,9 +279,10 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
<InfoCircleOutlined style={{ marginLeft: 8, color: 'var(--color-text-3)' }} />
</Tooltip>
</div>
<Select
<ModelSelector
providers={providers}
predicate={isRerankModel}
style={{ width: '100%' }}
options={rerankSelectOptions}
placeholder={t('settings.models.empty')}
onChange={(value) => {
const rerankModel = value

View File

@ -1,6 +1,7 @@
import { InfoCircleOutlined, WarningOutlined } from '@ant-design/icons'
import { loggerService } from '@logger'
import { HStack } from '@renderer/components/Layout'
import ModelSelector from '@renderer/components/ModelSelector'
import { TopView } from '@renderer/components/TopView'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, isMac } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
@ -12,7 +13,6 @@ import { useProviders } from '@renderer/hooks/useProvider'
import { getModelUniqId } from '@renderer/services/ModelService'
import { KnowledgeBase, PreprocessProvider } from '@renderer/types'
import { Alert, Input, InputNumber, Menu, Modal, Select, Slider, Tooltip } from 'antd'
import { sortBy } from 'lodash'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@ -46,34 +46,6 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
return null
}
const selectOptions = providers
.filter((p) => p.models.length > 0)
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => isEmbeddingModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m)
}))
}))
.filter((group) => group.options.length > 0)
const rerankSelectOptions = providers
.filter((p) => p.models.length > 0)
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => isRerankModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m)
}))
}))
.filter((group) => group.options.length > 0)
const preprocessOptions = {
label: t('settings.tool.preprocess.provider'),
title: t('settings.tool.preprocess.provider'),
@ -181,9 +153,10 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
<InfoCircleOutlined style={{ marginLeft: 8, color: 'var(--color-text-3)' }} />
</Tooltip>
</div>
<Select
<ModelSelector
providers={providers}
predicate={isEmbeddingModel}
style={{ width: '100%' }}
options={selectOptions}
placeholder={t('settings.models.empty')}
defaultValue={getModelUniqId(base.model)}
disabled
@ -202,10 +175,11 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
<InfoCircleOutlined style={{ marginLeft: 8, color: 'var(--color-text-3)' }} />
</Tooltip>
</div>
<Select
<ModelSelector
providers={providers}
predicate={isRerankModel}
style={{ width: '100%' }}
defaultValue={getModelUniqId(base.rerankModel) || undefined}
options={rerankSelectOptions}
placeholder={t('settings.models.empty')}
onChange={(value) => {
const rerankModel = value

View File

@ -1,15 +1,16 @@
import { loggerService } from '@logger'
import AiProvider from '@renderer/aiCore'
import ModelSelector from '@renderer/components/ModelSelector'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { useModel } from '@renderer/hooks/useModel'
import { useProviders } from '@renderer/hooks/useProvider'
import { getModelUniqId } from '@renderer/services/ModelService'
import { selectMemoryConfig, updateMemoryConfig } from '@renderer/store/memory'
import { Model } from '@renderer/types'
import { getErrorMessage } from '@renderer/utils/error'
import { Form, InputNumber, Modal, Select, Switch } from 'antd'
import { Form, InputNumber, Modal, Switch } from 'antd'
import { t } from 'i18next'
import { sortBy } from 'lodash'
import { FC, useEffect, useState } from 'react'
import { FC, useCallback, useEffect, useState } from 'react'
import { useDispatch, useSelector } from 'react-redux'
interface MemoriesSettingsModalProps {
@ -129,33 +130,9 @@ const MemoriesSettingsModal: FC<MemoriesSettingsModalProps> = ({ visible, onSubm
}
}
const llmSelectOptions = providers
.filter((p) => p.models.length > 0)
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => !isEmbeddingModel(model) && !isRerankModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m)
}))
}))
.filter((group) => group.options.length > 0)
const llmPredicate = useCallback((m: Model) => !isEmbeddingModel(m) && !isRerankModel(m), [])
const embeddingSelectOptions = providers
.filter((p) => p.models.length > 0)
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => isEmbeddingModel(model) && !isRerankModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m)
}))
}))
.filter((group) => group.options.length > 0)
const embeddingPredicate = useCallback((m: Model) => isEmbeddingModel(m) && !isRerankModel(m), [])
return (
<Modal
@ -183,13 +160,21 @@ const MemoriesSettingsModal: FC<MemoriesSettingsModalProps> = ({ visible, onSubm
label={t('memory.llm_model')}
name="llmModel"
rules={[{ required: true, message: t('memory.please_select_llm_model') }]}>
<Select placeholder={t('memory.select_llm_model_placeholder')} options={llmSelectOptions} showSearch />
<ModelSelector
providers={providers}
predicate={llmPredicate}
placeholder={t('memory.select_llm_model_placeholder')}
/>
</Form.Item>
<Form.Item
label={t('memory.embedding_model')}
name="embedderModel"
rules={[{ required: true, message: t('memory.please_select_embedding_model') }]}>
<Select placeholder={t('memory.select_embedding_model_placeholder')} options={embeddingSelectOptions} />
<ModelSelector
providers={providers}
predicate={embeddingPredicate}
placeholder={t('memory.select_embedding_model_placeholder')}
/>
</Form.Item>
<Form.Item
label={t('knowledge.dimensions_auto_set')}

View File

@ -1,15 +1,16 @@
import { loggerService } from '@logger'
import AiProvider from '@renderer/aiCore'
import ModelSelector from '@renderer/components/ModelSelector'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { useModel } from '@renderer/hooks/useModel'
import { useProviders } from '@renderer/hooks/useProvider'
import { getModelUniqId } from '@renderer/services/ModelService'
import { selectMemoryConfig, updateMemoryConfig } from '@renderer/store/memory'
import { Model } from '@renderer/types'
import { getErrorMessage } from '@renderer/utils/error'
import { Form, InputNumber, Modal, Select, Switch } from 'antd'
import { Form, InputNumber, Modal, Switch } from 'antd'
import { t } from 'i18next'
import { sortBy } from 'lodash'
import { FC, useEffect, useState } from 'react'
import { FC, useCallback, useEffect, useState } from 'react'
import { useDispatch, useSelector } from 'react-redux'
const logger = loggerService.withContext('MemoriesSettingsModal')
@ -125,33 +126,9 @@ const MemoriesSettingsModal: FC<MemoriesSettingsModalProps> = ({ visible, onSubm
}
}
const llmSelectOptions = providers
.filter((p) => p.models.length > 0)
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => !isEmbeddingModel(model) && !isRerankModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m)
}))
}))
.filter((group) => group.options.length > 0)
const llmPredicate = useCallback((m: Model) => !isEmbeddingModel(m) && !isRerankModel(m), [])
const embeddingSelectOptions = providers
.filter((p) => p.models.length > 0)
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => isEmbeddingModel(model) && !isRerankModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m)
}))
}))
.filter((group) => group.options.length > 0)
const embeddingPredicate = useCallback((m: Model) => isEmbeddingModel(m) && !isRerankModel(m), [])
return (
<Modal
@ -179,13 +156,21 @@ const MemoriesSettingsModal: FC<MemoriesSettingsModalProps> = ({ visible, onSubm
label={t('memory.llm_model')}
name="llmModel"
rules={[{ required: true, message: t('memory.please_select_llm_model') }]}>
<Select placeholder={t('memory.select_llm_model_placeholder')} options={llmSelectOptions} showSearch />
<ModelSelector
providers={providers}
predicate={llmPredicate}
placeholder={t('memory.select_llm_model_placeholder')}
/>
</Form.Item>
<Form.Item
label={t('memory.embedding_model')}
name="embedderModel"
rules={[{ required: true, message: t('memory.please_select_embedding_model') }]}>
<Select placeholder={t('memory.select_embedding_model_placeholder')} options={embeddingSelectOptions} />
<ModelSelector
providers={providers}
predicate={embeddingPredicate}
placeholder={t('memory.select_embedding_model_placeholder')}
/>
</Form.Item>
<Form.Item
label={t('knowledge.dimensions_auto_set')}

View File

@ -1,6 +1,7 @@
import { RedoOutlined } from '@ant-design/icons'
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
import { HStack } from '@renderer/components/Layout'
import ModelSelector from '@renderer/components/ModelSelector'
import PromptPopup from '@renderer/components/Popups/PromptPopup'
import { isEmbeddingModel, isRerankModel, isTextToImageModel } from '@renderer/config/models'
import { TRANSLATE_PROMPT } from '@renderer/config/prompts'
@ -15,9 +16,9 @@ import { setQuickAssistantId } from '@renderer/store/llm'
import { setTranslateModelPrompt } from '@renderer/store/settings'
import { Model } from '@renderer/types'
import { Button, Select, Tooltip } from 'antd'
import { find, sortBy } from 'lodash'
import { find } from 'lodash'
import { CircleHelp, FolderPen, Languages, MessageSquareMore, Rocket, Settings2 } from 'lucide-react'
import { FC, useMemo } from 'react'
import { FC, useCallback, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@ -39,27 +40,10 @@ const ModelSettings: FC = () => {
const dispatch = useAppDispatch()
const { quickAssistantId } = useAppSelector((state) => state.llm)
const selectOptions = providers
.filter((p) => p.models.length > 0)
.flatMap((p) => {
const filteredModels = sortBy(p.models, 'name')
.filter((m) => !isEmbeddingModel(m) && !isRerankModel(m) && !isTextToImageModel(m))
.map((m) => ({
label: `${m.name} | ${p.isSystem ? t(`provider.${p.id}`) : p.name}`,
value: getModelUniqId(m)
}))
if (filteredModels.length > 0) {
return [
{
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: filteredModels
}
]
}
return []
})
const modelPredicate = useCallback(
(m: Model) => !isEmbeddingModel(m) && !isRerankModel(m) && !isTextToImageModel(m),
[]
)
const defaultModelValue = useMemo(
() => (hasModel(defaultModel) ? getModelUniqId(defaultModel) : undefined),
@ -105,13 +89,13 @@ const ModelSettings: FC = () => {
</HStack>
</SettingTitle>
<HStack alignItems="center">
<Select
<ModelSelector
providers={providers}
predicate={modelPredicate}
value={defaultModelValue}
defaultValue={defaultModelValue}
style={{ width: 360 }}
onChange={(value) => setDefaultModel(find(allModels, JSON.parse(value)) as Model)}
options={selectOptions}
showSearch
placeholder={t('settings.models.empty')}
/>
<Button icon={<Settings2 size={16} />} style={{ marginLeft: 8 }} onClick={DefaultAssistantSettings.show} />
@ -126,13 +110,13 @@ const ModelSettings: FC = () => {
</HStack>
</SettingTitle>
<HStack alignItems="center">
<Select
<ModelSelector
providers={providers}
predicate={modelPredicate}
value={defaultTopicNamingModel}
defaultValue={defaultTopicNamingModel}
style={{ width: 360 }}
onChange={(value) => setTopicNamingModel(find(allModels, JSON.parse(value)) as Model)}
options={selectOptions}
showSearch
placeholder={t('settings.models.empty')}
/>
<Button icon={<Settings2 size={16} />} style={{ marginLeft: 8 }} onClick={TopicNamingModalPopup.show} />
@ -147,13 +131,13 @@ const ModelSettings: FC = () => {
</HStack>
</SettingTitle>
<HStack alignItems="center">
<Select
<ModelSelector
providers={providers}
predicate={modelPredicate}
value={defaultTranslateModel}
defaultValue={defaultTranslateModel}
style={{ width: 360 }}
onChange={(value) => setTranslateModel(find(allModels, JSON.parse(value)) as Model)}
options={selectOptions}
showSearch
placeholder={t('settings.models.empty')}
/>
<Button icon={<Settings2 size={16} />} style={{ marginLeft: 8 }} onClick={onUpdateTranslateModel} />

View File

@ -1,10 +1,12 @@
import ModelSelector from '@renderer/components/ModelSelector'
import { TopView } from '@renderer/components/TopView'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { isRerankModel } from '@renderer/config/models'
import i18n from '@renderer/i18n'
import { Provider } from '@renderer/types'
import { Modal, Select } from 'antd'
import { first, orderBy } from 'lodash'
import { useState } from 'react'
import { getModelUniqId } from '@renderer/services/ModelService'
import { Model, Provider } from '@renderer/types'
import { Modal } from 'antd'
import { first } from 'lodash'
import { useCallback, useMemo, useState } from 'react'
interface ShowParams {
provider: Provider
@ -16,10 +18,19 @@ interface Props extends ShowParams {
}
const PopupContainer: React.FC<Props> = ({ provider, resolve, reject }) => {
const models = orderBy(provider.models, 'group').filter((i) => !isEmbeddingModel(i) && !isRerankModel(i))
const [open, setOpen] = useState(true)
// Keep the natural order of models
const models = useMemo(() => provider.models.filter((m) => !isRerankModel(m)), [provider])
const [model, setModel] = useState(first(models))
const modelPredicate = useCallback((m: Model) => !isRerankModel(m), [])
const defaultModelValue = useMemo(() => {
return model ? getModelUniqId(model) : undefined
}, [model])
const onOk = () => {
if (!model) {
window.message.error({ content: i18n.t('message.error.enter.model'), key: 'api-check' })
@ -50,14 +61,15 @@ const PopupContainer: React.FC<Props> = ({ provider, resolve, reject }) => {
transitionName="animation-move-down"
width={400}
centered>
<Select
value={model?.id}
<ModelSelector
providers={[provider]}
predicate={modelPredicate}
grouped={false}
defaultValue={defaultModelValue}
placeholder={i18n.t('settings.models.empty')}
options={models.map((m) => ({ label: m.name, value: m.id }))}
style={{ width: '100%' }}
showSearch
onChange={(value) => {
setModel(provider.models.find((m) => m.id === value)!)
setModel(models.find((m) => value === getModelUniqId(m))!)
}}
/>
</Modal>

View File

@ -7,7 +7,15 @@ import { useAllProviders, useProviders } from '@renderer/hooks/useProvider'
import ImageStorage from '@renderer/services/ImageStorage'
import { INITIAL_PROVIDERS } from '@renderer/store/llm'
import { Provider, ProviderType } from '@renderer/types'
import { droppableReorder, generateColorFromChar, getFirstCharacter, uuid } from '@renderer/utils'
import {
droppableReorder,
generateColorFromChar,
getFancyProviderName,
getFirstCharacter,
matchKeywordsInModel,
matchKeywordsInProvider,
uuid
} from '@renderer/utils'
import { Avatar, Button, Card, Dropdown, Input, MenuProps, Tag } from 'antd'
import { Eye, EyeOff, Search, UserPen } from 'lucide-react'
import { FC, useEffect, useState } from 'react'
@ -420,19 +428,9 @@ const ProvidersList: FC = () => {
}
const filteredProviders = providers.filter((provider) => {
const providerName = provider.isSystem ? t(`provider.${provider.id}`) : provider.name
const isProviderMatch =
provider.id.toLowerCase().includes(searchText.toLowerCase()) ||
providerName.toLowerCase().includes(searchText.toLowerCase())
const isModelMatch = provider.models.some((model) => {
return (
model.id.toLowerCase().includes(searchText.toLowerCase()) ||
model.name.toLowerCase().includes(searchText.toLowerCase())
)
})
const keywords = searchText.toLowerCase().split(/\s+/).filter(Boolean)
const isProviderMatch = matchKeywordsInProvider(keywords, provider)
const isModelMatch = provider.models.some((model) => matchKeywordsInModel(keywords, model))
return isProviderMatch || isModelMatch
})
@ -481,7 +479,7 @@ const ProvidersList: FC = () => {
onClick={() => setSelectedProvider(provider)}>
{getProviderAvatar(provider)}
<ProviderItemName className="text-nowrap">
{provider.isSystem ? t(`provider.${provider.id}`) : provider.name}
{getFancyProviderName(provider)}
</ProviderItemName>
{provider.enabled && (
<Tag color="green" style={{ marginLeft: 'auto', marginRight: 0, borderRadius: 16 }}>

View File

@ -1,5 +1,6 @@
import { loggerService } from '@logger'
import AiProvider from '@renderer/aiCore'
import ModelSelector from '@renderer/components/ModelSelector'
import { DEFAULT_WEBSEARCH_RAG_DOCUMENT_COUNT } from '@renderer/config/constant'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
@ -8,15 +9,15 @@ import { useWebSearchSettings } from '@renderer/hooks/useWebSearchProviders'
import { SettingDivider, SettingRow, SettingRowTitle } from '@renderer/pages/settings'
import { getModelUniqId } from '@renderer/services/ModelService'
import { Model } from '@renderer/types'
import { Button, InputNumber, Select, Slider, Tooltip } from 'antd'
import { find, sortBy } from 'lodash'
import { Button, InputNumber, Slider, Tooltip } from 'antd'
import { find } from 'lodash'
import { Info, RefreshCw } from 'lucide-react'
import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
const logger = loggerService.withContext('RagSettings')
const INPUT_BOX_WIDTH = '200px'
const INPUT_BOX_WIDTH = 'min(350px, 60%)'
const RagSettings = () => {
const { t } = useTranslation()
@ -25,53 +26,16 @@ const RagSettings = () => {
const [loadingDimensions, setLoadingDimensions] = useState(false)
const embeddingModels = useMemo(() => {
return providers
.map((p) => p.models)
.flat()
.filter((model) => isEmbeddingModel(model))
return providers.flatMap((p) => p.models).filter((model) => isEmbeddingModel(model))
}, [providers])
const rerankModels = useMemo(() => {
return providers
.map((p) => p.models)
.flat()
.filter((model) => isRerankModel(model))
return providers.flatMap((p) => p.models).filter((model) => isRerankModel(model))
}, [providers])
const embeddingSelectOptions = useMemo(() => {
return providers
.filter((p) => p.models.length > 0)
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => isEmbeddingModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m),
providerId: p.id,
modelId: m.id
}))
}))
.filter((group) => group.options.length > 0)
}, [providers, t])
const rerankSelectOptions = useMemo(() => {
return providers
.filter((p) => p.models.length > 0)
.filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => isRerankModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m)
}))
}))
.filter((group) => group.options.length > 0)
}, [providers, t])
const rerankProviders = useMemo(() => {
return providers.filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
}, [providers])
const handleEmbeddingModelChange = (modelValue: string) => {
const selectedModel = find(embeddingModels, JSON.parse(modelValue)) as Model
@ -125,14 +89,14 @@ const RagSettings = () => {
<>
<SettingRow>
<SettingRowTitle>{t('models.embedding_model')}</SettingRowTitle>
<Select
<ModelSelector
providers={providers}
predicate={isEmbeddingModel}
value={compressionConfig?.embeddingModel ? getModelUniqId(compressionConfig.embeddingModel) : undefined}
style={{ width: INPUT_BOX_WIDTH }}
options={embeddingSelectOptions}
placeholder={t('settings.models.empty')}
onChange={handleEmbeddingModelChange}
allowClear={false}
showSearch
/>
</SettingRow>
<SettingDivider />
@ -166,14 +130,14 @@ const RagSettings = () => {
<SettingRow>
<SettingRowTitle>{t('models.rerank_model')}</SettingRowTitle>
<Select
<ModelSelector
providers={rerankProviders}
predicate={isRerankModel}
value={compressionConfig?.rerankModel ? getModelUniqId(compressionConfig.rerankModel) : undefined}
style={{ width: INPUT_BOX_WIDTH }}
options={rerankSelectOptions}
placeholder={t('settings.models.empty')}
onChange={handleRerankModelChange}
allowClear
showSearch
/>
</SettingRow>
<SettingDivider />

View File

@ -6,6 +6,9 @@ import { useTranslation } from 'react-i18next'
import CutoffSettings from './CutoffSettings'
import RagSettings from './RagSettings'
const INPUT_BOX_WIDTH_CUTOFF = '200px'
const INPUT_BOX_WIDTH_RAG = 'min(350px, 60%)'
const CompressionSettings = () => {
const { t } = useTranslation()
const { compressionConfig, updateCompressionConfig } = useWebSearchSettings()
@ -29,7 +32,7 @@ const CompressionSettings = () => {
<SettingRowTitle>{t('settings.tool.websearch.compression.method')}</SettingRowTitle>
<Select
value={compressionConfig?.method || 'none'}
style={{ width: '200px' }}
style={{ width: compressionConfig?.method === 'rag' ? INPUT_BOX_WIDTH_RAG : INPUT_BOX_WIDTH_CUTOFF }}
onChange={handleCompressionMethodChange}
options={compressionMethodOptions}
/>

View File

@ -3,6 +3,7 @@ import { loggerService } from '@logger'
import { Navbar, NavbarCenter } from '@renderer/components/app/Navbar'
import CopyIcon from '@renderer/components/Icons/CopyIcon'
import { HStack } from '@renderer/components/Layout'
import ModelSelector from '@renderer/components/ModelSelector'
import { isEmbeddingModel, isRerankModel, isTextToImageModel } from '@renderer/config/models'
import { TRANSLATE_PROMPT } from '@renderer/config/prompts'
import { LanguagesEnum, translateLanguageOptions } from '@renderer/config/translate'
@ -29,9 +30,9 @@ import { Button, Dropdown, Empty, Flex, Modal, Popconfirm, Select, Space, Switch
import TextArea, { TextAreaRef } from 'antd/es/input/TextArea'
import dayjs from 'dayjs'
import { useLiveQuery } from 'dexie-react-hooks'
import { find, isEmpty, sortBy } from 'lodash'
import { find, isEmpty } from 'lodash'
import { ChevronDown, HelpCircle, Settings2, TriangleAlert } from 'lucide-react'
import { FC, useEffect, useMemo, useRef, useState } from 'react'
import { FC, useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@ -54,8 +55,6 @@ const TranslateSettings: FC<{
setBidirectionalPair: (value: [Language, Language]) => void
translateModel: Model | undefined
onModelChange: (model: Model) => void
allModels: Model[]
selectOptions: any[]
}> = ({
visible,
onClose,
@ -68,9 +67,7 @@ const TranslateSettings: FC<{
bidirectionalPair,
setBidirectionalPair,
translateModel,
onModelChange,
allModels,
selectOptions
onModelChange
}) => {
const { t } = useTranslation()
const { translateModelPrompt } = useSettings()
@ -79,6 +76,14 @@ const TranslateSettings: FC<{
const [showPrompt, setShowPrompt] = useState(false)
const [localPrompt, setLocalPrompt] = useState(translateModelPrompt)
const { providers } = useProviders()
const allModels = useMemo(() => providers.map((p) => p.models).flat(), [providers])
const modelPredicate = useCallback(
(m: Model) => !isEmbeddingModel(m) && !isRerankModel(m) && !isTextToImageModel(m),
[]
)
const defaultTranslateModel = useMemo(
() => (hasModel(translateModel) ? getModelUniqId(translateModel) : undefined),
[translateModel]
@ -136,7 +141,9 @@ const TranslateSettings: FC<{
</Tooltip>
</div>
<HStack alignItems="center" gap={5}>
<Select
<ModelSelector
providers={providers}
predicate={modelPredicate}
style={{ width: '100%' }}
placeholder={t('translate.settings.model_placeholder')}
value={defaultTranslateModel}
@ -146,8 +153,6 @@ const TranslateSettings: FC<{
onModelChange(selectedModel)
}
}}
options={selectOptions}
showSearch
/>
</HStack>
{!translateModel && (
@ -302,9 +307,6 @@ const TranslatePage: FC = () => {
const outputTextRef = useRef<HTMLDivElement>(null)
const isProgrammaticScroll = useRef(false)
const { providers } = useProviders()
const allModels = useMemo(() => providers.map((p) => p.models).flat(), [providers])
const _translateHistory = useLiveQuery(() => db.translate_history.orderBy('createdAt').reverse().toArray(), [])
const translateHistory = useMemo(() => {
@ -319,31 +321,6 @@ const TranslatePage: FC = () => {
_result = result
_targetLanguage = targetLanguage
const selectOptions = useMemo(
() =>
providers
.filter((p) => p.models.length > 0)
.flatMap((p) => {
const filteredModels = sortBy(p.models, 'name')
.filter((m) => !isEmbeddingModel(m) && !isRerankModel(m) && !isTextToImageModel(m))
.map((m) => ({
label: `${m.name} | ${p.isSystem ? t(`provider.${p.id}`) : p.name}`,
value: getModelUniqId(m)
}))
if (filteredModels.length > 0) {
return [
{
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: filteredModels
}
]
}
return []
}),
[providers, t]
)
const handleModelChange = (model: Model) => {
setTranslateModel(model)
db.settings.put({ id: 'translate:model', value: model.id })
@ -760,8 +737,6 @@ const TranslatePage: FC = () => {
setBidirectionalPair={setBidirectionalPair}
translateModel={translateModel}
onModelChange={handleModelChange}
allModels={allModels}
selectOptions={selectOptions}
/>
</Container>
)

View File

@ -1,6 +1,6 @@
import store from '@renderer/store'
import { Model, Provider } from '@renderer/types'
import { t } from 'i18next'
import { getFancyProviderName } from '@renderer/utils'
import { pick } from 'lodash'
import { checkApi } from './ApiService'
@ -24,7 +24,7 @@ export function getModelName(model?: Model) {
const modelName = model?.name || model?.id || ''
if (provider) {
const providerName = provider?.isSystem ? t(`provider.${provider.id}`) : provider?.name
const providerName = getFancyProviderName(provider)
return `${modelName} | ${providerName}`
}

View File

@ -0,0 +1,140 @@
import type { Model, Provider } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
import { includeKeywords, matchKeywordsInModel, matchKeywordsInProvider, matchKeywordsInString } from '../match'
// mock i18n for getFancyProviderName
vi.mock('@renderer/i18n', () => ({
default: {
t: (key: string) => `i18n:${key}`
}
}))
describe('match', () => {
const provider: Provider = {
id: '12345',
type: 'openai',
name: 'OpenAI',
apiKey: '',
apiHost: '',
models: [],
isSystem: false
}
const sysProvider: Provider = {
...provider,
id: 'sys',
name: 'SystemProvider',
isSystem: true
}
describe('includeKeywords', () => {
it('should return true if keywords is empty or blank', () => {
expect(includeKeywords('hello world', '')).toBe(true)
expect(includeKeywords('hello world', ' ')).toBe(true)
})
it('should return false if target is empty', () => {
expect(includeKeywords('', 'hello')).toBe(false)
expect(includeKeywords(undefined as any, 'hello')).toBe(false)
})
it('should match all keywords (case-insensitive, whitespace split)', () => {
expect(includeKeywords('Hello World', 'hello')).toBe(true)
expect(includeKeywords('Hello World', 'world')).toBe(true)
expect(includeKeywords('Hello World', 'hello world')).toBe(true)
expect(includeKeywords('Hello World', 'world hello')).toBe(true)
expect(includeKeywords('Hello World', 'HELLO')).toBe(true)
expect(includeKeywords('Hello World', 'hello world')).toBe(true)
expect(includeKeywords('Hello\nWorld', 'hello world')).toBe(true)
})
it('should return false if any keyword is not included', () => {
expect(includeKeywords('Hello World', 'hello foo')).toBe(false)
expect(includeKeywords('Hello World', 'foo')).toBe(false)
})
it('should ignore blank keywords', () => {
expect(includeKeywords('Hello World', ' hello ')).toBe(true)
expect(includeKeywords('Hello World', 'hello ')).toBe(true)
expect(includeKeywords('Hello World', ' ')).toBe(true)
})
it('should handle keyword array', () => {
expect(includeKeywords('Hello World', ['hello', 'world'])).toBe(true)
expect(includeKeywords('Hello World', ['Hello', 'World'])).toBe(true)
expect(includeKeywords('Hello World', ['hello', 'foo'])).toBe(false)
expect(includeKeywords('Hello World', ['hello', ''])).toBe(true)
})
})
describe('matchKeywordsInString', () => {
it('should delegate to includeKeywords with string', () => {
expect(matchKeywordsInString('foo', 'foo bar')).toBe(true)
expect(matchKeywordsInString('bar', 'foo bar')).toBe(true)
expect(matchKeywordsInString('baz', 'foo bar')).toBe(false)
})
})
describe('matchKeywordsInProvider', () => {
it('should match non-system provider by name only, not id', () => {
expect(matchKeywordsInProvider('OpenAI', provider)).toBe(true)
expect(matchKeywordsInProvider('12345', provider)).toBe(false) // Should NOT match by id
expect(matchKeywordsInProvider('foo', provider)).toBe(false)
})
it('should match i18n name for system provider', () => {
expect(matchKeywordsInProvider('i18n:provider.sys', sysProvider)).toBe(true)
expect(matchKeywordsInProvider('SystemProvider', sysProvider)).toBe(false)
})
})
describe('matchKeywordsInModel', () => {
const model: Model = {
id: 'gpt-4.1',
provider: 'openai',
name: 'GPT-4.1',
group: 'gpt'
}
it('should match model name only if provider not given', () => {
expect(matchKeywordsInModel('gpt-4.1', model)).toBe(true)
expect(matchKeywordsInModel('openai', model)).toBe(false)
})
it('should match model name and provider name if provider given', () => {
expect(matchKeywordsInModel('gpt-4.1 openai', model, provider)).toBe(true)
expect(matchKeywordsInModel('gpt-4.1', model, provider)).toBe(true)
expect(matchKeywordsInModel('foo', model, provider)).toBe(false)
})
it('should match model name and i18n provider name for system provider', () => {
expect(matchKeywordsInModel('gpt-4.1 i18n:provider.sys', model, sysProvider)).toBe(true)
expect(matchKeywordsInModel('i18n:provider.sys', model, sysProvider)).toBe(true)
expect(matchKeywordsInModel('SystemProvider', model, sysProvider)).toBe(false)
})
it('should match model by id when name is customized', () => {
const customNameModel: Model = {
id: 'claude-3-opus-20240229',
provider: 'anthropic',
name: 'Opus (Custom Name)',
group: 'claude'
}
// search by parts of ID
expect(matchKeywordsInModel('claude', customNameModel)).toBe(true)
expect(matchKeywordsInModel('opus', customNameModel)).toBe(true)
expect(matchKeywordsInModel('20240229', customNameModel)).toBe(true)
// search by parts of custom name
expect(matchKeywordsInModel('Custom', customNameModel)).toBe(true)
expect(matchKeywordsInModel('Opus Name', customNameModel)).toBe(true)
// search by both
expect(matchKeywordsInModel('claude custom', customNameModel)).toBe(true)
// should not match
expect(matchKeywordsInModel('sonnet', customNameModel)).toBe(false)
})
})
})

View File

@ -231,6 +231,7 @@ export * from './api'
export * from './file'
export * from './image'
export * from './json'
export * from './match'
export * from './naming'
export * from './sort'
export * from './style'

View File

@ -0,0 +1,80 @@
import i18n from '@renderer/i18n'
import { Model, Provider } from '@renderer/types'
/**
* keywords
* keywords
* -
* -
*
* @param target
* @param keywords
* @returns true
*/
export function includeKeywords(target: string, keywords: string | string[]): boolean {
const keywordArray = Array.isArray(keywords) ? keywords : (keywords || '').split(/\s+/)
const nonEmptyKeywords = keywordArray.filter(Boolean)
// 如果没有有效关键词,则视为匹配
if (nonEmptyKeywords.length === 0) return true
// 如果没有搜索目标,则视为不匹配
if (!target || typeof target !== 'string') return false
const targetLower = target.toLowerCase()
return nonEmptyKeywords.every((keyword) => targetLower.includes(keyword.toLowerCase()))
}
/**
*
* @see includeKeywords
* @param keywords
* @param value
* @returns true
*/
export function matchKeywordsInString(keywords: string | string[], value: string): boolean {
return includeKeywords(value, keywords)
}
/**
* Provider
* @param keywords
* @param provider Provider
* @returns true
*/
export function matchKeywordsInProvider(keywords: string | string[], provider: Provider): boolean {
return includeKeywords(getProviderSearchString(provider), keywords)
}
/**
* Model
* @param keywords
* @param model Model
* @param provider Provider
* @returns true
*/
export function matchKeywordsInModel(keywords: string | string[], model: Model, provider?: Provider): boolean {
const fullName = `${model.name} ${model.id} ${provider ? getProviderSearchString(provider) : ''}`
return includeKeywords(fullName, keywords)
}
/**
* Provider getFancyProviderName
* @param provider Provider
* @returns
*/
function getProviderSearchString(provider: Provider) {
return provider.isSystem ? `${i18n.t(`provider.${provider.id}`)} ${provider.id}` : provider.name
}
/**
*
* @param keywords
* @param models
* @param provider Provider
* @returns
*/
export function filterModelsByKeywords(keywords: string, models: Model[], provider?: Provider): Model[] {
const keywordsArray = keywords.toLowerCase().split(/\s+/).filter(Boolean)
return models.filter((model) => matchKeywordsInModel(keywordsArray, model, provider))
}