diff --git a/src/renderer/src/components/Popups/SelectModelPopup/api-model-popup.tsx b/src/renderer/src/components/Popups/SelectModelPopup/api-model-popup.tsx index 168ab563b..17a644bdc 100644 --- a/src/renderer/src/components/Popups/SelectModelPopup/api-model-popup.tsx +++ b/src/renderer/src/components/Popups/SelectModelPopup/api-model-popup.tsx @@ -7,7 +7,7 @@ import { getModelLogo } from '@renderer/config/models' import { useApiModels } from '@renderer/hooks/agents/useModels' import { getModelUniqId } from '@renderer/services/ModelService' import { getProviderNameById } from '@renderer/services/ProviderService' -import { AdaptedApiModel, ApiModel, ApiModelsFilter, ModelType, objectEntries } from '@renderer/types' +import { AdaptedApiModel, ApiModel, ApiModelsFilter, Model, ModelType, objectEntries } from '@renderer/types' import { classNames, filterModelsByKeywords } from '@renderer/utils' import { apiModelAdapter, getModelTags } from '@renderer/utils/model' import { Avatar, Divider, Empty, Modal } from 'antd' @@ -34,8 +34,10 @@ const ITEM_HEIGHT = 36 interface PopupParams { model?: ApiModel - /** Api model filter */ - filter?: ApiModelsFilter + /** Api models filter */ + apiFilter?: ApiModelsFilter + /** model filter */ + modelFilter?: (model: Model) => boolean /** Show tag filter section */ showTagFilter?: boolean } @@ -48,12 +50,12 @@ export type FilterType = Exclude | 'free' // const logger = loggerService.withContext('SelectModelPopup') -const PopupContainer: React.FC = ({ model, filter: baseFilter, showTagFilter = true, resolve }) => { +const PopupContainer: React.FC = ({ model, apiFilter, modelFilter, showTagFilter = true, resolve }) => { const [open, setOpen] = useState(true) const listRef = useRef(null) const [_searchText, setSearchText] = useState('') const searchText = useDeferredValue(_searchText) - const { models, isLoading } = useApiModels(baseFilter) + const { models, isLoading } = useApiModels(apiFilter) const adaptedModels = models.map((model) => apiModelAdapter(model)) // 当前选中的模型ID @@ -128,7 +130,8 @@ const PopupContainer: React.FC = ({ model, filter: baseFilter, showTagFil const items: FlatListApiItem[] = [] const finalModelFilter = (model: AdaptedApiModel) => { const _tagFilter = !showTagFilter || tagFilter(model) - return _tagFilter + const _modelFilter = modelFilter === undefined || modelFilter(model) + return _tagFilter && _modelFilter } // 筛选模型 diff --git a/src/renderer/src/pages/home/components/SelectAgentModelButton.tsx b/src/renderer/src/pages/home/components/SelectAgentModelButton.tsx index bc6ca3ed3..53607cd1d 100644 --- a/src/renderer/src/pages/home/components/SelectAgentModelButton.tsx +++ b/src/renderer/src/pages/home/components/SelectAgentModelButton.tsx @@ -1,9 +1,10 @@ import { Button } from '@heroui/react' import ModelAvatar from '@renderer/components/Avatar/ModelAvatar' import { SelectApiModelPopup } from '@renderer/components/Popups/SelectModelPopup' +import { isEmbeddingModel, isRerankModel, isTextToImageModel } from '@renderer/config/models' import { useApiModel } from '@renderer/hooks/agents/useModel' import { getProviderNameById } from '@renderer/services/ProviderService' -import { AgentBaseWithId, ApiModel, isAgentEntity } from '@renderer/types' +import { AgentBaseWithId, ApiModel, isAgentEntity, Model } from '@renderer/types' import { getModelFilterByAgentType } from '@renderer/utils/agentSession' import { apiModelAdapter } from '@renderer/utils/model' import { ChevronsUpDown } from 'lucide-react' @@ -20,12 +21,13 @@ const SelectAgentModelButton: FC = ({ agent, onSelect, isDisabled }) => { const { t } = useTranslation() const model = useApiModel({ id: agent?.model }) - const modelFilter = isAgentEntity(agent) ? getModelFilterByAgentType(agent.type) : undefined + const apiFilter = isAgentEntity(agent) ? getModelFilterByAgentType(agent.type) : undefined + const modelFilter = (model: Model) => !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model) if (!agent) return null const onSelectModel = async () => { - const selectedModel = await SelectApiModelPopup.show({ model, filter: modelFilter }) + const selectedModel = await SelectApiModelPopup.show({ model, apiFilter: apiFilter, modelFilter }) if (selectedModel && selectedModel.id !== agent.model) { onSelect(selectedModel) }