diff --git a/src/renderer/src/components/ModelSelectButton.tsx b/src/renderer/src/components/ModelSelectButton.tsx index 40330253cd..fd227f1e76 100644 --- a/src/renderer/src/components/ModelSelectButton.tsx +++ b/src/renderer/src/components/ModelSelectButton.tsx @@ -15,7 +15,7 @@ type Props = { const ModelSelectButton = ({ model, onSelectModel, modelFilter, noTooltip, tooltipProps }: Props) => { const onClick = useCallback(async () => { - const selectedModel = await SelectModelPopup.show({ model, modelFilter }) + const selectedModel = await SelectModelPopup.show({ model, filter: modelFilter }) if (selectedModel) { onSelectModel?.(selectedModel) } diff --git a/src/renderer/src/components/Popups/SelectModelPopup/TagFilterSection.tsx b/src/renderer/src/components/Popups/SelectModelPopup/TagFilterSection.tsx new file mode 100644 index 0000000000..aec91bd803 --- /dev/null +++ b/src/renderer/src/components/Popups/SelectModelPopup/TagFilterSection.tsx @@ -0,0 +1,83 @@ +import { loggerService } from '@logger' +import { + EmbeddingTag, + FreeTag, + ReasoningTag, + RerankerTag, + ToolsCallingTag, + VisionTag, + WebSearchTag +} from '@renderer/components/Tags/Model' +import { ModelTag } from '@renderer/types' +import { Flex } from 'antd' +import React, { startTransition, useCallback, useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import styled from 'styled-components' + +const logger = loggerService.withContext('TagFilterSection') + +interface TagFilterSectionProps { + availableTags: ModelTag[] + tagSelection: Record + onToggleTag: (tag: ModelTag) => void +} + +const TagFilterSection: React.FC = ({ availableTags, tagSelection, onToggleTag }) => { + const { t } = useTranslation() + + const handleTagClick = useCallback( + (tag: ModelTag) => { + startTransition(() => onToggleTag(tag)) + }, + [onToggleTag] + ) + + // 标签组件 + const tagComponents = useMemo( + () => ({ + vision: VisionTag, + embedding: EmbeddingTag, + reasoning: ReasoningTag, + function_calling: ToolsCallingTag, + web_search: WebSearchTag, + rerank: RerankerTag, + free: FreeTag + }), + [] + ) + + return ( + + + {t('models.filter.by_tag')} + {availableTags.map((tag) => { + const TagElement = tagComponents[tag] + if (!TagElement) { + logger.error(`Tag element not found for tag: ${tag}`) + return null + } + return ( + handleTagClick(tag)} + inactive={!tagSelection[tag]} + showLabel + /> + ) + })} + + + ) +} + +const FilterContainer = styled.div` + padding: 8px; + padding-left: 18px; +` + +const FilterText = styled.span` + color: var(--color-text-3); + font-size: 12px; +` + +export default TagFilterSection diff --git a/src/renderer/src/components/Popups/SelectModelPopup/__tests__/TagFilterSection.test.tsx b/src/renderer/src/components/Popups/SelectModelPopup/__tests__/TagFilterSection.test.tsx new file mode 100644 index 0000000000..131e924e4e --- /dev/null +++ b/src/renderer/src/components/Popups/SelectModelPopup/__tests__/TagFilterSection.test.tsx @@ -0,0 +1,110 @@ +import type { ModelTag } from '@renderer/types' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import TagFilterSection from '../TagFilterSection' + +const mocks = vi.hoisted(() => ({ + t: vi.fn((key: string) => key), + createTagComponent: (name: string) => { + // Create a simple button component exposing props for assertions + return ({ onClick, inactive, showLabel }: { onClick?: () => void; inactive?: boolean; showLabel?: boolean }) => { + const React = require('react') + return React.createElement( + 'button', + { + type: 'button', + 'aria-label': `tag-${name}`, + 'data-inactive': String(Boolean(inactive)), + onClick + }, + showLabel ? name : '' + ) + } + } +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ t: mocks.t }) +})) + +vi.mock('@renderer/components/Tags/Model', () => ({ + VisionTag: mocks.createTagComponent('vision'), + EmbeddingTag: mocks.createTagComponent('embedding'), + ReasoningTag: mocks.createTagComponent('reasoning'), + ToolsCallingTag: mocks.createTagComponent('function_calling'), + WebSearchTag: mocks.createTagComponent('web_search'), + RerankerTag: mocks.createTagComponent('rerank'), + FreeTag: mocks.createTagComponent('free') +})) + +vi.mock('antd', () => ({ + Flex: ({ children }: { children: React.ReactNode }) => children +})) + +function createSelection(overrides: Partial> = {}): Record { + const base: Record = { + vision: true, + embedding: true, + reasoning: true, + function_calling: true, + web_search: true, + rerank: true, + free: true + } + return { ...base, ...overrides } +} + +const allTags: ModelTag[] = ['vision', 'embedding', 'reasoning', 'function_calling', 'web_search', 'rerank', 'free'] + +describe('TagFilterSection', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should match snapshot', () => { + const { container } = render( + + ) + expect(container).toMatchSnapshot() + }) + + it('should reflect inactive state based on tagSelection', () => { + render( + + ) + const visionBtn = screen.getByRole('button', { name: 'tag-vision' }) + expect(visionBtn).toHaveAttribute('data-inactive', 'true') + }) + + it('should skip unknown tags', () => { + render( + + ) + expect(screen.queryByRole('button', { name: 'tag-unknown' })).not.toBeInTheDocument() + expect(screen.getByRole('button', { name: 'tag-vision' })).toBeInTheDocument() + }) + }) + + describe('functionality', () => { + it('should call onToggleTag when a tag is clicked', () => { + const handleToggle = vi.fn() + render() + + const visionBtn = screen.getByRole('button', { name: 'tag-vision' }) + fireEvent.click(visionBtn) + + expect(handleToggle).toHaveBeenCalledTimes(1) + expect(handleToggle).toHaveBeenCalledWith('vision') + }) + }) +}) diff --git a/src/renderer/src/components/Popups/SelectModelPopup/__tests__/__snapshots__/TagFilterSection.test.tsx.snap b/src/renderer/src/components/Popups/SelectModelPopup/__tests__/__snapshots__/TagFilterSection.test.tsx.snap new file mode 100644 index 0000000000..297987423c --- /dev/null +++ b/src/renderer/src/components/Popups/SelectModelPopup/__tests__/__snapshots__/TagFilterSection.test.tsx.snap @@ -0,0 +1,74 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`TagFilterSection > rendering > should match snapshot 1`] = ` +.c0 { + padding: 8px; + padding-left: 18px; +} + +.c1 { + color: var(--color-text-3); + font-size: 12px; +} + +
+
+ + models.filter.by_tag + + + + + + + + +
+
+`; diff --git a/src/renderer/src/components/Popups/SelectModelPopup/__tests__/filters.test.ts b/src/renderer/src/components/Popups/SelectModelPopup/__tests__/filters.test.ts new file mode 100644 index 0000000000..ebdfce59a3 --- /dev/null +++ b/src/renderer/src/components/Popups/SelectModelPopup/__tests__/filters.test.ts @@ -0,0 +1,122 @@ +import type { Model } from '@renderer/types' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { useModelTagFilter } from '../filters' + +const mocks = vi.hoisted(() => ({ + isVisionModel: vi.fn(), + isEmbeddingModel: vi.fn(), + isReasoningModel: vi.fn(), + isFunctionCallingModel: vi.fn(), + isWebSearchModel: vi.fn(), + isRerankModel: vi.fn(), + isFreeModel: vi.fn() +})) + +vi.mock('@renderer/config/models', () => ({ + isEmbeddingModel: mocks.isEmbeddingModel, + isFunctionCallingModel: mocks.isFunctionCallingModel, + isReasoningModel: mocks.isReasoningModel, + isRerankModel: mocks.isRerankModel, + isVisionModel: mocks.isVisionModel, + isWebSearchModel: mocks.isWebSearchModel +})) + +vi.mock('@renderer/utils/model', () => ({ + isFreeModel: mocks.isFreeModel +})) + +function createModel(overrides: Partial = {}): Model { + return { + id: 'm1', + provider: 'openai', + name: 'Model-1', + group: 'default', + ...overrides + } +} + +describe('useModelTagFilter', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should have all tags unselected initially', () => { + const { result } = renderHook(() => useModelTagFilter()) + + expect(result.current.tagSelection).toEqual({ + vision: false, + embedding: false, + reasoning: false, + function_calling: false, + web_search: false, + rerank: false, + free: false + }) + expect(result.current.selectedTags).toEqual([]) + }) + + it('should toggle a tag state', () => { + const { result } = renderHook(() => useModelTagFilter()) + + act(() => result.current.toggleTag('vision')) + expect(result.current.tagSelection.vision).toBe(true) + expect(result.current.selectedTags).toEqual(['vision']) + + act(() => result.current.toggleTag('vision')) + expect(result.current.tagSelection.vision).toBe(false) + expect(result.current.selectedTags).toEqual([]) + }) + + it('should reset all tags to false', () => { + const { result } = renderHook(() => useModelTagFilter()) + + act(() => result.current.toggleTag('vision')) + act(() => result.current.toggleTag('embedding')) + expect(result.current.selectedTags.sort()).toEqual(['embedding', 'vision']) + + act(() => result.current.resetTags()) + expect(result.current.selectedTags).toEqual([]) + expect(Object.values(result.current.tagSelection).every((v) => v === false)).toBe(true) + }) + + it('tagFilter returns true when no tags selected', () => { + const { result } = renderHook(() => useModelTagFilter()) + const model = createModel() + const passed = result.current.tagFilter(model) + expect(passed).toBe(true) + expect(mocks.isVisionModel).not.toHaveBeenCalled() + }) + + it('tagFilter uses single selected tag predicate', () => { + const { result } = renderHook(() => useModelTagFilter()) + const model = createModel() + + mocks.isVisionModel.mockReturnValueOnce(true) + act(() => result.current.toggleTag('vision')) + + const ok = result.current.tagFilter(model) + expect(ok).toBe(true) + expect(mocks.isVisionModel).toHaveBeenCalledTimes(1) + expect(mocks.isVisionModel).toHaveBeenCalledWith(model) + }) + + it('tagFilter requires all selected tags to match (AND logic)', () => { + const { result } = renderHook(() => useModelTagFilter()) + const model = createModel() + + act(() => result.current.toggleTag('vision')) + act(() => result.current.toggleTag('embedding')) + + // 第一次:vision=true, embedding=false => 应为 false + mocks.isVisionModel.mockReturnValueOnce(true) + mocks.isEmbeddingModel.mockReturnValueOnce(false) + expect(result.current.tagFilter(model)).toBe(false) + + // 第二次:vision=true, embedding=true => 应为 true + mocks.isVisionModel.mockReturnValueOnce(true) + mocks.isEmbeddingModel.mockReturnValueOnce(true) + expect(result.current.tagFilter(model)).toBe(true) + }) +}) diff --git a/src/renderer/src/components/Popups/SelectModelPopup/filters.ts b/src/renderer/src/components/Popups/SelectModelPopup/filters.ts new file mode 100644 index 0000000000..d2ee6c7742 --- /dev/null +++ b/src/renderer/src/components/Popups/SelectModelPopup/filters.ts @@ -0,0 +1,79 @@ +import { + isEmbeddingModel, + isFunctionCallingModel, + isReasoningModel, + isRerankModel, + isVisionModel, + isWebSearchModel +} from '@renderer/config/models' +import { Model, ModelTag, objectEntries } from '@renderer/types' +import { isFreeModel } from '@renderer/utils/model' +import { useCallback, useMemo, useState } from 'react' + +type ModelPredict = (m: Model) => boolean + +const initialTagSelection: Record = { + vision: false, + embedding: false, + reasoning: false, + function_calling: false, + web_search: false, + rerank: false, + free: false +} + +/** + * 标签筛选 hook,仅关注标签过滤逻辑 + */ +export function useModelTagFilter() { + const filterConfig: Record = useMemo( + () => ({ + vision: isVisionModel, + embedding: isEmbeddingModel, + reasoning: isReasoningModel, + function_calling: isFunctionCallingModel, + web_search: isWebSearchModel, + rerank: isRerankModel, + free: isFreeModel + }), + [] + ) + + const [tagSelection, setTagSelection] = useState>(initialTagSelection) + + // 已选中的标签 + const selectedTags = useMemo( + () => + objectEntries(tagSelection) + .filter(([, state]) => state) + .map(([tag]) => tag), + [tagSelection] + ) + + // 切换标签 + const toggleTag = useCallback((tag: ModelTag) => { + setTagSelection((prev) => ({ ...prev, [tag]: !prev[tag] })) + }, []) + + // 重置标签 + const resetTags = useCallback(() => { + setTagSelection(initialTagSelection) + }, []) + + // 根据标签过滤模型 + const tagFilter = useCallback( + (model: Model) => { + if (selectedTags.length === 0) return true + return selectedTags.map((tag) => filterConfig[tag]).every((predict) => predict(model)) + }, + [filterConfig, selectedTags] + ) + + return { + tagSelection, + selectedTags, + tagFilter, + toggleTag, + resetTags + } +} diff --git a/src/renderer/src/components/Popups/SelectModelPopup/popup.tsx b/src/renderer/src/components/Popups/SelectModelPopup/popup.tsx index 6b910b2d78..1cd7926145 100644 --- a/src/renderer/src/components/Popups/SelectModelPopup/popup.tsx +++ b/src/renderer/src/components/Popups/SelectModelPopup/popup.tsx @@ -1,37 +1,19 @@ import { PushpinOutlined } from '@ant-design/icons' import { FreeTrialModelTag } from '@renderer/components/FreeTrialModelTag' import ModelTagsWithLabel from '@renderer/components/ModelTagsWithLabel' -import { - EmbeddingTag, - FreeTag, - ReasoningTag, - RerankerTag, - ToolsCallingTag, - VisionTag, - WebSearchTag -} from '@renderer/components/Tags/Model' import { TopView } from '@renderer/components/TopView' import { DynamicVirtualList, type DynamicVirtualListRef } from '@renderer/components/VirtualList' -import { - getModelLogo, - isEmbeddingModel, - isFunctionCallingModel, - isReasoningModel, - isRerankModel, - isVisionModel, - isWebSearchModel -} from '@renderer/config/models' +import { getModelLogo } from '@renderer/config/models' import { usePinnedModels } from '@renderer/hooks/usePinnedModels' import { useProviders } from '@renderer/hooks/useProvider' import { getModelUniqId } from '@renderer/services/ModelService' -import { Model, ModelTag, ModelType, objectEntries, Provider } from '@renderer/types' +import { Model, ModelType, objectEntries, Provider } from '@renderer/types' import { classNames, filterModelsByKeywords, getFancyProviderName } from '@renderer/utils' -import { getModelTags, isFreeModel } from '@renderer/utils/model' -import { Avatar, Button, Divider, Empty, Flex, Modal, Tooltip } from 'antd' +import { getModelTags } from '@renderer/utils/model' +import { Avatar, Divider, Empty, Modal, Tooltip } from 'antd' import { first, sortBy } from 'lodash' import { Settings2 } from 'lucide-react' import React, { - ReactNode, startTransition, useCallback, useDeferredValue, @@ -44,18 +26,20 @@ import React, { import { useTranslation } from 'react-i18next' import styled from 'styled-components' +import { useModelTagFilter } from './filters' import SelectModelSearchBar from './searchbar' +import TagFilterSection from './TagFilterSection' import { FlatListItem, FlatListModel } from './types' const PAGE_SIZE = 12 const ITEM_HEIGHT = 36 -type ModelPredict = (m: Model) => boolean - interface PopupParams { model?: Model - modelFilter?: (model: Model) => boolean - userFilterDisabled?: boolean + /** Basic model filter */ + filter?: (model: Model) => boolean + /** Show tag filter section */ + showTagFilter?: boolean } interface Props extends PopupParams { @@ -66,7 +50,7 @@ export type FilterType = Exclude | 'free' // const logger = loggerService.withContext('SelectModelPopup') -const PopupContainer: React.FC = ({ model, resolve, modelFilter, userFilterDisabled }) => { +const PopupContainer: React.FC = ({ model, filter: baseFilter, showTagFilter = true, resolve }) => { const { t } = useTranslation() const { providers } = useProviders() const { pinnedModels, togglePinnedModel, loading } = usePinnedModels() @@ -75,11 +59,6 @@ const PopupContainer: React.FC = ({ model, resolve, modelFilter, userFilt const [_searchText, setSearchText] = useState('') const searchText = useDeferredValue(_searchText) - const allModels: Model[] = useMemo( - () => providers.flatMap((p) => p.models).filter(modelFilter ?? (() => true)), - [modelFilter, providers] - ) - // 当前选中的模型ID const currentModelId = model ? getModelUniqId(model) : '' @@ -94,95 +73,16 @@ const PopupContainer: React.FC = ({ model, resolve, modelFilter, userFilt }) }, []) - // 管理用户筛选状态 - /** 从模型列表获取的需要显示的标签 */ - const availableTags = useMemo( - () => - objectEntries(getModelTags(allModels)) - .filter(([, state]) => state) - .map(([tag]) => tag), - [allModels] - ) + const { tagSelection, selectedTags, tagFilter, toggleTag } = useModelTagFilter() - const filterConfig: Record = useMemo( - () => ({ - vision: isVisionModel, - embedding: isEmbeddingModel, - reasoning: isReasoningModel, - function_calling: isFunctionCallingModel, - web_search: isWebSearchModel, - rerank: isRerankModel, - free: isFreeModel - }), - [] - ) + // 计算要显示的可用标签列表 + const availableTags = useMemo(() => { + const models = providers.flatMap((p) => p.models).filter(baseFilter ?? (() => true)) + return objectEntries(getModelTags(models)) + .filter(([, state]) => state) + .map(([tag]) => tag) + }, [providers, baseFilter]) - /** 当前选择的标签,表示是否启用特定tag的筛选 */ - const [filterTags, setFilterTags] = useState>({ - vision: false, - embedding: false, - reasoning: false, - function_calling: false, - web_search: false, - rerank: false, - free: false - }) - const selectedFilterTags = useMemo( - () => - objectEntries(filterTags) - .filter(([, state]) => state) - .map(([tag]) => tag), - [filterTags] - ) - - const userFilter = useCallback( - (model: Model) => { - return selectedFilterTags - .map((tag) => [tag, filterConfig[tag]] as const) - .reduce((prev, [tag, predict]) => { - return prev && (!filterTags[tag] || predict(model)) - }, true) - }, - [filterConfig, filterTags, selectedFilterTags] - ) - - const onClickTag = useCallback((type: ModelTag) => { - startTransition(() => { - setFilterTags((prev) => ({ ...prev, [type]: !prev[type] })) - }) - }, []) - - // 筛选项列表 - const tagsItems: Record = useMemo( - () => ({ - vision: onClickTag('vision')} />, - embedding: onClickTag('embedding')} />, - reasoning: onClickTag('reasoning')} />, - function_calling: ( - onClickTag('function_calling')} - /> - ), - web_search: onClickTag('web_search')} />, - rerank: onClickTag('rerank')} />, - free: onClickTag('free')} /> - }), - [ - filterTags.embedding, - filterTags.free, - filterTags.function_calling, - filterTags.reasoning, - filterTags.rerank, - filterTags.vision, - filterTags.web_search, - onClickTag - ] - ) - - // 要显示的筛选项 - const displayedTags = useMemo(() => availableTags.map((tag) => tagsItems[tag]), [availableTags, tagsItems]) // 根据输入的文本筛选模型 const searchFilter = useCallback( (provider: Provider) => { @@ -237,9 +137,9 @@ const PopupContainer: React.FC = ({ model, resolve, modelFilter, userFilt const items: FlatListItem[] = [] const pinnedModelIds = new Set(pinnedModels) const finalModelFilter = (model: Model) => { - const _userFilter = userFilterDisabled || userFilter(model) - const _modelFilter = modelFilter === undefined || modelFilter(model) - return _userFilter && _modelFilter + const _tagFilter = !showTagFilter || tagFilter(model) + const _baseFilter = baseFilter === undefined || baseFilter(model) + return _tagFilter && _baseFilter } // 添加置顶模型分组(仅在无搜索文本时) @@ -279,11 +179,10 @@ const PopupContainer: React.FC = ({ model, resolve, modelFilter, userFilt name: getFancyProviderName(p), actions: ( -