mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-25 11:20:07 +08:00
refactor: simplify select model popup (#9630)
* refactor: simplify SelectModelPopup * refactor: extract useModelTagFilter * refactor: improve filter naming * refactor: extract TagFilterSection, improve tag style * test: add tests for filters * refactor: focus on tag selection * refactor: suppress react key warning * refactor: add log to TagFilterSection * refactor: add initialTagSelection * refactor: use objectEntries * test: add tests for TagFilterSection * refactor: improve group action icon style * refactor: ease CustomTag opacity change * fix: GroupItem alignment
This commit is contained in:
parent
a1f5c12a96
commit
9e567ace4e
@ -15,7 +15,7 @@ type Props = {
|
||||
|
||||
const ModelSelectButton = ({ model, onSelectModel, modelFilter, noTooltip, tooltipProps }: Props) => {
|
||||
const onClick = useCallback(async () => {
|
||||
const selectedModel = await SelectModelPopup.show({ model, modelFilter })
|
||||
const selectedModel = await SelectModelPopup.show({ model, filter: modelFilter })
|
||||
if (selectedModel) {
|
||||
onSelectModel?.(selectedModel)
|
||||
}
|
||||
|
||||
@ -0,0 +1,83 @@
|
||||
import { loggerService } from '@logger'
|
||||
import {
|
||||
EmbeddingTag,
|
||||
FreeTag,
|
||||
ReasoningTag,
|
||||
RerankerTag,
|
||||
ToolsCallingTag,
|
||||
VisionTag,
|
||||
WebSearchTag
|
||||
} from '@renderer/components/Tags/Model'
|
||||
import { ModelTag } from '@renderer/types'
|
||||
import { Flex } from 'antd'
|
||||
import React, { startTransition, useCallback, useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
|
||||
const logger = loggerService.withContext('TagFilterSection')
|
||||
|
||||
interface TagFilterSectionProps {
|
||||
availableTags: ModelTag[]
|
||||
tagSelection: Record<ModelTag, boolean>
|
||||
onToggleTag: (tag: ModelTag) => void
|
||||
}
|
||||
|
||||
const TagFilterSection: React.FC<TagFilterSectionProps> = ({ availableTags, tagSelection, onToggleTag }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const handleTagClick = useCallback(
|
||||
(tag: ModelTag) => {
|
||||
startTransition(() => onToggleTag(tag))
|
||||
},
|
||||
[onToggleTag]
|
||||
)
|
||||
|
||||
// 标签组件
|
||||
const tagComponents = useMemo(
|
||||
() => ({
|
||||
vision: VisionTag,
|
||||
embedding: EmbeddingTag,
|
||||
reasoning: ReasoningTag,
|
||||
function_calling: ToolsCallingTag,
|
||||
web_search: WebSearchTag,
|
||||
rerank: RerankerTag,
|
||||
free: FreeTag
|
||||
}),
|
||||
[]
|
||||
)
|
||||
|
||||
return (
|
||||
<FilterContainer>
|
||||
<Flex wrap="wrap" gap={4}>
|
||||
<FilterText>{t('models.filter.by_tag')}</FilterText>
|
||||
{availableTags.map((tag) => {
|
||||
const TagElement = tagComponents[tag]
|
||||
if (!TagElement) {
|
||||
logger.error(`Tag element not found for tag: ${tag}`)
|
||||
return null
|
||||
}
|
||||
return (
|
||||
<TagElement
|
||||
key={`tag-${tag}`}
|
||||
onClick={() => handleTagClick(tag)}
|
||||
inactive={!tagSelection[tag]}
|
||||
showLabel
|
||||
/>
|
||||
)
|
||||
})}
|
||||
</Flex>
|
||||
</FilterContainer>
|
||||
)
|
||||
}
|
||||
|
||||
const FilterContainer = styled.div`
|
||||
padding: 8px;
|
||||
padding-left: 18px;
|
||||
`
|
||||
|
||||
const FilterText = styled.span`
|
||||
color: var(--color-text-3);
|
||||
font-size: 12px;
|
||||
`
|
||||
|
||||
export default TagFilterSection
|
||||
@ -0,0 +1,110 @@
|
||||
import type { ModelTag } from '@renderer/types'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import TagFilterSection from '../TagFilterSection'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
t: vi.fn((key: string) => key),
|
||||
createTagComponent: (name: string) => {
|
||||
// Create a simple button component exposing props for assertions
|
||||
return ({ onClick, inactive, showLabel }: { onClick?: () => void; inactive?: boolean; showLabel?: boolean }) => {
|
||||
const React = require('react')
|
||||
return React.createElement(
|
||||
'button',
|
||||
{
|
||||
type: 'button',
|
||||
'aria-label': `tag-${name}`,
|
||||
'data-inactive': String(Boolean(inactive)),
|
||||
onClick
|
||||
},
|
||||
showLabel ? name : ''
|
||||
)
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({ t: mocks.t })
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/components/Tags/Model', () => ({
|
||||
VisionTag: mocks.createTagComponent('vision'),
|
||||
EmbeddingTag: mocks.createTagComponent('embedding'),
|
||||
ReasoningTag: mocks.createTagComponent('reasoning'),
|
||||
ToolsCallingTag: mocks.createTagComponent('function_calling'),
|
||||
WebSearchTag: mocks.createTagComponent('web_search'),
|
||||
RerankerTag: mocks.createTagComponent('rerank'),
|
||||
FreeTag: mocks.createTagComponent('free')
|
||||
}))
|
||||
|
||||
vi.mock('antd', () => ({
|
||||
Flex: ({ children }: { children: React.ReactNode }) => children
|
||||
}))
|
||||
|
||||
function createSelection(overrides: Partial<Record<ModelTag, boolean>> = {}): Record<ModelTag, boolean> {
|
||||
const base: Record<ModelTag, boolean> = {
|
||||
vision: true,
|
||||
embedding: true,
|
||||
reasoning: true,
|
||||
function_calling: true,
|
||||
web_search: true,
|
||||
rerank: true,
|
||||
free: true
|
||||
}
|
||||
return { ...base, ...overrides }
|
||||
}
|
||||
|
||||
const allTags: ModelTag[] = ['vision', 'embedding', 'reasoning', 'function_calling', 'web_search', 'rerank', 'free']
|
||||
|
||||
describe('TagFilterSection', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('rendering', () => {
|
||||
it('should match snapshot', () => {
|
||||
const { container } = render(
|
||||
<TagFilterSection availableTags={allTags} tagSelection={createSelection()} onToggleTag={vi.fn()} />
|
||||
)
|
||||
expect(container).toMatchSnapshot()
|
||||
})
|
||||
|
||||
it('should reflect inactive state based on tagSelection', () => {
|
||||
render(
|
||||
<TagFilterSection
|
||||
availableTags={['vision']}
|
||||
tagSelection={createSelection({ vision: false })}
|
||||
onToggleTag={vi.fn()}
|
||||
/>
|
||||
)
|
||||
const visionBtn = screen.getByRole('button', { name: 'tag-vision' })
|
||||
expect(visionBtn).toHaveAttribute('data-inactive', 'true')
|
||||
})
|
||||
|
||||
it('should skip unknown tags', () => {
|
||||
render(
|
||||
<TagFilterSection
|
||||
availableTags={['unknown' as unknown as ModelTag, 'vision']}
|
||||
tagSelection={createSelection()}
|
||||
onToggleTag={vi.fn()}
|
||||
/>
|
||||
)
|
||||
expect(screen.queryByRole('button', { name: 'tag-unknown' })).not.toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'tag-vision' })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('functionality', () => {
|
||||
it('should call onToggleTag when a tag is clicked', () => {
|
||||
const handleToggle = vi.fn()
|
||||
render(<TagFilterSection availableTags={allTags} tagSelection={createSelection()} onToggleTag={handleToggle} />)
|
||||
|
||||
const visionBtn = screen.getByRole('button', { name: 'tag-vision' })
|
||||
fireEvent.click(visionBtn)
|
||||
|
||||
expect(handleToggle).toHaveBeenCalledTimes(1)
|
||||
expect(handleToggle).toHaveBeenCalledWith('vision')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,74 @@
|
||||
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
|
||||
|
||||
exports[`TagFilterSection > rendering > should match snapshot 1`] = `
|
||||
.c0 {
|
||||
padding: 8px;
|
||||
padding-left: 18px;
|
||||
}
|
||||
|
||||
.c1 {
|
||||
color: var(--color-text-3);
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
<div>
|
||||
<div
|
||||
class="c0"
|
||||
>
|
||||
<span
|
||||
class="c1"
|
||||
>
|
||||
models.filter.by_tag
|
||||
</span>
|
||||
<button
|
||||
aria-label="tag-vision"
|
||||
data-inactive="false"
|
||||
type="button"
|
||||
>
|
||||
vision
|
||||
</button>
|
||||
<button
|
||||
aria-label="tag-embedding"
|
||||
data-inactive="false"
|
||||
type="button"
|
||||
>
|
||||
embedding
|
||||
</button>
|
||||
<button
|
||||
aria-label="tag-reasoning"
|
||||
data-inactive="false"
|
||||
type="button"
|
||||
>
|
||||
reasoning
|
||||
</button>
|
||||
<button
|
||||
aria-label="tag-function_calling"
|
||||
data-inactive="false"
|
||||
type="button"
|
||||
>
|
||||
function_calling
|
||||
</button>
|
||||
<button
|
||||
aria-label="tag-web_search"
|
||||
data-inactive="false"
|
||||
type="button"
|
||||
>
|
||||
web_search
|
||||
</button>
|
||||
<button
|
||||
aria-label="tag-rerank"
|
||||
data-inactive="false"
|
||||
type="button"
|
||||
>
|
||||
rerank
|
||||
</button>
|
||||
<button
|
||||
aria-label="tag-free"
|
||||
data-inactive="false"
|
||||
type="button"
|
||||
>
|
||||
free
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
@ -0,0 +1,122 @@
|
||||
import type { Model } from '@renderer/types'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { useModelTagFilter } from '../filters'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
isVisionModel: vi.fn(),
|
||||
isEmbeddingModel: vi.fn(),
|
||||
isReasoningModel: vi.fn(),
|
||||
isFunctionCallingModel: vi.fn(),
|
||||
isWebSearchModel: vi.fn(),
|
||||
isRerankModel: vi.fn(),
|
||||
isFreeModel: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
isEmbeddingModel: mocks.isEmbeddingModel,
|
||||
isFunctionCallingModel: mocks.isFunctionCallingModel,
|
||||
isReasoningModel: mocks.isReasoningModel,
|
||||
isRerankModel: mocks.isRerankModel,
|
||||
isVisionModel: mocks.isVisionModel,
|
||||
isWebSearchModel: mocks.isWebSearchModel
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/model', () => ({
|
||||
isFreeModel: mocks.isFreeModel
|
||||
}))
|
||||
|
||||
function createModel(overrides: Partial<Model> = {}): Model {
|
||||
return {
|
||||
id: 'm1',
|
||||
provider: 'openai',
|
||||
name: 'Model-1',
|
||||
group: 'default',
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
describe('useModelTagFilter', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should have all tags unselected initially', () => {
|
||||
const { result } = renderHook(() => useModelTagFilter())
|
||||
|
||||
expect(result.current.tagSelection).toEqual({
|
||||
vision: false,
|
||||
embedding: false,
|
||||
reasoning: false,
|
||||
function_calling: false,
|
||||
web_search: false,
|
||||
rerank: false,
|
||||
free: false
|
||||
})
|
||||
expect(result.current.selectedTags).toEqual([])
|
||||
})
|
||||
|
||||
it('should toggle a tag state', () => {
|
||||
const { result } = renderHook(() => useModelTagFilter())
|
||||
|
||||
act(() => result.current.toggleTag('vision'))
|
||||
expect(result.current.tagSelection.vision).toBe(true)
|
||||
expect(result.current.selectedTags).toEqual(['vision'])
|
||||
|
||||
act(() => result.current.toggleTag('vision'))
|
||||
expect(result.current.tagSelection.vision).toBe(false)
|
||||
expect(result.current.selectedTags).toEqual([])
|
||||
})
|
||||
|
||||
it('should reset all tags to false', () => {
|
||||
const { result } = renderHook(() => useModelTagFilter())
|
||||
|
||||
act(() => result.current.toggleTag('vision'))
|
||||
act(() => result.current.toggleTag('embedding'))
|
||||
expect(result.current.selectedTags.sort()).toEqual(['embedding', 'vision'])
|
||||
|
||||
act(() => result.current.resetTags())
|
||||
expect(result.current.selectedTags).toEqual([])
|
||||
expect(Object.values(result.current.tagSelection).every((v) => v === false)).toBe(true)
|
||||
})
|
||||
|
||||
it('tagFilter returns true when no tags selected', () => {
|
||||
const { result } = renderHook(() => useModelTagFilter())
|
||||
const model = createModel()
|
||||
const passed = result.current.tagFilter(model)
|
||||
expect(passed).toBe(true)
|
||||
expect(mocks.isVisionModel).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('tagFilter uses single selected tag predicate', () => {
|
||||
const { result } = renderHook(() => useModelTagFilter())
|
||||
const model = createModel()
|
||||
|
||||
mocks.isVisionModel.mockReturnValueOnce(true)
|
||||
act(() => result.current.toggleTag('vision'))
|
||||
|
||||
const ok = result.current.tagFilter(model)
|
||||
expect(ok).toBe(true)
|
||||
expect(mocks.isVisionModel).toHaveBeenCalledTimes(1)
|
||||
expect(mocks.isVisionModel).toHaveBeenCalledWith(model)
|
||||
})
|
||||
|
||||
it('tagFilter requires all selected tags to match (AND logic)', () => {
|
||||
const { result } = renderHook(() => useModelTagFilter())
|
||||
const model = createModel()
|
||||
|
||||
act(() => result.current.toggleTag('vision'))
|
||||
act(() => result.current.toggleTag('embedding'))
|
||||
|
||||
// 第一次:vision=true, embedding=false => 应为 false
|
||||
mocks.isVisionModel.mockReturnValueOnce(true)
|
||||
mocks.isEmbeddingModel.mockReturnValueOnce(false)
|
||||
expect(result.current.tagFilter(model)).toBe(false)
|
||||
|
||||
// 第二次:vision=true, embedding=true => 应为 true
|
||||
mocks.isVisionModel.mockReturnValueOnce(true)
|
||||
mocks.isEmbeddingModel.mockReturnValueOnce(true)
|
||||
expect(result.current.tagFilter(model)).toBe(true)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,79 @@
|
||||
import {
|
||||
isEmbeddingModel,
|
||||
isFunctionCallingModel,
|
||||
isReasoningModel,
|
||||
isRerankModel,
|
||||
isVisionModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { Model, ModelTag, objectEntries } from '@renderer/types'
|
||||
import { isFreeModel } from '@renderer/utils/model'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
|
||||
type ModelPredict = (m: Model) => boolean
|
||||
|
||||
const initialTagSelection: Record<ModelTag, boolean> = {
|
||||
vision: false,
|
||||
embedding: false,
|
||||
reasoning: false,
|
||||
function_calling: false,
|
||||
web_search: false,
|
||||
rerank: false,
|
||||
free: false
|
||||
}
|
||||
|
||||
/**
|
||||
* 标签筛选 hook,仅关注标签过滤逻辑
|
||||
*/
|
||||
export function useModelTagFilter() {
|
||||
const filterConfig: Record<ModelTag, ModelPredict> = useMemo(
|
||||
() => ({
|
||||
vision: isVisionModel,
|
||||
embedding: isEmbeddingModel,
|
||||
reasoning: isReasoningModel,
|
||||
function_calling: isFunctionCallingModel,
|
||||
web_search: isWebSearchModel,
|
||||
rerank: isRerankModel,
|
||||
free: isFreeModel
|
||||
}),
|
||||
[]
|
||||
)
|
||||
|
||||
const [tagSelection, setTagSelection] = useState<Record<ModelTag, boolean>>(initialTagSelection)
|
||||
|
||||
// 已选中的标签
|
||||
const selectedTags = useMemo(
|
||||
() =>
|
||||
objectEntries(tagSelection)
|
||||
.filter(([, state]) => state)
|
||||
.map(([tag]) => tag),
|
||||
[tagSelection]
|
||||
)
|
||||
|
||||
// 切换标签
|
||||
const toggleTag = useCallback((tag: ModelTag) => {
|
||||
setTagSelection((prev) => ({ ...prev, [tag]: !prev[tag] }))
|
||||
}, [])
|
||||
|
||||
// 重置标签
|
||||
const resetTags = useCallback(() => {
|
||||
setTagSelection(initialTagSelection)
|
||||
}, [])
|
||||
|
||||
// 根据标签过滤模型
|
||||
const tagFilter = useCallback(
|
||||
(model: Model) => {
|
||||
if (selectedTags.length === 0) return true
|
||||
return selectedTags.map((tag) => filterConfig[tag]).every((predict) => predict(model))
|
||||
},
|
||||
[filterConfig, selectedTags]
|
||||
)
|
||||
|
||||
return {
|
||||
tagSelection,
|
||||
selectedTags,
|
||||
tagFilter,
|
||||
toggleTag,
|
||||
resetTags
|
||||
}
|
||||
}
|
||||
@ -1,37 +1,19 @@
|
||||
import { PushpinOutlined } from '@ant-design/icons'
|
||||
import { FreeTrialModelTag } from '@renderer/components/FreeTrialModelTag'
|
||||
import ModelTagsWithLabel from '@renderer/components/ModelTagsWithLabel'
|
||||
import {
|
||||
EmbeddingTag,
|
||||
FreeTag,
|
||||
ReasoningTag,
|
||||
RerankerTag,
|
||||
ToolsCallingTag,
|
||||
VisionTag,
|
||||
WebSearchTag
|
||||
} from '@renderer/components/Tags/Model'
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { DynamicVirtualList, type DynamicVirtualListRef } from '@renderer/components/VirtualList'
|
||||
import {
|
||||
getModelLogo,
|
||||
isEmbeddingModel,
|
||||
isFunctionCallingModel,
|
||||
isReasoningModel,
|
||||
isRerankModel,
|
||||
isVisionModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { getModelLogo } from '@renderer/config/models'
|
||||
import { usePinnedModels } from '@renderer/hooks/usePinnedModels'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { Model, ModelTag, ModelType, objectEntries, Provider } from '@renderer/types'
|
||||
import { Model, ModelType, objectEntries, Provider } from '@renderer/types'
|
||||
import { classNames, filterModelsByKeywords, getFancyProviderName } from '@renderer/utils'
|
||||
import { getModelTags, isFreeModel } from '@renderer/utils/model'
|
||||
import { Avatar, Button, Divider, Empty, Flex, Modal, Tooltip } from 'antd'
|
||||
import { getModelTags } from '@renderer/utils/model'
|
||||
import { Avatar, Divider, Empty, Modal, Tooltip } from 'antd'
|
||||
import { first, sortBy } from 'lodash'
|
||||
import { Settings2 } from 'lucide-react'
|
||||
import React, {
|
||||
ReactNode,
|
||||
startTransition,
|
||||
useCallback,
|
||||
useDeferredValue,
|
||||
@ -44,18 +26,20 @@ import React, {
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
|
||||
import { useModelTagFilter } from './filters'
|
||||
import SelectModelSearchBar from './searchbar'
|
||||
import TagFilterSection from './TagFilterSection'
|
||||
import { FlatListItem, FlatListModel } from './types'
|
||||
|
||||
const PAGE_SIZE = 12
|
||||
const ITEM_HEIGHT = 36
|
||||
|
||||
type ModelPredict = (m: Model) => boolean
|
||||
|
||||
interface PopupParams {
|
||||
model?: Model
|
||||
modelFilter?: (model: Model) => boolean
|
||||
userFilterDisabled?: boolean
|
||||
/** Basic model filter */
|
||||
filter?: (model: Model) => boolean
|
||||
/** Show tag filter section */
|
||||
showTagFilter?: boolean
|
||||
}
|
||||
|
||||
interface Props extends PopupParams {
|
||||
@ -66,7 +50,7 @@ export type FilterType = Exclude<ModelType, 'text'> | 'free'
|
||||
|
||||
// const logger = loggerService.withContext('SelectModelPopup')
|
||||
|
||||
const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter, userFilterDisabled }) => {
|
||||
const PopupContainer: React.FC<Props> = ({ model, filter: baseFilter, showTagFilter = true, resolve }) => {
|
||||
const { t } = useTranslation()
|
||||
const { providers } = useProviders()
|
||||
const { pinnedModels, togglePinnedModel, loading } = usePinnedModels()
|
||||
@ -75,11 +59,6 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter, userFilt
|
||||
const [_searchText, setSearchText] = useState('')
|
||||
const searchText = useDeferredValue(_searchText)
|
||||
|
||||
const allModels: Model[] = useMemo(
|
||||
() => providers.flatMap((p) => p.models).filter(modelFilter ?? (() => true)),
|
||||
[modelFilter, providers]
|
||||
)
|
||||
|
||||
// 当前选中的模型ID
|
||||
const currentModelId = model ? getModelUniqId(model) : ''
|
||||
|
||||
@ -94,95 +73,16 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter, userFilt
|
||||
})
|
||||
}, [])
|
||||
|
||||
// 管理用户筛选状态
|
||||
/** 从模型列表获取的需要显示的标签 */
|
||||
const availableTags = useMemo(
|
||||
() =>
|
||||
objectEntries(getModelTags(allModels))
|
||||
.filter(([, state]) => state)
|
||||
.map(([tag]) => tag),
|
||||
[allModels]
|
||||
)
|
||||
const { tagSelection, selectedTags, tagFilter, toggleTag } = useModelTagFilter()
|
||||
|
||||
const filterConfig: Record<ModelTag, ModelPredict> = useMemo(
|
||||
() => ({
|
||||
vision: isVisionModel,
|
||||
embedding: isEmbeddingModel,
|
||||
reasoning: isReasoningModel,
|
||||
function_calling: isFunctionCallingModel,
|
||||
web_search: isWebSearchModel,
|
||||
rerank: isRerankModel,
|
||||
free: isFreeModel
|
||||
}),
|
||||
[]
|
||||
)
|
||||
// 计算要显示的可用标签列表
|
||||
const availableTags = useMemo(() => {
|
||||
const models = providers.flatMap((p) => p.models).filter(baseFilter ?? (() => true))
|
||||
return objectEntries(getModelTags(models))
|
||||
.filter(([, state]) => state)
|
||||
.map(([tag]) => tag)
|
||||
}, [providers, baseFilter])
|
||||
|
||||
/** 当前选择的标签,表示是否启用特定tag的筛选 */
|
||||
const [filterTags, setFilterTags] = useState<Record<ModelTag, boolean>>({
|
||||
vision: false,
|
||||
embedding: false,
|
||||
reasoning: false,
|
||||
function_calling: false,
|
||||
web_search: false,
|
||||
rerank: false,
|
||||
free: false
|
||||
})
|
||||
const selectedFilterTags = useMemo(
|
||||
() =>
|
||||
objectEntries(filterTags)
|
||||
.filter(([, state]) => state)
|
||||
.map(([tag]) => tag),
|
||||
[filterTags]
|
||||
)
|
||||
|
||||
const userFilter = useCallback(
|
||||
(model: Model) => {
|
||||
return selectedFilterTags
|
||||
.map((tag) => [tag, filterConfig[tag]] as const)
|
||||
.reduce((prev, [tag, predict]) => {
|
||||
return prev && (!filterTags[tag] || predict(model))
|
||||
}, true)
|
||||
},
|
||||
[filterConfig, filterTags, selectedFilterTags]
|
||||
)
|
||||
|
||||
const onClickTag = useCallback((type: ModelTag) => {
|
||||
startTransition(() => {
|
||||
setFilterTags((prev) => ({ ...prev, [type]: !prev[type] }))
|
||||
})
|
||||
}, [])
|
||||
|
||||
// 筛选项列表
|
||||
const tagsItems: Record<ModelTag, ReactNode> = useMemo(
|
||||
() => ({
|
||||
vision: <VisionTag showLabel inactive={!filterTags.vision} onClick={() => onClickTag('vision')} />,
|
||||
embedding: <EmbeddingTag inactive={!filterTags.embedding} onClick={() => onClickTag('embedding')} />,
|
||||
reasoning: <ReasoningTag showLabel inactive={!filterTags.reasoning} onClick={() => onClickTag('reasoning')} />,
|
||||
function_calling: (
|
||||
<ToolsCallingTag
|
||||
showLabel
|
||||
inactive={!filterTags.function_calling}
|
||||
onClick={() => onClickTag('function_calling')}
|
||||
/>
|
||||
),
|
||||
web_search: <WebSearchTag showLabel inactive={!filterTags.web_search} onClick={() => onClickTag('web_search')} />,
|
||||
rerank: <RerankerTag inactive={!filterTags.rerank} onClick={() => onClickTag('rerank')} />,
|
||||
free: <FreeTag inactive={!filterTags.free} onClick={() => onClickTag('free')} />
|
||||
}),
|
||||
[
|
||||
filterTags.embedding,
|
||||
filterTags.free,
|
||||
filterTags.function_calling,
|
||||
filterTags.reasoning,
|
||||
filterTags.rerank,
|
||||
filterTags.vision,
|
||||
filterTags.web_search,
|
||||
onClickTag
|
||||
]
|
||||
)
|
||||
|
||||
// 要显示的筛选项
|
||||
const displayedTags = useMemo(() => availableTags.map((tag) => tagsItems[tag]), [availableTags, tagsItems])
|
||||
// 根据输入的文本筛选模型
|
||||
const searchFilter = useCallback(
|
||||
(provider: Provider) => {
|
||||
@ -237,9 +137,9 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter, userFilt
|
||||
const items: FlatListItem[] = []
|
||||
const pinnedModelIds = new Set(pinnedModels)
|
||||
const finalModelFilter = (model: Model) => {
|
||||
const _userFilter = userFilterDisabled || userFilter(model)
|
||||
const _modelFilter = modelFilter === undefined || modelFilter(model)
|
||||
return _userFilter && _modelFilter
|
||||
const _tagFilter = !showTagFilter || tagFilter(model)
|
||||
const _baseFilter = baseFilter === undefined || baseFilter(model)
|
||||
return _tagFilter && _baseFilter
|
||||
}
|
||||
|
||||
// 添加置顶模型分组(仅在无搜索文本时)
|
||||
@ -279,11 +179,10 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter, userFilt
|
||||
name: getFancyProviderName(p),
|
||||
actions: (
|
||||
<Tooltip title={t('navigate.provider_settings')} mouseEnterDelay={0.5} mouseLeaveDelay={0}>
|
||||
<Button
|
||||
type="text"
|
||||
size="small"
|
||||
shape="circle"
|
||||
icon={<Settings2 size={12} color="var(--color-text-3)" style={{ pointerEvents: 'none' }} />}
|
||||
<Settings2
|
||||
size={12}
|
||||
color="var(--color-text)"
|
||||
className="action-icon"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
setOpen(false)
|
||||
@ -306,9 +205,9 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter, userFilt
|
||||
pinnedModels,
|
||||
searchText.length,
|
||||
providers,
|
||||
userFilterDisabled,
|
||||
userFilter,
|
||||
modelFilter,
|
||||
showTagFilter,
|
||||
tagFilter,
|
||||
baseFilter,
|
||||
createModelItem,
|
||||
t,
|
||||
searchFilter,
|
||||
@ -319,7 +218,7 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter, userFilt
|
||||
return Math.min(PAGE_SIZE, listItems.length) * ITEM_HEIGHT
|
||||
}, [listItems.length])
|
||||
|
||||
// 处理程序化滚动(加载、搜索开始、搜索清空)
|
||||
// 处理程序化滚动(加载、搜索开始、搜索清空、tag 筛选)
|
||||
useLayoutEffect(() => {
|
||||
if (loading) return
|
||||
|
||||
@ -330,8 +229,8 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter, userFilt
|
||||
|
||||
let targetItemKey: string | undefined
|
||||
|
||||
// 启动搜索时,滚动到第一个 item
|
||||
if (searchText) {
|
||||
// 启动搜索或 tag 筛选时,滚动到第一个 item
|
||||
if (searchText || selectedTags.length > 0) {
|
||||
targetItemKey = modelItems[0]?.key
|
||||
}
|
||||
// 初始加载或清空搜索时,滚动到 selected item
|
||||
@ -352,7 +251,7 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter, userFilt
|
||||
})
|
||||
}
|
||||
}
|
||||
}, [searchText, listItems, modelItems, loading, setFocusedItemKey, listHeight])
|
||||
}, [searchText, listItems, modelItems, loading, setFocusedItemKey, listHeight, selectedTags.length])
|
||||
|
||||
const handleItemClick = useCallback(
|
||||
(item: FlatListItem) => {
|
||||
@ -511,9 +410,7 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter, userFilt
|
||||
borderRadius: 20,
|
||||
padding: 0,
|
||||
overflow: 'hidden',
|
||||
paddingBottom: 16,
|
||||
// 需要稳定高度避免布局偏移
|
||||
height: userFilterDisabled ? undefined : 530
|
||||
paddingBottom: 16
|
||||
},
|
||||
body: {
|
||||
maxHeight: 'inherit',
|
||||
@ -525,14 +422,9 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter, userFilt
|
||||
{/* 搜索框 */}
|
||||
<SelectModelSearchBar onSearch={setSearchText} />
|
||||
<Divider style={{ margin: 0, marginTop: 4, borderBlockStartWidth: 0.5 }} />
|
||||
{!userFilterDisabled && (
|
||||
{showTagFilter && (
|
||||
<>
|
||||
<FilterContainer>
|
||||
<Flex wrap="wrap" gap={4}>
|
||||
<FilterText>{t('models.filter.by_tag')}</FilterText>
|
||||
{displayedTags.map((item) => item)}
|
||||
</Flex>
|
||||
</FilterContainer>
|
||||
<TagFilterSection availableTags={availableTags} tagSelection={tagSelection} onToggleTag={toggleTag} />
|
||||
<Divider style={{ margin: 0, borderBlockStartWidth: 0.5 }} />
|
||||
</>
|
||||
)}
|
||||
@ -561,16 +453,6 @@ const PopupContainer: React.FC<Props> = ({ model, resolve, modelFilter, userFilt
|
||||
)
|
||||
}
|
||||
|
||||
const FilterContainer = styled.div`
|
||||
padding: 8px;
|
||||
padding-left: 18px;
|
||||
`
|
||||
|
||||
const FilterText = styled.span`
|
||||
color: var(--color-text-3);
|
||||
font-size: 12px;
|
||||
`
|
||||
|
||||
const ListContainer = styled.div`
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
@ -579,25 +461,28 @@ const ListContainer = styled.div`
|
||||
const GroupItem = styled.div`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 2px;
|
||||
gap: 8px;
|
||||
position: relative;
|
||||
line-height: 1;
|
||||
font-size: 12px;
|
||||
font-weight: normal;
|
||||
height: ${ITEM_HEIGHT}px;
|
||||
padding: 5px 12px 5px 18px;
|
||||
padding: 5px 18px;
|
||||
color: var(--color-text-3);
|
||||
z-index: 1;
|
||||
background: var(--modal-background);
|
||||
|
||||
&:hover {
|
||||
.ant-btn {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
.ant-btn {
|
||||
.action-icon {
|
||||
cursor: pointer;
|
||||
opacity: 0;
|
||||
transition: opacity 0.2s;
|
||||
|
||||
&:hover {
|
||||
opacity: 1 !important;
|
||||
}
|
||||
}
|
||||
&:hover .action-icon {
|
||||
opacity: 0.3;
|
||||
}
|
||||
`
|
||||
|
||||
|
||||
@ -37,8 +37,12 @@ const CustomTag: FC<CustomTagProps> = ({
|
||||
$color={actualColor}
|
||||
$size={size}
|
||||
$closable={closable}
|
||||
$clickable={!disabled && !!onClick}
|
||||
onClick={disabled ? undefined : onClick}
|
||||
style={{ cursor: disabled ? 'not-allowed' : onClick ? 'pointer' : 'auto', ...style }}>
|
||||
style={{
|
||||
...(disabled && { cursor: 'not-allowed' }),
|
||||
...style
|
||||
}}>
|
||||
{icon && icon} {children}
|
||||
{closable && (
|
||||
<CloseIcon
|
||||
@ -66,7 +70,7 @@ const CustomTag: FC<CustomTagProps> = ({
|
||||
|
||||
export default memo(CustomTag)
|
||||
|
||||
const Tag = styled.div<{ $color: string; $size: number; $closable: boolean }>`
|
||||
const Tag = styled.div<{ $color: string; $size: number; $closable: boolean; $clickable: boolean }>`
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
@ -79,10 +83,16 @@ const Tag = styled.div<{ $color: string; $size: number; $closable: boolean }>`
|
||||
line-height: 1;
|
||||
white-space: nowrap;
|
||||
position: relative;
|
||||
cursor: ${({ $clickable }) => ($clickable ? 'pointer' : 'auto')};
|
||||
.iconfont {
|
||||
font-size: ${({ $size }) => $size}px;
|
||||
color: ${({ $color }) => $color};
|
||||
}
|
||||
|
||||
transition: opacity 0.2s ease;
|
||||
&:hover {
|
||||
opacity: ${({ $clickable }) => ($clickable ? 0.8 : 1)};
|
||||
}
|
||||
`
|
||||
|
||||
const CloseIcon = styled(CloseOutlined)<{ $size: number; $color: string }>`
|
||||
|
||||
@ -41,4 +41,15 @@ describe('CustomTag', () => {
|
||||
expect(document.querySelector('.ant-tooltip')).toBeNull()
|
||||
expect(screen.queryByRole('tooltip')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not allow click when disabled', async () => {
|
||||
render(
|
||||
<CustomTag color={COLOR} disabled>
|
||||
custom-tag
|
||||
</CustomTag>
|
||||
)
|
||||
const tag = screen.getByText('custom-tag')
|
||||
expect(tag).toBeInTheDocument()
|
||||
expect(tag).toHaveStyle({ cursor: 'not-allowed' })
|
||||
})
|
||||
})
|
||||
|
||||
@ -439,7 +439,7 @@ const MessageMenubar: FC<Props> = (props) => {
|
||||
async (e: React.MouseEvent) => {
|
||||
e.stopPropagation()
|
||||
if (loading) return
|
||||
const selectedModel = await SelectModelPopup.show({ model, modelFilter: mentionModelFilter })
|
||||
const selectedModel = await SelectModelPopup.show({ model, filter: mentionModelFilter })
|
||||
if (!selectedModel) return
|
||||
appendAssistantResponse(message, selectedModel, { ...assistant, model: selectedModel })
|
||||
},
|
||||
|
||||
@ -24,7 +24,7 @@ const SelectModelButton: FC<Props> = ({ assistant }) => {
|
||||
|
||||
const onSelectModel = async (event: React.MouseEvent<HTMLElement>) => {
|
||||
event.currentTarget.blur()
|
||||
const selectedModel = await SelectModelPopup.show({ model, modelFilter })
|
||||
const selectedModel = await SelectModelPopup.show({ model, filter: modelFilter })
|
||||
if (selectedModel) {
|
||||
// 避免更新数据造成关闭弹框的卡顿
|
||||
clearTimeout(timerRef.current)
|
||||
|
||||
@ -184,7 +184,7 @@ const AssistantModelSettings: FC<Props> = ({ assistant, updateAssistant, updateA
|
||||
|
||||
const onSelectModel = useCallback(async () => {
|
||||
const currentModel = defaultModel ? assistant?.model : undefined
|
||||
const selectedModel = await SelectModelPopup.show({ model: currentModel, modelFilter })
|
||||
const selectedModel = await SelectModelPopup.show({ model: currentModel, filter: modelFilter })
|
||||
if (selectedModel) {
|
||||
setDefaultModel(selectedModel)
|
||||
updateAssistant({
|
||||
|
||||
Loading…
Reference in New Issue
Block a user