mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 14:41:24 +08:00
feat(model-selection): add model type filtering to exclude embedding/rerank/image models
Add modelFilter parameter to SelectApiModelPopup to exclude embedding, rerank and text-to-image models from selection. This ensures only appropriate models are shown based on agent type requirements.
This commit is contained in:
parent
42435e8f76
commit
1b705edb06
@ -7,7 +7,7 @@ import { getModelLogo } from '@renderer/config/models'
|
|||||||
import { useApiModels } from '@renderer/hooks/agents/useModels'
|
import { useApiModels } from '@renderer/hooks/agents/useModels'
|
||||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||||
import { getProviderNameById } from '@renderer/services/ProviderService'
|
import { getProviderNameById } from '@renderer/services/ProviderService'
|
||||||
import { AdaptedApiModel, ApiModel, ApiModelsFilter, ModelType, objectEntries } from '@renderer/types'
|
import { AdaptedApiModel, ApiModel, ApiModelsFilter, Model, ModelType, objectEntries } from '@renderer/types'
|
||||||
import { classNames, filterModelsByKeywords } from '@renderer/utils'
|
import { classNames, filterModelsByKeywords } from '@renderer/utils'
|
||||||
import { apiModelAdapter, getModelTags } from '@renderer/utils/model'
|
import { apiModelAdapter, getModelTags } from '@renderer/utils/model'
|
||||||
import { Avatar, Divider, Empty, Modal } from 'antd'
|
import { Avatar, Divider, Empty, Modal } from 'antd'
|
||||||
@ -34,8 +34,10 @@ const ITEM_HEIGHT = 36
|
|||||||
|
|
||||||
interface PopupParams {
|
interface PopupParams {
|
||||||
model?: ApiModel
|
model?: ApiModel
|
||||||
/** Api model filter */
|
/** Api models filter */
|
||||||
filter?: ApiModelsFilter
|
apiFilter?: ApiModelsFilter
|
||||||
|
/** model filter */
|
||||||
|
modelFilter?: (model: Model) => boolean
|
||||||
/** Show tag filter section */
|
/** Show tag filter section */
|
||||||
showTagFilter?: boolean
|
showTagFilter?: boolean
|
||||||
}
|
}
|
||||||
@ -48,12 +50,12 @@ export type FilterType = Exclude<ModelType, 'text'> | 'free'
|
|||||||
|
|
||||||
// const logger = loggerService.withContext('SelectModelPopup')
|
// const logger = loggerService.withContext('SelectModelPopup')
|
||||||
|
|
||||||
const PopupContainer: React.FC<Props> = ({ model, filter: baseFilter, showTagFilter = true, resolve }) => {
|
const PopupContainer: React.FC<Props> = ({ model, apiFilter, modelFilter, showTagFilter = true, resolve }) => {
|
||||||
const [open, setOpen] = useState(true)
|
const [open, setOpen] = useState(true)
|
||||||
const listRef = useRef<DynamicVirtualListRef>(null)
|
const listRef = useRef<DynamicVirtualListRef>(null)
|
||||||
const [_searchText, setSearchText] = useState('')
|
const [_searchText, setSearchText] = useState('')
|
||||||
const searchText = useDeferredValue(_searchText)
|
const searchText = useDeferredValue(_searchText)
|
||||||
const { models, isLoading } = useApiModels(baseFilter)
|
const { models, isLoading } = useApiModels(apiFilter)
|
||||||
const adaptedModels = models.map((model) => apiModelAdapter(model))
|
const adaptedModels = models.map((model) => apiModelAdapter(model))
|
||||||
|
|
||||||
// 当前选中的模型ID
|
// 当前选中的模型ID
|
||||||
@ -128,7 +130,8 @@ const PopupContainer: React.FC<Props> = ({ model, filter: baseFilter, showTagFil
|
|||||||
const items: FlatListApiItem[] = []
|
const items: FlatListApiItem[] = []
|
||||||
const finalModelFilter = (model: AdaptedApiModel) => {
|
const finalModelFilter = (model: AdaptedApiModel) => {
|
||||||
const _tagFilter = !showTagFilter || tagFilter(model)
|
const _tagFilter = !showTagFilter || tagFilter(model)
|
||||||
return _tagFilter
|
const _modelFilter = modelFilter === undefined || modelFilter(model)
|
||||||
|
return _tagFilter && _modelFilter
|
||||||
}
|
}
|
||||||
|
|
||||||
// 筛选模型
|
// 筛选模型
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
import { Button } from '@heroui/react'
|
import { Button } from '@heroui/react'
|
||||||
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
|
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
|
||||||
import { SelectApiModelPopup } from '@renderer/components/Popups/SelectModelPopup'
|
import { SelectApiModelPopup } from '@renderer/components/Popups/SelectModelPopup'
|
||||||
|
import { isEmbeddingModel, isRerankModel, isTextToImageModel } from '@renderer/config/models'
|
||||||
import { useApiModel } from '@renderer/hooks/agents/useModel'
|
import { useApiModel } from '@renderer/hooks/agents/useModel'
|
||||||
import { getProviderNameById } from '@renderer/services/ProviderService'
|
import { getProviderNameById } from '@renderer/services/ProviderService'
|
||||||
import { AgentBaseWithId, ApiModel, isAgentEntity } from '@renderer/types'
|
import { AgentBaseWithId, ApiModel, isAgentEntity, Model } from '@renderer/types'
|
||||||
import { getModelFilterByAgentType } from '@renderer/utils/agentSession'
|
import { getModelFilterByAgentType } from '@renderer/utils/agentSession'
|
||||||
import { apiModelAdapter } from '@renderer/utils/model'
|
import { apiModelAdapter } from '@renderer/utils/model'
|
||||||
import { ChevronsUpDown } from 'lucide-react'
|
import { ChevronsUpDown } from 'lucide-react'
|
||||||
@ -20,12 +21,13 @@ const SelectAgentModelButton: FC<Props> = ({ agent, onSelect, isDisabled }) => {
|
|||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const model = useApiModel({ id: agent?.model })
|
const model = useApiModel({ id: agent?.model })
|
||||||
|
|
||||||
const modelFilter = isAgentEntity(agent) ? getModelFilterByAgentType(agent.type) : undefined
|
const apiFilter = isAgentEntity(agent) ? getModelFilterByAgentType(agent.type) : undefined
|
||||||
|
const modelFilter = (model: Model) => !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model)
|
||||||
|
|
||||||
if (!agent) return null
|
if (!agent) return null
|
||||||
|
|
||||||
const onSelectModel = async () => {
|
const onSelectModel = async () => {
|
||||||
const selectedModel = await SelectApiModelPopup.show({ model, filter: modelFilter })
|
const selectedModel = await SelectApiModelPopup.show({ model, apiFilter: apiFilter, modelFilter })
|
||||||
if (selectedModel && selectedModel.id !== agent.model) {
|
if (selectedModel && selectedModel.id !== agent.model) {
|
||||||
onSelect(selectedModel)
|
onSelect(selectedModel)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user