mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 06:30:10 +08:00
refactor: match provider and model using a consistent method (#7933)
* refactor: match provider and model using a consistent method * refactor: use keywords matching across model selectors * refactor: update match, reuse getFancyProviderName * refactor: use modelSelectFilter in knowledgebase settings * refactor: use filter in ModelList * refactor: add filterModelsByKeywords * refactor: add getModelSelectOptions * style: better function names * fix: update effect dependencies in popup and panel components Adjusted dependency arrays in HtmlArtifactsPopup and QuickPanelView to ensure correct effect execution. This change improves state synchronization and prevents unnecessary updates. * refactor: use match in memory settings * refactor: add avatar to model selector * refactor: simplify utils, move select options to components * docs: add comments * refactor: move filter to SelectOptions * test: add tests for SelectOptions * test: remove type mock * refactor: use match in EditModelsPopup * refactor: use SelectOptions in SelectProviderModelPopup, add more tests * fix: api check model select * refactor: improve websearch rag model select style * refactor: add a ModelSelector * test: update tests for ModelSelector * docs: comments --------- Co-authored-by: 自由的世界人 <3196812536@qq.com>
This commit is contained in:
parent
d0649d29fb
commit
736f73a726
@ -247,7 +247,7 @@ The Enterprise Edition addresses core challenges in team collaboration by centra
|
||||
|
||||
| Feature | Community Edition | Enterprise Edition |
|
||||
| :---------------- | :----------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Open Source** | ✅ Yes | ⭕️ Partially released to customers |
|
||||
| **Open Source** | ✅ Yes | ⭕️ Partially released to customers |
|
||||
| **Cost** | Free for Personal Use / Commercial License | Buyout / Subscription Fee |
|
||||
| **Admin Backend** | — | ● Centralized **Model** Access<br>● **Employee** Management<br>● Shared **Knowledge Base**<br>● **Access** Control<br>● **Data** Backup |
|
||||
| **Server** | — | ✅ Dedicated Private Deployment |
|
||||
|
||||
@ -145,8 +145,8 @@ In a development environment, you can define environment variables to filter dis
|
||||
|
||||
Environment variables can be set in the terminal or defined in the `.env` file in the project's root directory. The available variables are as follows:
|
||||
|
||||
| Variable Name | Description |
|
||||
| ------- | ------- |
|
||||
| Variable Name | Description |
|
||||
| ------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| CSLOGGER_MAIN_LEVEL | Log level for the `main` process. Logs below this level will not be displayed. |
|
||||
| CSLOGGER_MAIN_SHOW_MODULES | Filters log modules for the `main` process. Use a comma (`,`) to separate modules. The filter is case-sensitive. Only logs from modules in this list will be displayed. |
|
||||
| CSLOGGER_RENDERER_LEVEL | Log level for the `renderer` process. Logs below this level will not be displayed. |
|
||||
@ -160,6 +160,7 @@ CSLOGGER_MAIN_SHOW_MODULES=MCPService,SelectionService
|
||||
```
|
||||
|
||||
Note:
|
||||
|
||||
- Environment variables are only effective in the development environment.
|
||||
- These variables only affect the logs displayed in the terminal or DevTools. They do not affect file logging or the `logToMain` recording logic.
|
||||
|
||||
@ -168,7 +169,7 @@ Note:
|
||||
There are many log levels. The following are the guidelines that should be followed in CherryStudio for when to use each level:
|
||||
(Arranged from highest to lowest log level)
|
||||
|
||||
| Log Level | Core Definition & Use case | Example |
|
||||
| Log Level | Core Definition & Use case | Example |
|
||||
| :------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **`error`** | **Critical error causing the program to crash or core functionality to become unusable.** <br> This is the highest-priority log, usually requiring immediate reporting or user notification. | - Main or renderer process crash. <br> - Failure to read/write critical user data files (e.g., database, configuration files), preventing the application from running. <br> - All unhandled exceptions. |
|
||||
| **`warn`** | **Potential issue or unexpected situation that does not affect the program's core functionality.** <br> The program can recover or use a fallback. | - Configuration file `settings.json` is missing; started with default settings. <br> - Auto-update check failed, but does not affect the use of the current version. <br> - A non-essential plugin failed to load. |
|
||||
|
||||
@ -145,12 +145,12 @@ const logger = loggerService.initWindowSource('Worker').withContext('LetsWork')
|
||||
|
||||
环境变量可以在终端中自行设置,或者在开发根目录的`.env`文件中进行定义,可以定义的变量如下:
|
||||
|
||||
| 变量名 | 含义 |
|
||||
| ----- | ----- |
|
||||
| CSLOGGER_MAIN_LEVEL | 用于`main`进程的日志级别,低于该级别的日志将不显示 |
|
||||
| CSLOGGER_MAIN_SHOW_MODULES | 用于`main`进程的日志module筛选,用`,`分隔,区分大小写。只有在该列表中的module的日志才会显示 |
|
||||
| CSLOGGER_RENDERER_LEVEL | 用于`renderer`进程的日志级别,低于该级别的日志将不显示 |
|
||||
| CSLOGGER_RENDERER_SHOW_MODULES | 用于`renderer`进程的日志module筛选,用`,`分隔,区分大小写。只有在该列表中的module的日志才会显示 |
|
||||
| 变量名 | 含义 |
|
||||
| ------------------------------ | ----------------------------------------------------------------------------------------------- |
|
||||
| CSLOGGER_MAIN_LEVEL | 用于`main`进程的日志级别,低于该级别的日志将不显示 |
|
||||
| CSLOGGER_MAIN_SHOW_MODULES | 用于`main`进程的日志module筛选,用`,`分隔,区分大小写。只有在该列表中的module的日志才会显示 |
|
||||
| CSLOGGER_RENDERER_LEVEL | 用于`renderer`进程的日志级别,低于该级别的日志将不显示 |
|
||||
| CSLOGGER_RENDERER_SHOW_MODULES | 用于`renderer`进程的日志module筛选,用`,`分隔,区分大小写。只有在该列表中的module的日志才会显示 |
|
||||
|
||||
示例:
|
||||
|
||||
@ -160,6 +160,7 @@ CSLOGGER_MAIN_SHOW_MODULES=MCPService,SelectionService
|
||||
```
|
||||
|
||||
注意:
|
||||
|
||||
- 环境变量仅在开发环境中生效
|
||||
- 该变量仅会改变在终端或在devTools中显示的日志,不会影响文件日志和`logToMain`的记录逻辑
|
||||
|
||||
|
||||
@ -23,7 +23,13 @@ import { useProvider } from '@renderer/hooks/useProvider'
|
||||
import FileItem from '@renderer/pages/files/FileItem'
|
||||
import { fetchModels } from '@renderer/services/ApiService'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
import { getDefaultGroupName, isFreeModel, runAsyncFunction } from '@renderer/utils'
|
||||
import {
|
||||
filterModelsByKeywords,
|
||||
getDefaultGroupName,
|
||||
getFancyProviderName,
|
||||
isFreeModel,
|
||||
runAsyncFunction
|
||||
} from '@renderer/utils'
|
||||
import { Avatar, Button, Empty, Flex, Modal, Spin, Tabs, Tooltip } from 'antd'
|
||||
import Input from 'antd/es/input/Input'
|
||||
import { groupBy, isEmpty, uniqBy } from 'lodash'
|
||||
@ -86,34 +92,30 @@ const PopupContainer: React.FC<Props> = ({ provider: _provider, resolve }) => {
|
||||
const systemModels = SYSTEM_MODELS[_provider.id] || []
|
||||
const allModels = uniqBy([...systemModels, ...listModels, ...models], 'id')
|
||||
|
||||
const list = allModels.filter((model) => {
|
||||
if (
|
||||
filterSearchText &&
|
||||
!model.id.toLocaleLowerCase().includes(filterSearchText.toLocaleLowerCase()) &&
|
||||
!model.name?.toLocaleLowerCase().includes(filterSearchText.toLocaleLowerCase())
|
||||
) {
|
||||
return false
|
||||
}
|
||||
|
||||
switch (actualFilterType) {
|
||||
case 'reasoning':
|
||||
return isReasoningModel(model)
|
||||
case 'vision':
|
||||
return isVisionModel(model)
|
||||
case 'websearch':
|
||||
return isWebSearchModel(model)
|
||||
case 'free':
|
||||
return isFreeModel(model)
|
||||
case 'embedding':
|
||||
return isEmbeddingModel(model)
|
||||
case 'function_calling':
|
||||
return isFunctionCallingModel(model)
|
||||
case 'rerank':
|
||||
return isRerankModel(model)
|
||||
default:
|
||||
return true
|
||||
}
|
||||
})
|
||||
const list = useMemo(
|
||||
() =>
|
||||
filterModelsByKeywords(filterSearchText, allModels).filter((model) => {
|
||||
switch (actualFilterType) {
|
||||
case 'reasoning':
|
||||
return isReasoningModel(model)
|
||||
case 'vision':
|
||||
return isVisionModel(model)
|
||||
case 'websearch':
|
||||
return isWebSearchModel(model)
|
||||
case 'free':
|
||||
return isFreeModel(model)
|
||||
case 'embedding':
|
||||
return isEmbeddingModel(model)
|
||||
case 'function_calling':
|
||||
return isFunctionCallingModel(model)
|
||||
case 'rerank':
|
||||
return isRerankModel(model)
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}),
|
||||
[filterSearchText, actualFilterType, allModels]
|
||||
)
|
||||
|
||||
const modelGroups = useMemo(
|
||||
() =>
|
||||
@ -202,7 +204,7 @@ const PopupContainer: React.FC<Props> = ({ provider: _provider, resolve }) => {
|
||||
return (
|
||||
<Flex>
|
||||
<ModelHeaderTitle>
|
||||
{provider.isSystem ? t(`provider.${provider.id}`) : provider.name}
|
||||
{getFancyProviderName(provider)}
|
||||
{i18n.language.startsWith('zh') ? '' : ' '}
|
||||
{t('common.models')}
|
||||
</ModelHeaderTitle>
|
||||
|
||||
@ -12,6 +12,7 @@ import { SettingHelpLink, SettingHelpText, SettingHelpTextRow, SettingSubtitle }
|
||||
import { useAppDispatch } from '@renderer/store'
|
||||
import { setModel } from '@renderer/store/assistants'
|
||||
import { Model } from '@renderer/types'
|
||||
import { filterModelsByKeywords } from '@renderer/utils'
|
||||
import { Button, Flex, Tooltip } from 'antd'
|
||||
import { groupBy, sortBy, toPairs } from 'lodash'
|
||||
import { ListCheck, Plus } from 'lucide-react'
|
||||
@ -51,9 +52,7 @@ const ModelList: React.FC<ModelListProps> = ({ providerId }) => {
|
||||
}, [])
|
||||
|
||||
const modelGroups = useMemo(() => {
|
||||
const filteredModels = searchText
|
||||
? models.filter((model) => model.name.toLowerCase().includes(searchText.toLowerCase()))
|
||||
: models
|
||||
const filteredModels = searchText ? filterModelsByKeywords(searchText, models) : models
|
||||
return groupBy(filteredModels, 'group')
|
||||
}, [searchText, models])
|
||||
|
||||
|
||||
123
src/renderer/src/components/ModelSelector.tsx
Normal file
123
src/renderer/src/components/ModelSelector.tsx
Normal file
@ -0,0 +1,123 @@
|
||||
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
import { matchKeywordsInString } from '@renderer/utils'
|
||||
import { getFancyProviderName } from '@renderer/utils/naming'
|
||||
import { Select, SelectProps } from 'antd'
|
||||
import { sortBy } from 'lodash'
|
||||
import { BaseSelectRef } from 'rc-select'
|
||||
import { memo, useCallback, useMemo } from 'react'
|
||||
|
||||
interface ModelOption {
|
||||
label: React.ReactNode
|
||||
title: string
|
||||
value: string
|
||||
}
|
||||
|
||||
interface GroupedModelOption {
|
||||
label: string
|
||||
title: string
|
||||
options: ModelOption[]
|
||||
}
|
||||
|
||||
type SelectOption = ModelOption | GroupedModelOption
|
||||
|
||||
interface ModelSelectorProps extends SelectProps {
|
||||
providers?: Provider[]
|
||||
predicate?: (model: Model) => boolean
|
||||
grouped?: boolean
|
||||
showAvatar?: boolean
|
||||
showSuffix?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 模型选择器,封装了 antd Select
|
||||
* - 通过传入模型服务商列表和模型 predicate 来构造选项
|
||||
* - 支持按服务商分组
|
||||
* - 可以控制 avatar 和 suffix 显示与否
|
||||
* @param providers 服务商列表
|
||||
* @param predicate 模型过滤条件
|
||||
* @param grouped 是否按服务商分组
|
||||
* @param showAvatar 是否显示模型图标
|
||||
* @param showSuffix 是否在模型名称后显示服务商作为后缀
|
||||
*/
|
||||
const ModelSelector = ({
|
||||
providers,
|
||||
predicate,
|
||||
grouped = true,
|
||||
showAvatar = true,
|
||||
showSuffix = true,
|
||||
ref,
|
||||
...props
|
||||
}: ModelSelectorProps & { ref?: React.Ref<BaseSelectRef> | null }) => {
|
||||
// 单个 provider 的模型选项
|
||||
const getModelOptions = useCallback(
|
||||
(p: Provider, fancyName: string) => {
|
||||
const suffix = showSuffix ? <span style={{ opacity: 0.45 }}>{` | ${fancyName}`}</span> : null
|
||||
return sortBy(p.models, 'name')
|
||||
.filter((model) => predicate?.(model) ?? true)
|
||||
.map((m) => ({
|
||||
label: (
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 8 }}>
|
||||
{showAvatar && <ModelAvatar model={m} size={18} />}
|
||||
<span>
|
||||
{m.name}
|
||||
{suffix}
|
||||
</span>
|
||||
</div>
|
||||
),
|
||||
title: `${m.name} | ${fancyName}`,
|
||||
value: getModelUniqId(m)
|
||||
}))
|
||||
},
|
||||
[predicate, showAvatar, showSuffix]
|
||||
)
|
||||
|
||||
// 所有 provider 的模型选项
|
||||
const options = useMemo((): SelectOption[] => {
|
||||
if (!providers) return []
|
||||
|
||||
if (grouped) {
|
||||
return providers.flatMap((p) => {
|
||||
const fancyName = getFancyProviderName(p)
|
||||
const modelOptions = getModelOptions(p, fancyName)
|
||||
return modelOptions.length > 0
|
||||
? [
|
||||
{
|
||||
label: fancyName,
|
||||
title: p.name,
|
||||
options: modelOptions
|
||||
} as GroupedModelOption
|
||||
]
|
||||
: []
|
||||
})
|
||||
}
|
||||
return providers.flatMap((p) => getModelOptions(p, getFancyProviderName(p)))
|
||||
}, [providers, grouped, getModelOptions])
|
||||
|
||||
return <Select ref={ref} options={options} filterOption={modelSelectFilter} showSearch {...props} />
|
||||
}
|
||||
|
||||
export default memo(ModelSelector)
|
||||
|
||||
/**
|
||||
* 用于 antd Select 组件的 filterOption,统一搜索行为:
|
||||
* - 优先使用 title 匹配
|
||||
* - 其次使用 label 匹配
|
||||
* - 最后使用 value 匹配
|
||||
*
|
||||
* @param input 用户输入的搜索字符串
|
||||
* @param option Select 选项对象,包含 label 或 value
|
||||
* @returns 是否匹配
|
||||
*/
|
||||
export function modelSelectFilter(input: string, option: any): boolean {
|
||||
const target =
|
||||
typeof option?.title === 'string'
|
||||
? option.title
|
||||
: typeof option?.label === 'string'
|
||||
? option.label
|
||||
: typeof option?.value === 'string'
|
||||
? option.value
|
||||
: ''
|
||||
return matchKeywordsInString(input, target)
|
||||
}
|
||||
@ -7,7 +7,7 @@ import { usePinnedModels } from '@renderer/hooks/usePinnedModels'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { Model } from '@renderer/types'
|
||||
import { classNames } from '@renderer/utils/style'
|
||||
import { classNames, filterModelsByKeywords, getFancyProviderName } from '@renderer/utils'
|
||||
import { Avatar, Divider, Empty, Input, InputRef, Modal } from 'antd'
|
||||
import { first, sortBy } from 'lodash'
|
||||
import { Search } from 'lucide-react'
|
||||
@ -102,27 +102,19 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter }) => {
|
||||
let models = provider.models.filter((m) => !isEmbeddingModel(m) && !isRerankModel(m))
|
||||
|
||||
if (searchText.trim()) {
|
||||
const keywords = searchText.toLowerCase().split(/\s+/).filter(Boolean)
|
||||
models = models.filter((m) => {
|
||||
const fullName = provider.isSystem
|
||||
? `${m.name} ${provider.name} ${t('provider.' + provider.id)}`
|
||||
: `${m.name} ${provider.name}`
|
||||
|
||||
const lowerFullName = fullName.toLowerCase()
|
||||
return keywords.every((keyword) => lowerFullName.includes(keyword))
|
||||
})
|
||||
models = filterModelsByKeywords(searchText, models, provider)
|
||||
}
|
||||
|
||||
return sortBy(models, ['group', 'name'])
|
||||
},
|
||||
[searchText, t]
|
||||
[searchText]
|
||||
)
|
||||
|
||||
// 创建模型列表项
|
||||
const createModelItem = useCallback(
|
||||
(model: Model, provider: any, isPinned: boolean): FlatListItem => {
|
||||
const modelId = getModelUniqId(model)
|
||||
const groupName = provider.isSystem ? t(`provider.${provider.id}`) : provider.name
|
||||
const groupName = getFancyProviderName(provider)
|
||||
|
||||
return {
|
||||
key: isPinned ? `${modelId}_pinned` : modelId,
|
||||
@ -148,7 +140,7 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter }) => {
|
||||
isSelected: modelId === currentModelId
|
||||
}
|
||||
},
|
||||
[t, currentModelId]
|
||||
[currentModelId]
|
||||
)
|
||||
|
||||
// 构建扁平化列表数据
|
||||
@ -189,7 +181,7 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter }) => {
|
||||
items.push({
|
||||
key: `provider-${p.id}`,
|
||||
type: 'group',
|
||||
name: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
name: getFancyProviderName(p),
|
||||
isSelected: false
|
||||
})
|
||||
|
||||
|
||||
225
src/renderer/src/components/__tests__/ModelSelector.test.tsx
Normal file
225
src/renderer/src/components/__tests__/ModelSelector.test.tsx
Normal file
@ -0,0 +1,225 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
// Mock the imported modules
|
||||
vi.mock('@renderer/components/Avatar/ModelAvatar', () => ({
|
||||
default: ({ model, size }: any) => (
|
||||
<div data-testid="model-avatar" style={{ width: size, height: size }}>
|
||||
{model.name.charAt(0)}
|
||||
</div>
|
||||
)
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/ModelService', () => ({
|
||||
getModelUniqId: (model: any) => `${model.provider}-${model.id}`
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils', () => ({
|
||||
matchKeywordsInString: (input: string, target: string) => target.toLowerCase().includes(input.toLowerCase())
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/naming', () => ({
|
||||
getFancyProviderName: (provider: any) => provider.name
|
||||
}))
|
||||
|
||||
// Import after mocking
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import ModelSelector, { modelSelectFilter } from '../ModelSelector'
|
||||
|
||||
describe('ModelSelector', () => {
|
||||
const mockProviders: Provider[] = [
|
||||
{
|
||||
id: 'openai',
|
||||
name: 'OpenAI',
|
||||
type: 'openai',
|
||||
apiKey: '123',
|
||||
apiHost: 'https://api.openai.com',
|
||||
models: [
|
||||
{ id: 'text-embedding-ada-002', name: 'text-embedding-ada-002', provider: 'openai', group: 'embedding' },
|
||||
{ id: 'gpt-4.1', name: 'GPT-4.1', provider: 'openai', group: 'chat' }
|
||||
]
|
||||
},
|
||||
{
|
||||
id: 'cohere',
|
||||
name: 'Cohere',
|
||||
type: 'openai',
|
||||
apiKey: '123',
|
||||
apiHost: 'https://api.cohere.com',
|
||||
models: [
|
||||
{ id: 'embed-english-v3.0', name: 'embed-english-v3.0', provider: 'cohere', group: 'embedding' },
|
||||
{ id: 'rerank-english-v2.0', name: 'rerank-english-v2.0', provider: 'cohere', group: 'rerank' }
|
||||
]
|
||||
},
|
||||
{
|
||||
id: 'empty-provider',
|
||||
name: 'EmptyProvider',
|
||||
type: 'openai',
|
||||
apiKey: '123',
|
||||
apiHost: 'https://api.cohere.com',
|
||||
models: []
|
||||
}
|
||||
]
|
||||
|
||||
describe('grouped mode (grouped=true)', () => {
|
||||
it('should render grouped options and apply predicate', () => {
|
||||
render(
|
||||
<ModelSelector
|
||||
providers={mockProviders}
|
||||
predicate={(model) => model.group === 'embedding'}
|
||||
open // Keep dropdown open for testing
|
||||
/>
|
||||
)
|
||||
|
||||
// Check for group labels
|
||||
expect(screen.getByText('OpenAI')).toBeInTheDocument()
|
||||
expect(screen.getByText('Cohere')).toBeInTheDocument()
|
||||
expect(screen.queryByText('EmptyProvider')).not.toBeInTheDocument()
|
||||
|
||||
// Check for correct models
|
||||
const ada = screen.getByText('text-embedding-ada-002')
|
||||
const cohere = screen.getByText('embed-english-v3.0')
|
||||
expect(ada).toBeInTheDocument()
|
||||
expect(cohere).toBeInTheDocument()
|
||||
// Check suffix is present by default
|
||||
expect(ada.textContent).toContain(' | OpenAI')
|
||||
expect(cohere.textContent).toContain(' | Cohere')
|
||||
|
||||
// Check that filtered models are not present
|
||||
expect(screen.queryByText('GPT-4.1')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('rerank-english-v2.0')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide suffix when showSuffix is false', () => {
|
||||
render(
|
||||
<ModelSelector
|
||||
providers={mockProviders}
|
||||
predicate={(model) => model.group === 'embedding'}
|
||||
showSuffix={false}
|
||||
open
|
||||
/>
|
||||
)
|
||||
|
||||
const ada = screen.getByText('text-embedding-ada-002')
|
||||
expect(ada.textContent).toBe('text-embedding-ada-002')
|
||||
expect(ada.textContent).not.toContain(' | OpenAI')
|
||||
})
|
||||
|
||||
it('should hide avatar when showAvatar is false', () => {
|
||||
render(<ModelSelector providers={mockProviders} showAvatar={false} open />)
|
||||
expect(screen.queryByTestId('model-avatar')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show avatar when showAvatar is true', () => {
|
||||
render(<ModelSelector providers={mockProviders} showAvatar={true} open />)
|
||||
// 4 models in total from mockProviders
|
||||
expect(screen.getAllByTestId('model-avatar')).toHaveLength(4)
|
||||
})
|
||||
})
|
||||
|
||||
describe('flat mode (grouped=false)', () => {
|
||||
it('should render flat options and apply predicate', () => {
|
||||
render(
|
||||
<ModelSelector
|
||||
providers={mockProviders}
|
||||
predicate={(model) => model.group === 'embedding'}
|
||||
grouped={false}
|
||||
open
|
||||
/>
|
||||
)
|
||||
|
||||
// In flat mode, there are no group labels in the dropdown structure
|
||||
expect(document.querySelector('.ant-select-item-option-group')).toBeNull()
|
||||
|
||||
// Check for correct models
|
||||
const ada = screen.getByText('text-embedding-ada-002')
|
||||
const cohere = screen.getByText('embed-english-v3.0')
|
||||
expect(ada).toBeInTheDocument()
|
||||
expect(cohere).toBeInTheDocument()
|
||||
// Check suffix is present by default
|
||||
expect(ada.textContent).toContain(' | OpenAI')
|
||||
expect(cohere.textContent).toContain(' | Cohere')
|
||||
|
||||
// Check that filtered models are not present
|
||||
expect(screen.queryByText('GPT-4.1')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('rerank-english-v2.0')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide suffix when showSuffix is false', () => {
|
||||
render(<ModelSelector providers={mockProviders} grouped={false} showSuffix={false} open />)
|
||||
|
||||
const gpt4 = screen.getByText('GPT-4.1')
|
||||
expect(gpt4.textContent).toBe('GPT-4.1')
|
||||
expect(gpt4.textContent).not.toContain(' | OpenAI')
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle empty providers array', () => {
|
||||
render(<ModelSelector providers={[]} open />)
|
||||
expect(document.querySelector('.ant-select-item-option')).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle undefined providers', () => {
|
||||
render(<ModelSelector providers={undefined} open />)
|
||||
expect(document.querySelector('.ant-select-item-option')).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('modelSelectFilter function', () => {
|
||||
it('should filter by provider name in title', () => {
|
||||
const mockOption = {
|
||||
title: 'GPT-4.1 | OpenAI',
|
||||
value: 'openai-gpt-4.1'
|
||||
}
|
||||
expect(modelSelectFilter('openai', mockOption)).toBe(true)
|
||||
})
|
||||
|
||||
it('should filter by model name in title', () => {
|
||||
const mockOption = {
|
||||
title: 'embed-english-v3.0 | Cohere',
|
||||
value: 'cohere-embed-english-v3.0'
|
||||
}
|
||||
expect(modelSelectFilter('english', mockOption)).toBe(true)
|
||||
})
|
||||
|
||||
it('should filter by value if title is not present', () => {
|
||||
const mockOption = {
|
||||
value: 'openai-gpt-4.1'
|
||||
}
|
||||
expect(modelSelectFilter('gpt', mockOption)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for no match', () => {
|
||||
const mockOption = {
|
||||
title: 'GPT-4.1 | OpenAI',
|
||||
value: 'openai-gpt-4.1'
|
||||
}
|
||||
expect(modelSelectFilter('nonexistent', mockOption)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('integration', () => {
|
||||
it('should filter options correctly when user types in search input', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<ModelSelector providers={mockProviders} open />)
|
||||
|
||||
// Find the search input field, which is a combobox
|
||||
const searchInput = screen.getByRole('combobox')
|
||||
await user.type(searchInput, 'embed')
|
||||
|
||||
// After filtering, only embedding models should be visible
|
||||
expect(screen.getByText('text-embedding-ada-002')).toBeInTheDocument()
|
||||
expect(screen.getByText('embed-english-v3.0')).toBeInTheDocument()
|
||||
|
||||
// Other models should not be visible
|
||||
expect(screen.queryByText('GPT-4.1')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('rerank-english-v2.0')).not.toBeInTheDocument()
|
||||
|
||||
// The group titles for visible items should still be there
|
||||
expect(screen.getByText('OpenAI')).toBeInTheDocument()
|
||||
expect(screen.getByText('Cohere')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -6,6 +6,7 @@ import db from '@renderer/databases'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { FileType, Model } from '@renderer/types'
|
||||
import { getFancyProviderName } from '@renderer/utils'
|
||||
import { Avatar, Tooltip } from 'antd'
|
||||
import { useLiveQuery } from 'dexie-react-hooks'
|
||||
import { first, sortBy } from 'lodash'
|
||||
@ -62,7 +63,7 @@ const MentionModelsButton: FC<Props> = ({
|
||||
.map((m) => ({
|
||||
label: (
|
||||
<>
|
||||
<ProviderName>{p.isSystem ? t(`provider.${p.id}`) : p.name}</ProviderName>
|
||||
<ProviderName>{getFancyProviderName(p)}</ProviderName>
|
||||
<span style={{ opacity: 0.8 }}> | {m.name}</span>
|
||||
</>
|
||||
),
|
||||
@ -72,7 +73,7 @@ const MentionModelsButton: FC<Props> = ({
|
||||
{first(m.name)}
|
||||
</Avatar>
|
||||
),
|
||||
filterText: (p.isSystem ? t(`provider.${p.id}`) : p.name) + m.name,
|
||||
filterText: getFancyProviderName(p) + m.name,
|
||||
action: () => onMentionModel(m),
|
||||
isSelected: mentionedModels.some((selected) => getModelUniqId(selected) === getModelUniqId(m))
|
||||
}))
|
||||
@ -95,7 +96,7 @@ const MentionModelsButton: FC<Props> = ({
|
||||
const providerModelItems = providerModels.map((m) => ({
|
||||
label: (
|
||||
<>
|
||||
<ProviderName>{p.isSystem ? t(`provider.${p.id}`) : p.name}</ProviderName>
|
||||
<ProviderName>{getFancyProviderName(p)}</ProviderName>
|
||||
<span style={{ opacity: 0.8 }}> | {m.name}</span>
|
||||
</>
|
||||
),
|
||||
@ -105,7 +106,7 @@ const MentionModelsButton: FC<Props> = ({
|
||||
{first(m.name)}
|
||||
</Avatar>
|
||||
),
|
||||
filterText: (p.isSystem ? t(`provider.${p.id}`) : p.name) + m.name,
|
||||
filterText: getFancyProviderName(p) + m.name,
|
||||
action: () => onMentionModel(m),
|
||||
isSelected: mentionedModels.some((selected) => getModelUniqId(selected) === getModelUniqId(m))
|
||||
}))
|
||||
|
||||
@ -2,8 +2,8 @@ import CustomTag from '@renderer/components/CustomTag'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { Model } from '@renderer/types'
|
||||
import { getFancyProviderName } from '@renderer/utils'
|
||||
import { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
|
||||
const MentionModelsInput: FC<{
|
||||
@ -11,11 +11,10 @@ const MentionModelsInput: FC<{
|
||||
onRemoveModel: (model: Model) => void
|
||||
}> = ({ selectedModels, onRemoveModel }) => {
|
||||
const { providers } = useProviders()
|
||||
const { t } = useTranslation()
|
||||
|
||||
const getProviderName = (model: Model) => {
|
||||
const provider = providers.find((p) => p.id === model?.provider)
|
||||
return provider ? (provider.isSystem ? t(`provider.${provider.id}`) : provider.name) : ''
|
||||
return provider ? getFancyProviderName(provider) : ''
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@ -2,11 +2,11 @@ import { InfoCircleOutlined, WarningOutlined } from '@ant-design/icons'
|
||||
import { loggerService } from '@logger'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import ModelSelector from '@renderer/components/ModelSelector'
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, isMac } from '@renderer/config/constant'
|
||||
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
|
||||
import { useOcrProviders } from '@renderer/hooks/useOcr'
|
||||
import { usePreprocessProviders } from '@renderer/hooks/usePreprocess'
|
||||
@ -16,7 +16,7 @@ import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { KnowledgeBase, Model, OcrProvider, PreprocessProvider } from '@renderer/types'
|
||||
import { getErrorMessage } from '@renderer/utils/error'
|
||||
import { Alert, Input, InputNumber, Modal, Select, Slider, Switch, Tooltip } from 'antd'
|
||||
import { find, sortBy } from 'lodash'
|
||||
import { find } from 'lodash'
|
||||
import { ChevronDown } from 'lucide-react'
|
||||
import { nanoid } from 'nanoid'
|
||||
import { useEffect, useMemo, useRef, useState } from 'react'
|
||||
@ -65,41 +65,6 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
||||
const nameInputRef = useRef<any>(null)
|
||||
const scrollContainerRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
const embeddingSelectOptions = useMemo(() => {
|
||||
return providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: sortBy(p.models, 'name')
|
||||
.filter((model) => isEmbeddingModel(model))
|
||||
.map((m) => ({
|
||||
label: m.name,
|
||||
value: getModelUniqId(m),
|
||||
providerId: p.id,
|
||||
modelId: m.id
|
||||
}))
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
}, [providers, t])
|
||||
|
||||
const rerankSelectOptions = useMemo(() => {
|
||||
return providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: sortBy(p.models, 'name')
|
||||
.filter((model) => isRerankModel(model))
|
||||
.map((m) => ({
|
||||
label: m.name,
|
||||
value: getModelUniqId(m)
|
||||
}))
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
}, [providers, t])
|
||||
|
||||
const preprocessOrOcrSelectOptions = useMemo(() => {
|
||||
const preprocessOptions = {
|
||||
label: t('settings.tool.preprocess.provider'),
|
||||
@ -292,9 +257,10 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
||||
<InfoCircleOutlined style={{ marginLeft: 8, color: 'var(--color-text-3)' }} />
|
||||
</Tooltip>
|
||||
</div>
|
||||
<Select
|
||||
<ModelSelector
|
||||
providers={providers}
|
||||
predicate={isEmbeddingModel}
|
||||
style={{ width: '100%' }}
|
||||
options={embeddingSelectOptions}
|
||||
placeholder={t('settings.models.empty')}
|
||||
onChange={(value) => {
|
||||
const model = value
|
||||
@ -313,9 +279,10 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
||||
<InfoCircleOutlined style={{ marginLeft: 8, color: 'var(--color-text-3)' }} />
|
||||
</Tooltip>
|
||||
</div>
|
||||
<Select
|
||||
<ModelSelector
|
||||
providers={providers}
|
||||
predicate={isRerankModel}
|
||||
style={{ width: '100%' }}
|
||||
options={rerankSelectOptions}
|
||||
placeholder={t('settings.models.empty')}
|
||||
onChange={(value) => {
|
||||
const rerankModel = value
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { InfoCircleOutlined, WarningOutlined } from '@ant-design/icons'
|
||||
import { loggerService } from '@logger'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import ModelSelector from '@renderer/components/ModelSelector'
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, isMac } from '@renderer/config/constant'
|
||||
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
|
||||
@ -12,7 +13,6 @@ import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { KnowledgeBase, PreprocessProvider } from '@renderer/types'
|
||||
import { Alert, Input, InputNumber, Menu, Modal, Select, Slider, Tooltip } from 'antd'
|
||||
import { sortBy } from 'lodash'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
@ -46,34 +46,6 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
return null
|
||||
}
|
||||
|
||||
const selectOptions = providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: sortBy(p.models, 'name')
|
||||
.filter((model) => isEmbeddingModel(model))
|
||||
.map((m) => ({
|
||||
label: m.name,
|
||||
value: getModelUniqId(m)
|
||||
}))
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
|
||||
const rerankSelectOptions = providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: sortBy(p.models, 'name')
|
||||
.filter((model) => isRerankModel(model))
|
||||
.map((m) => ({
|
||||
label: m.name,
|
||||
value: getModelUniqId(m)
|
||||
}))
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
|
||||
const preprocessOptions = {
|
||||
label: t('settings.tool.preprocess.provider'),
|
||||
title: t('settings.tool.preprocess.provider'),
|
||||
@ -181,9 +153,10 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
<InfoCircleOutlined style={{ marginLeft: 8, color: 'var(--color-text-3)' }} />
|
||||
</Tooltip>
|
||||
</div>
|
||||
<Select
|
||||
<ModelSelector
|
||||
providers={providers}
|
||||
predicate={isEmbeddingModel}
|
||||
style={{ width: '100%' }}
|
||||
options={selectOptions}
|
||||
placeholder={t('settings.models.empty')}
|
||||
defaultValue={getModelUniqId(base.model)}
|
||||
disabled
|
||||
@ -202,10 +175,11 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
<InfoCircleOutlined style={{ marginLeft: 8, color: 'var(--color-text-3)' }} />
|
||||
</Tooltip>
|
||||
</div>
|
||||
<Select
|
||||
<ModelSelector
|
||||
providers={providers}
|
||||
predicate={isRerankModel}
|
||||
style={{ width: '100%' }}
|
||||
defaultValue={getModelUniqId(base.rerankModel) || undefined}
|
||||
options={rerankSelectOptions}
|
||||
placeholder={t('settings.models.empty')}
|
||||
onChange={(value) => {
|
||||
const rerankModel = value
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
import { loggerService } from '@logger'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import ModelSelector from '@renderer/components/ModelSelector'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
import { useModel } from '@renderer/hooks/useModel'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { selectMemoryConfig, updateMemoryConfig } from '@renderer/store/memory'
|
||||
import { Model } from '@renderer/types'
|
||||
import { getErrorMessage } from '@renderer/utils/error'
|
||||
import { Form, InputNumber, Modal, Select, Switch } from 'antd'
|
||||
import { Form, InputNumber, Modal, Switch } from 'antd'
|
||||
import { t } from 'i18next'
|
||||
import { sortBy } from 'lodash'
|
||||
import { FC, useEffect, useState } from 'react'
|
||||
import { FC, useCallback, useEffect, useState } from 'react'
|
||||
import { useDispatch, useSelector } from 'react-redux'
|
||||
|
||||
interface MemoriesSettingsModalProps {
|
||||
@ -129,33 +130,9 @@ const MemoriesSettingsModal: FC<MemoriesSettingsModalProps> = ({ visible, onSubm
|
||||
}
|
||||
}
|
||||
|
||||
const llmSelectOptions = providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: sortBy(p.models, 'name')
|
||||
.filter((model) => !isEmbeddingModel(model) && !isRerankModel(model))
|
||||
.map((m) => ({
|
||||
label: m.name,
|
||||
value: getModelUniqId(m)
|
||||
}))
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
const llmPredicate = useCallback((m: Model) => !isEmbeddingModel(m) && !isRerankModel(m), [])
|
||||
|
||||
const embeddingSelectOptions = providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: sortBy(p.models, 'name')
|
||||
.filter((model) => isEmbeddingModel(model) && !isRerankModel(model))
|
||||
.map((m) => ({
|
||||
label: m.name,
|
||||
value: getModelUniqId(m)
|
||||
}))
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
const embeddingPredicate = useCallback((m: Model) => isEmbeddingModel(m) && !isRerankModel(m), [])
|
||||
|
||||
return (
|
||||
<Modal
|
||||
@ -183,13 +160,21 @@ const MemoriesSettingsModal: FC<MemoriesSettingsModalProps> = ({ visible, onSubm
|
||||
label={t('memory.llm_model')}
|
||||
name="llmModel"
|
||||
rules={[{ required: true, message: t('memory.please_select_llm_model') }]}>
|
||||
<Select placeholder={t('memory.select_llm_model_placeholder')} options={llmSelectOptions} showSearch />
|
||||
<ModelSelector
|
||||
providers={providers}
|
||||
predicate={llmPredicate}
|
||||
placeholder={t('memory.select_llm_model_placeholder')}
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
label={t('memory.embedding_model')}
|
||||
name="embedderModel"
|
||||
rules={[{ required: true, message: t('memory.please_select_embedding_model') }]}>
|
||||
<Select placeholder={t('memory.select_embedding_model_placeholder')} options={embeddingSelectOptions} />
|
||||
<ModelSelector
|
||||
providers={providers}
|
||||
predicate={embeddingPredicate}
|
||||
placeholder={t('memory.select_embedding_model_placeholder')}
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
label={t('knowledge.dimensions_auto_set')}
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
import { loggerService } from '@logger'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import ModelSelector from '@renderer/components/ModelSelector'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
import { useModel } from '@renderer/hooks/useModel'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { selectMemoryConfig, updateMemoryConfig } from '@renderer/store/memory'
|
||||
import { Model } from '@renderer/types'
|
||||
import { getErrorMessage } from '@renderer/utils/error'
|
||||
import { Form, InputNumber, Modal, Select, Switch } from 'antd'
|
||||
import { Form, InputNumber, Modal, Switch } from 'antd'
|
||||
import { t } from 'i18next'
|
||||
import { sortBy } from 'lodash'
|
||||
import { FC, useEffect, useState } from 'react'
|
||||
import { FC, useCallback, useEffect, useState } from 'react'
|
||||
import { useDispatch, useSelector } from 'react-redux'
|
||||
|
||||
const logger = loggerService.withContext('MemoriesSettingsModal')
|
||||
@ -125,33 +126,9 @@ const MemoriesSettingsModal: FC<MemoriesSettingsModalProps> = ({ visible, onSubm
|
||||
}
|
||||
}
|
||||
|
||||
const llmSelectOptions = providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: sortBy(p.models, 'name')
|
||||
.filter((model) => !isEmbeddingModel(model) && !isRerankModel(model))
|
||||
.map((m) => ({
|
||||
label: m.name,
|
||||
value: getModelUniqId(m)
|
||||
}))
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
const llmPredicate = useCallback((m: Model) => !isEmbeddingModel(m) && !isRerankModel(m), [])
|
||||
|
||||
const embeddingSelectOptions = providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: sortBy(p.models, 'name')
|
||||
.filter((model) => isEmbeddingModel(model) && !isRerankModel(model))
|
||||
.map((m) => ({
|
||||
label: m.name,
|
||||
value: getModelUniqId(m)
|
||||
}))
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
const embeddingPredicate = useCallback((m: Model) => isEmbeddingModel(m) && !isRerankModel(m), [])
|
||||
|
||||
return (
|
||||
<Modal
|
||||
@ -179,13 +156,21 @@ const MemoriesSettingsModal: FC<MemoriesSettingsModalProps> = ({ visible, onSubm
|
||||
label={t('memory.llm_model')}
|
||||
name="llmModel"
|
||||
rules={[{ required: true, message: t('memory.please_select_llm_model') }]}>
|
||||
<Select placeholder={t('memory.select_llm_model_placeholder')} options={llmSelectOptions} showSearch />
|
||||
<ModelSelector
|
||||
providers={providers}
|
||||
predicate={llmPredicate}
|
||||
placeholder={t('memory.select_llm_model_placeholder')}
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
label={t('memory.embedding_model')}
|
||||
name="embedderModel"
|
||||
rules={[{ required: true, message: t('memory.please_select_embedding_model') }]}>
|
||||
<Select placeholder={t('memory.select_embedding_model_placeholder')} options={embeddingSelectOptions} />
|
||||
<ModelSelector
|
||||
providers={providers}
|
||||
predicate={embeddingPredicate}
|
||||
placeholder={t('memory.select_embedding_model_placeholder')}
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
label={t('knowledge.dimensions_auto_set')}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { RedoOutlined } from '@ant-design/icons'
|
||||
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import ModelSelector from '@renderer/components/ModelSelector'
|
||||
import PromptPopup from '@renderer/components/Popups/PromptPopup'
|
||||
import { isEmbeddingModel, isRerankModel, isTextToImageModel } from '@renderer/config/models'
|
||||
import { TRANSLATE_PROMPT } from '@renderer/config/prompts'
|
||||
@ -15,9 +16,9 @@ import { setQuickAssistantId } from '@renderer/store/llm'
|
||||
import { setTranslateModelPrompt } from '@renderer/store/settings'
|
||||
import { Model } from '@renderer/types'
|
||||
import { Button, Select, Tooltip } from 'antd'
|
||||
import { find, sortBy } from 'lodash'
|
||||
import { find } from 'lodash'
|
||||
import { CircleHelp, FolderPen, Languages, MessageSquareMore, Rocket, Settings2 } from 'lucide-react'
|
||||
import { FC, useMemo } from 'react'
|
||||
import { FC, useCallback, useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
|
||||
@ -39,27 +40,10 @@ const ModelSettings: FC = () => {
|
||||
const dispatch = useAppDispatch()
|
||||
const { quickAssistantId } = useAppSelector((state) => state.llm)
|
||||
|
||||
const selectOptions = providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.flatMap((p) => {
|
||||
const filteredModels = sortBy(p.models, 'name')
|
||||
.filter((m) => !isEmbeddingModel(m) && !isRerankModel(m) && !isTextToImageModel(m))
|
||||
.map((m) => ({
|
||||
label: `${m.name} | ${p.isSystem ? t(`provider.${p.id}`) : p.name}`,
|
||||
value: getModelUniqId(m)
|
||||
}))
|
||||
|
||||
if (filteredModels.length > 0) {
|
||||
return [
|
||||
{
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: filteredModels
|
||||
}
|
||||
]
|
||||
}
|
||||
return []
|
||||
})
|
||||
const modelPredicate = useCallback(
|
||||
(m: Model) => !isEmbeddingModel(m) && !isRerankModel(m) && !isTextToImageModel(m),
|
||||
[]
|
||||
)
|
||||
|
||||
const defaultModelValue = useMemo(
|
||||
() => (hasModel(defaultModel) ? getModelUniqId(defaultModel) : undefined),
|
||||
@ -105,13 +89,13 @@ const ModelSettings: FC = () => {
|
||||
</HStack>
|
||||
</SettingTitle>
|
||||
<HStack alignItems="center">
|
||||
<Select
|
||||
<ModelSelector
|
||||
providers={providers}
|
||||
predicate={modelPredicate}
|
||||
value={defaultModelValue}
|
||||
defaultValue={defaultModelValue}
|
||||
style={{ width: 360 }}
|
||||
onChange={(value) => setDefaultModel(find(allModels, JSON.parse(value)) as Model)}
|
||||
options={selectOptions}
|
||||
showSearch
|
||||
placeholder={t('settings.models.empty')}
|
||||
/>
|
||||
<Button icon={<Settings2 size={16} />} style={{ marginLeft: 8 }} onClick={DefaultAssistantSettings.show} />
|
||||
@ -126,13 +110,13 @@ const ModelSettings: FC = () => {
|
||||
</HStack>
|
||||
</SettingTitle>
|
||||
<HStack alignItems="center">
|
||||
<Select
|
||||
<ModelSelector
|
||||
providers={providers}
|
||||
predicate={modelPredicate}
|
||||
value={defaultTopicNamingModel}
|
||||
defaultValue={defaultTopicNamingModel}
|
||||
style={{ width: 360 }}
|
||||
onChange={(value) => setTopicNamingModel(find(allModels, JSON.parse(value)) as Model)}
|
||||
options={selectOptions}
|
||||
showSearch
|
||||
placeholder={t('settings.models.empty')}
|
||||
/>
|
||||
<Button icon={<Settings2 size={16} />} style={{ marginLeft: 8 }} onClick={TopicNamingModalPopup.show} />
|
||||
@ -147,13 +131,13 @@ const ModelSettings: FC = () => {
|
||||
</HStack>
|
||||
</SettingTitle>
|
||||
<HStack alignItems="center">
|
||||
<Select
|
||||
<ModelSelector
|
||||
providers={providers}
|
||||
predicate={modelPredicate}
|
||||
value={defaultTranslateModel}
|
||||
defaultValue={defaultTranslateModel}
|
||||
style={{ width: 360 }}
|
||||
onChange={(value) => setTranslateModel(find(allModels, JSON.parse(value)) as Model)}
|
||||
options={selectOptions}
|
||||
showSearch
|
||||
placeholder={t('settings.models.empty')}
|
||||
/>
|
||||
<Button icon={<Settings2 size={16} />} style={{ marginLeft: 8 }} onClick={onUpdateTranslateModel} />
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import ModelSelector from '@renderer/components/ModelSelector'
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
import { isRerankModel } from '@renderer/config/models'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { Provider } from '@renderer/types'
|
||||
import { Modal, Select } from 'antd'
|
||||
import { first, orderBy } from 'lodash'
|
||||
import { useState } from 'react'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
import { Modal } from 'antd'
|
||||
import { first } from 'lodash'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
|
||||
interface ShowParams {
|
||||
provider: Provider
|
||||
@ -16,10 +18,19 @@ interface Props extends ShowParams {
|
||||
}
|
||||
|
||||
const PopupContainer: React.FC<Props> = ({ provider, resolve, reject }) => {
|
||||
const models = orderBy(provider.models, 'group').filter((i) => !isEmbeddingModel(i) && !isRerankModel(i))
|
||||
const [open, setOpen] = useState(true)
|
||||
|
||||
// Keep the natural order of models
|
||||
const models = useMemo(() => provider.models.filter((m) => !isRerankModel(m)), [provider])
|
||||
|
||||
const [model, setModel] = useState(first(models))
|
||||
|
||||
const modelPredicate = useCallback((m: Model) => !isRerankModel(m), [])
|
||||
|
||||
const defaultModelValue = useMemo(() => {
|
||||
return model ? getModelUniqId(model) : undefined
|
||||
}, [model])
|
||||
|
||||
const onOk = () => {
|
||||
if (!model) {
|
||||
window.message.error({ content: i18n.t('message.error.enter.model'), key: 'api-check' })
|
||||
@ -50,14 +61,15 @@ const PopupContainer: React.FC<Props> = ({ provider, resolve, reject }) => {
|
||||
transitionName="animation-move-down"
|
||||
width={400}
|
||||
centered>
|
||||
<Select
|
||||
value={model?.id}
|
||||
<ModelSelector
|
||||
providers={[provider]}
|
||||
predicate={modelPredicate}
|
||||
grouped={false}
|
||||
defaultValue={defaultModelValue}
|
||||
placeholder={i18n.t('settings.models.empty')}
|
||||
options={models.map((m) => ({ label: m.name, value: m.id }))}
|
||||
style={{ width: '100%' }}
|
||||
showSearch
|
||||
onChange={(value) => {
|
||||
setModel(provider.models.find((m) => m.id === value)!)
|
||||
setModel(models.find((m) => value === getModelUniqId(m))!)
|
||||
}}
|
||||
/>
|
||||
</Modal>
|
||||
|
||||
@ -7,7 +7,15 @@ import { useAllProviders, useProviders } from '@renderer/hooks/useProvider'
|
||||
import ImageStorage from '@renderer/services/ImageStorage'
|
||||
import { INITIAL_PROVIDERS } from '@renderer/store/llm'
|
||||
import { Provider, ProviderType } from '@renderer/types'
|
||||
import { droppableReorder, generateColorFromChar, getFirstCharacter, uuid } from '@renderer/utils'
|
||||
import {
|
||||
droppableReorder,
|
||||
generateColorFromChar,
|
||||
getFancyProviderName,
|
||||
getFirstCharacter,
|
||||
matchKeywordsInModel,
|
||||
matchKeywordsInProvider,
|
||||
uuid
|
||||
} from '@renderer/utils'
|
||||
import { Avatar, Button, Card, Dropdown, Input, MenuProps, Tag } from 'antd'
|
||||
import { Eye, EyeOff, Search, UserPen } from 'lucide-react'
|
||||
import { FC, useEffect, useState } from 'react'
|
||||
@ -420,19 +428,9 @@ const ProvidersList: FC = () => {
|
||||
}
|
||||
|
||||
const filteredProviders = providers.filter((provider) => {
|
||||
const providerName = provider.isSystem ? t(`provider.${provider.id}`) : provider.name
|
||||
|
||||
const isProviderMatch =
|
||||
provider.id.toLowerCase().includes(searchText.toLowerCase()) ||
|
||||
providerName.toLowerCase().includes(searchText.toLowerCase())
|
||||
|
||||
const isModelMatch = provider.models.some((model) => {
|
||||
return (
|
||||
model.id.toLowerCase().includes(searchText.toLowerCase()) ||
|
||||
model.name.toLowerCase().includes(searchText.toLowerCase())
|
||||
)
|
||||
})
|
||||
|
||||
const keywords = searchText.toLowerCase().split(/\s+/).filter(Boolean)
|
||||
const isProviderMatch = matchKeywordsInProvider(keywords, provider)
|
||||
const isModelMatch = provider.models.some((model) => matchKeywordsInModel(keywords, model))
|
||||
return isProviderMatch || isModelMatch
|
||||
})
|
||||
|
||||
@ -481,7 +479,7 @@ const ProvidersList: FC = () => {
|
||||
onClick={() => setSelectedProvider(provider)}>
|
||||
{getProviderAvatar(provider)}
|
||||
<ProviderItemName className="text-nowrap">
|
||||
{provider.isSystem ? t(`provider.${provider.id}`) : provider.name}
|
||||
{getFancyProviderName(provider)}
|
||||
</ProviderItemName>
|
||||
{provider.enabled && (
|
||||
<Tag color="green" style={{ marginLeft: 'auto', marginRight: 0, borderRadius: 16 }}>
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import { loggerService } from '@logger'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import ModelSelector from '@renderer/components/ModelSelector'
|
||||
import { DEFAULT_WEBSEARCH_RAG_DOCUMENT_COUNT } from '@renderer/config/constant'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
@ -8,15 +9,15 @@ import { useWebSearchSettings } from '@renderer/hooks/useWebSearchProviders'
|
||||
import { SettingDivider, SettingRow, SettingRowTitle } from '@renderer/pages/settings'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { Model } from '@renderer/types'
|
||||
import { Button, InputNumber, Select, Slider, Tooltip } from 'antd'
|
||||
import { find, sortBy } from 'lodash'
|
||||
import { Button, InputNumber, Slider, Tooltip } from 'antd'
|
||||
import { find } from 'lodash'
|
||||
import { Info, RefreshCw } from 'lucide-react'
|
||||
import { useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const logger = loggerService.withContext('RagSettings')
|
||||
|
||||
const INPUT_BOX_WIDTH = '200px'
|
||||
const INPUT_BOX_WIDTH = 'min(350px, 60%)'
|
||||
|
||||
const RagSettings = () => {
|
||||
const { t } = useTranslation()
|
||||
@ -25,53 +26,16 @@ const RagSettings = () => {
|
||||
const [loadingDimensions, setLoadingDimensions] = useState(false)
|
||||
|
||||
const embeddingModels = useMemo(() => {
|
||||
return providers
|
||||
.map((p) => p.models)
|
||||
.flat()
|
||||
.filter((model) => isEmbeddingModel(model))
|
||||
return providers.flatMap((p) => p.models).filter((model) => isEmbeddingModel(model))
|
||||
}, [providers])
|
||||
|
||||
const rerankModels = useMemo(() => {
|
||||
return providers
|
||||
.map((p) => p.models)
|
||||
.flat()
|
||||
.filter((model) => isRerankModel(model))
|
||||
return providers.flatMap((p) => p.models).filter((model) => isRerankModel(model))
|
||||
}, [providers])
|
||||
|
||||
const embeddingSelectOptions = useMemo(() => {
|
||||
return providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: sortBy(p.models, 'name')
|
||||
.filter((model) => isEmbeddingModel(model))
|
||||
.map((m) => ({
|
||||
label: m.name,
|
||||
value: getModelUniqId(m),
|
||||
providerId: p.id,
|
||||
modelId: m.id
|
||||
}))
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
}, [providers, t])
|
||||
|
||||
const rerankSelectOptions = useMemo(() => {
|
||||
return providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: sortBy(p.models, 'name')
|
||||
.filter((model) => isRerankModel(model))
|
||||
.map((m) => ({
|
||||
label: m.name,
|
||||
value: getModelUniqId(m)
|
||||
}))
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
}, [providers, t])
|
||||
const rerankProviders = useMemo(() => {
|
||||
return providers.filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
|
||||
}, [providers])
|
||||
|
||||
const handleEmbeddingModelChange = (modelValue: string) => {
|
||||
const selectedModel = find(embeddingModels, JSON.parse(modelValue)) as Model
|
||||
@ -125,14 +89,14 @@ const RagSettings = () => {
|
||||
<>
|
||||
<SettingRow>
|
||||
<SettingRowTitle>{t('models.embedding_model')}</SettingRowTitle>
|
||||
<Select
|
||||
<ModelSelector
|
||||
providers={providers}
|
||||
predicate={isEmbeddingModel}
|
||||
value={compressionConfig?.embeddingModel ? getModelUniqId(compressionConfig.embeddingModel) : undefined}
|
||||
style={{ width: INPUT_BOX_WIDTH }}
|
||||
options={embeddingSelectOptions}
|
||||
placeholder={t('settings.models.empty')}
|
||||
onChange={handleEmbeddingModelChange}
|
||||
allowClear={false}
|
||||
showSearch
|
||||
/>
|
||||
</SettingRow>
|
||||
<SettingDivider />
|
||||
@ -166,14 +130,14 @@ const RagSettings = () => {
|
||||
|
||||
<SettingRow>
|
||||
<SettingRowTitle>{t('models.rerank_model')}</SettingRowTitle>
|
||||
<Select
|
||||
<ModelSelector
|
||||
providers={rerankProviders}
|
||||
predicate={isRerankModel}
|
||||
value={compressionConfig?.rerankModel ? getModelUniqId(compressionConfig.rerankModel) : undefined}
|
||||
style={{ width: INPUT_BOX_WIDTH }}
|
||||
options={rerankSelectOptions}
|
||||
placeholder={t('settings.models.empty')}
|
||||
onChange={handleRerankModelChange}
|
||||
allowClear
|
||||
showSearch
|
||||
/>
|
||||
</SettingRow>
|
||||
<SettingDivider />
|
||||
|
||||
@ -6,6 +6,9 @@ import { useTranslation } from 'react-i18next'
|
||||
import CutoffSettings from './CutoffSettings'
|
||||
import RagSettings from './RagSettings'
|
||||
|
||||
const INPUT_BOX_WIDTH_CUTOFF = '200px'
|
||||
const INPUT_BOX_WIDTH_RAG = 'min(350px, 60%)'
|
||||
|
||||
const CompressionSettings = () => {
|
||||
const { t } = useTranslation()
|
||||
const { compressionConfig, updateCompressionConfig } = useWebSearchSettings()
|
||||
@ -29,7 +32,7 @@ const CompressionSettings = () => {
|
||||
<SettingRowTitle>{t('settings.tool.websearch.compression.method')}</SettingRowTitle>
|
||||
<Select
|
||||
value={compressionConfig?.method || 'none'}
|
||||
style={{ width: '200px' }}
|
||||
style={{ width: compressionConfig?.method === 'rag' ? INPUT_BOX_WIDTH_RAG : INPUT_BOX_WIDTH_CUTOFF }}
|
||||
onChange={handleCompressionMethodChange}
|
||||
options={compressionMethodOptions}
|
||||
/>
|
||||
|
||||
@ -3,6 +3,7 @@ import { loggerService } from '@logger'
|
||||
import { Navbar, NavbarCenter } from '@renderer/components/app/Navbar'
|
||||
import CopyIcon from '@renderer/components/Icons/CopyIcon'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import ModelSelector from '@renderer/components/ModelSelector'
|
||||
import { isEmbeddingModel, isRerankModel, isTextToImageModel } from '@renderer/config/models'
|
||||
import { TRANSLATE_PROMPT } from '@renderer/config/prompts'
|
||||
import { LanguagesEnum, translateLanguageOptions } from '@renderer/config/translate'
|
||||
@ -29,9 +30,9 @@ import { Button, Dropdown, Empty, Flex, Modal, Popconfirm, Select, Space, Switch
|
||||
import TextArea, { TextAreaRef } from 'antd/es/input/TextArea'
|
||||
import dayjs from 'dayjs'
|
||||
import { useLiveQuery } from 'dexie-react-hooks'
|
||||
import { find, isEmpty, sortBy } from 'lodash'
|
||||
import { find, isEmpty } from 'lodash'
|
||||
import { ChevronDown, HelpCircle, Settings2, TriangleAlert } from 'lucide-react'
|
||||
import { FC, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { FC, useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
|
||||
@ -54,8 +55,6 @@ const TranslateSettings: FC<{
|
||||
setBidirectionalPair: (value: [Language, Language]) => void
|
||||
translateModel: Model | undefined
|
||||
onModelChange: (model: Model) => void
|
||||
allModels: Model[]
|
||||
selectOptions: any[]
|
||||
}> = ({
|
||||
visible,
|
||||
onClose,
|
||||
@ -68,9 +67,7 @@ const TranslateSettings: FC<{
|
||||
bidirectionalPair,
|
||||
setBidirectionalPair,
|
||||
translateModel,
|
||||
onModelChange,
|
||||
allModels,
|
||||
selectOptions
|
||||
onModelChange
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { translateModelPrompt } = useSettings()
|
||||
@ -79,6 +76,14 @@ const TranslateSettings: FC<{
|
||||
const [showPrompt, setShowPrompt] = useState(false)
|
||||
const [localPrompt, setLocalPrompt] = useState(translateModelPrompt)
|
||||
|
||||
const { providers } = useProviders()
|
||||
const allModels = useMemo(() => providers.map((p) => p.models).flat(), [providers])
|
||||
|
||||
const modelPredicate = useCallback(
|
||||
(m: Model) => !isEmbeddingModel(m) && !isRerankModel(m) && !isTextToImageModel(m),
|
||||
[]
|
||||
)
|
||||
|
||||
const defaultTranslateModel = useMemo(
|
||||
() => (hasModel(translateModel) ? getModelUniqId(translateModel) : undefined),
|
||||
[translateModel]
|
||||
@ -136,7 +141,9 @@ const TranslateSettings: FC<{
|
||||
</Tooltip>
|
||||
</div>
|
||||
<HStack alignItems="center" gap={5}>
|
||||
<Select
|
||||
<ModelSelector
|
||||
providers={providers}
|
||||
predicate={modelPredicate}
|
||||
style={{ width: '100%' }}
|
||||
placeholder={t('translate.settings.model_placeholder')}
|
||||
value={defaultTranslateModel}
|
||||
@ -146,8 +153,6 @@ const TranslateSettings: FC<{
|
||||
onModelChange(selectedModel)
|
||||
}
|
||||
}}
|
||||
options={selectOptions}
|
||||
showSearch
|
||||
/>
|
||||
</HStack>
|
||||
{!translateModel && (
|
||||
@ -302,9 +307,6 @@ const TranslatePage: FC = () => {
|
||||
const outputTextRef = useRef<HTMLDivElement>(null)
|
||||
const isProgrammaticScroll = useRef(false)
|
||||
|
||||
const { providers } = useProviders()
|
||||
const allModels = useMemo(() => providers.map((p) => p.models).flat(), [providers])
|
||||
|
||||
const _translateHistory = useLiveQuery(() => db.translate_history.orderBy('createdAt').reverse().toArray(), [])
|
||||
|
||||
const translateHistory = useMemo(() => {
|
||||
@ -319,31 +321,6 @@ const TranslatePage: FC = () => {
|
||||
_result = result
|
||||
_targetLanguage = targetLanguage
|
||||
|
||||
const selectOptions = useMemo(
|
||||
() =>
|
||||
providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.flatMap((p) => {
|
||||
const filteredModels = sortBy(p.models, 'name')
|
||||
.filter((m) => !isEmbeddingModel(m) && !isRerankModel(m) && !isTextToImageModel(m))
|
||||
.map((m) => ({
|
||||
label: `${m.name} | ${p.isSystem ? t(`provider.${p.id}`) : p.name}`,
|
||||
value: getModelUniqId(m)
|
||||
}))
|
||||
if (filteredModels.length > 0) {
|
||||
return [
|
||||
{
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: filteredModels
|
||||
}
|
||||
]
|
||||
}
|
||||
return []
|
||||
}),
|
||||
[providers, t]
|
||||
)
|
||||
|
||||
const handleModelChange = (model: Model) => {
|
||||
setTranslateModel(model)
|
||||
db.settings.put({ id: 'translate:model', value: model.id })
|
||||
@ -760,8 +737,6 @@ const TranslatePage: FC = () => {
|
||||
setBidirectionalPair={setBidirectionalPair}
|
||||
translateModel={translateModel}
|
||||
onModelChange={handleModelChange}
|
||||
allModels={allModels}
|
||||
selectOptions={selectOptions}
|
||||
/>
|
||||
</Container>
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import store from '@renderer/store'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
import { t } from 'i18next'
|
||||
import { getFancyProviderName } from '@renderer/utils'
|
||||
import { pick } from 'lodash'
|
||||
|
||||
import { checkApi } from './ApiService'
|
||||
@ -24,7 +24,7 @@ export function getModelName(model?: Model) {
|
||||
const modelName = model?.name || model?.id || ''
|
||||
|
||||
if (provider) {
|
||||
const providerName = provider?.isSystem ? t(`provider.${provider.id}`) : provider?.name
|
||||
const providerName = getFancyProviderName(provider)
|
||||
return `${modelName} | ${providerName}`
|
||||
}
|
||||
|
||||
|
||||
140
src/renderer/src/utils/__tests__/match.test.ts
Normal file
140
src/renderer/src/utils/__tests__/match.test.ts
Normal file
@ -0,0 +1,140 @@
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { includeKeywords, matchKeywordsInModel, matchKeywordsInProvider, matchKeywordsInString } from '../match'
|
||||
|
||||
// mock i18n for getFancyProviderName
|
||||
vi.mock('@renderer/i18n', () => ({
|
||||
default: {
|
||||
t: (key: string) => `i18n:${key}`
|
||||
}
|
||||
}))
|
||||
|
||||
describe('match', () => {
|
||||
const provider: Provider = {
|
||||
id: '12345',
|
||||
type: 'openai',
|
||||
name: 'OpenAI',
|
||||
apiKey: '',
|
||||
apiHost: '',
|
||||
models: [],
|
||||
isSystem: false
|
||||
}
|
||||
const sysProvider: Provider = {
|
||||
...provider,
|
||||
id: 'sys',
|
||||
name: 'SystemProvider',
|
||||
isSystem: true
|
||||
}
|
||||
|
||||
describe('includeKeywords', () => {
|
||||
it('should return true if keywords is empty or blank', () => {
|
||||
expect(includeKeywords('hello world', '')).toBe(true)
|
||||
expect(includeKeywords('hello world', ' ')).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false if target is empty', () => {
|
||||
expect(includeKeywords('', 'hello')).toBe(false)
|
||||
expect(includeKeywords(undefined as any, 'hello')).toBe(false)
|
||||
})
|
||||
|
||||
it('should match all keywords (case-insensitive, whitespace split)', () => {
|
||||
expect(includeKeywords('Hello World', 'hello')).toBe(true)
|
||||
expect(includeKeywords('Hello World', 'world')).toBe(true)
|
||||
expect(includeKeywords('Hello World', 'hello world')).toBe(true)
|
||||
expect(includeKeywords('Hello World', 'world hello')).toBe(true)
|
||||
expect(includeKeywords('Hello World', 'HELLO')).toBe(true)
|
||||
expect(includeKeywords('Hello World', 'hello world')).toBe(true)
|
||||
expect(includeKeywords('Hello\nWorld', 'hello world')).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false if any keyword is not included', () => {
|
||||
expect(includeKeywords('Hello World', 'hello foo')).toBe(false)
|
||||
expect(includeKeywords('Hello World', 'foo')).toBe(false)
|
||||
})
|
||||
|
||||
it('should ignore blank keywords', () => {
|
||||
expect(includeKeywords('Hello World', ' hello ')).toBe(true)
|
||||
expect(includeKeywords('Hello World', 'hello ')).toBe(true)
|
||||
expect(includeKeywords('Hello World', ' ')).toBe(true)
|
||||
})
|
||||
|
||||
it('should handle keyword array', () => {
|
||||
expect(includeKeywords('Hello World', ['hello', 'world'])).toBe(true)
|
||||
expect(includeKeywords('Hello World', ['Hello', 'World'])).toBe(true)
|
||||
expect(includeKeywords('Hello World', ['hello', 'foo'])).toBe(false)
|
||||
expect(includeKeywords('Hello World', ['hello', ''])).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('matchKeywordsInString', () => {
|
||||
it('should delegate to includeKeywords with string', () => {
|
||||
expect(matchKeywordsInString('foo', 'foo bar')).toBe(true)
|
||||
expect(matchKeywordsInString('bar', 'foo bar')).toBe(true)
|
||||
expect(matchKeywordsInString('baz', 'foo bar')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('matchKeywordsInProvider', () => {
|
||||
it('should match non-system provider by name only, not id', () => {
|
||||
expect(matchKeywordsInProvider('OpenAI', provider)).toBe(true)
|
||||
expect(matchKeywordsInProvider('12345', provider)).toBe(false) // Should NOT match by id
|
||||
expect(matchKeywordsInProvider('foo', provider)).toBe(false)
|
||||
})
|
||||
|
||||
it('should match i18n name for system provider', () => {
|
||||
expect(matchKeywordsInProvider('i18n:provider.sys', sysProvider)).toBe(true)
|
||||
expect(matchKeywordsInProvider('SystemProvider', sysProvider)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('matchKeywordsInModel', () => {
|
||||
const model: Model = {
|
||||
id: 'gpt-4.1',
|
||||
provider: 'openai',
|
||||
name: 'GPT-4.1',
|
||||
group: 'gpt'
|
||||
}
|
||||
|
||||
it('should match model name only if provider not given', () => {
|
||||
expect(matchKeywordsInModel('gpt-4.1', model)).toBe(true)
|
||||
expect(matchKeywordsInModel('openai', model)).toBe(false)
|
||||
})
|
||||
|
||||
it('should match model name and provider name if provider given', () => {
|
||||
expect(matchKeywordsInModel('gpt-4.1 openai', model, provider)).toBe(true)
|
||||
expect(matchKeywordsInModel('gpt-4.1', model, provider)).toBe(true)
|
||||
expect(matchKeywordsInModel('foo', model, provider)).toBe(false)
|
||||
})
|
||||
|
||||
it('should match model name and i18n provider name for system provider', () => {
|
||||
expect(matchKeywordsInModel('gpt-4.1 i18n:provider.sys', model, sysProvider)).toBe(true)
|
||||
expect(matchKeywordsInModel('i18n:provider.sys', model, sysProvider)).toBe(true)
|
||||
expect(matchKeywordsInModel('SystemProvider', model, sysProvider)).toBe(false)
|
||||
})
|
||||
|
||||
it('should match model by id when name is customized', () => {
|
||||
const customNameModel: Model = {
|
||||
id: 'claude-3-opus-20240229',
|
||||
provider: 'anthropic',
|
||||
name: 'Opus (Custom Name)',
|
||||
group: 'claude'
|
||||
}
|
||||
|
||||
// search by parts of ID
|
||||
expect(matchKeywordsInModel('claude', customNameModel)).toBe(true)
|
||||
expect(matchKeywordsInModel('opus', customNameModel)).toBe(true)
|
||||
expect(matchKeywordsInModel('20240229', customNameModel)).toBe(true)
|
||||
|
||||
// search by parts of custom name
|
||||
expect(matchKeywordsInModel('Custom', customNameModel)).toBe(true)
|
||||
expect(matchKeywordsInModel('Opus Name', customNameModel)).toBe(true)
|
||||
|
||||
// search by both
|
||||
expect(matchKeywordsInModel('claude custom', customNameModel)).toBe(true)
|
||||
|
||||
// should not match
|
||||
expect(matchKeywordsInModel('sonnet', customNameModel)).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -231,6 +231,7 @@ export * from './api'
|
||||
export * from './file'
|
||||
export * from './image'
|
||||
export * from './json'
|
||||
export * from './match'
|
||||
export * from './naming'
|
||||
export * from './sort'
|
||||
export * from './style'
|
||||
|
||||
80
src/renderer/src/utils/match.ts
Normal file
80
src/renderer/src/utils/match.ts
Normal file
@ -0,0 +1,80 @@
|
||||
import i18n from '@renderer/i18n'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
|
||||
/**
|
||||
* 判断一个字符串是否包含由另一个字符串表示的 keywords
|
||||
* 将 keywords 按空白字符分割成多个关键词,检查目标字符串是否包含所有关键词
|
||||
* - 大小写不敏感
|
||||
* - 支持的分隔符:空格、制表符、换行符等各种空白字符
|
||||
*
|
||||
* @param target 被搜索的字符串
|
||||
* @param keywords 关键词字符串(空格分隔)或关键词数组
|
||||
* @returns 包含所有关键词或者没有有效关键词则返回 true
|
||||
*/
|
||||
export function includeKeywords(target: string, keywords: string | string[]): boolean {
|
||||
const keywordArray = Array.isArray(keywords) ? keywords : (keywords || '').split(/\s+/)
|
||||
const nonEmptyKeywords = keywordArray.filter(Boolean)
|
||||
|
||||
// 如果没有有效关键词,则视为匹配
|
||||
if (nonEmptyKeywords.length === 0) return true
|
||||
|
||||
// 如果没有搜索目标,则视为不匹配
|
||||
if (!target || typeof target !== 'string') return false
|
||||
const targetLower = target.toLowerCase()
|
||||
|
||||
return nonEmptyKeywords.every((keyword) => targetLower.includes(keyword.toLowerCase()))
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查字符串是否包含所有关键词
|
||||
* @see includeKeywords
|
||||
* @param keywords 关键词字符串(空格分隔)或关键词数组
|
||||
* @param value 被搜索的目标字符串
|
||||
* @returns 包含所有关键词则返回 true
|
||||
*/
|
||||
export function matchKeywordsInString(keywords: string | string[], value: string): boolean {
|
||||
return includeKeywords(value, keywords)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查 Provider 是否匹配所有关键词
|
||||
* @param keywords 关键词字符串(空格分隔)或关键词数组
|
||||
* @param provider 被搜索的 Provider 对象
|
||||
* @returns 匹配所有关键词则返回 true
|
||||
*/
|
||||
export function matchKeywordsInProvider(keywords: string | string[], provider: Provider): boolean {
|
||||
return includeKeywords(getProviderSearchString(provider), keywords)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查 Model 是否匹配所有关键词
|
||||
* @param keywords 关键词字符串(空格分隔)或关键词数组
|
||||
* @param model 被搜索的 Model 对象
|
||||
* @param provider 可选的 Provider 对象,用于生成完整模型名称
|
||||
* @returns 匹配所有关键词则返回 true
|
||||
*/
|
||||
export function matchKeywordsInModel(keywords: string | string[], model: Model, provider?: Provider): boolean {
|
||||
const fullName = `${model.name} ${model.id} ${provider ? getProviderSearchString(provider) : ''}`
|
||||
return includeKeywords(fullName, keywords)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Provider 的搜索字符串,它和 getFancyProviderName 不同
|
||||
* @param provider Provider 对象
|
||||
* @returns 搜索字符串
|
||||
*/
|
||||
function getProviderSearchString(provider: Provider) {
|
||||
return provider.isSystem ? `${i18n.t(`provider.${provider.id}`)} ${provider.id}` : provider.name
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据关键词过滤模型
|
||||
* @param keywords 关键词字符串
|
||||
* @param models 模型数组
|
||||
* @param provider 可选的 Provider 对象,用于生成完整模型名称
|
||||
* @returns 过滤后的模型数组
|
||||
*/
|
||||
export function filterModelsByKeywords(keywords: string, models: Model[], provider?: Provider): Model[] {
|
||||
const keywordsArray = keywords.toLowerCase().split(/\s+/).filter(Boolean)
|
||||
return models.filter((model) => matchKeywordsInModel(keywordsArray, model, provider))
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user