feat(model-selection): add api model selection support for agents

- Introduce SelectAgentModelButton component for agent model selection
- Add SelectApiModelPopup for displaying and selecting API models
- Implement apiModelAdapter to convert API models to adapted format
- Add model filtering by agent type in agentSession utils
- Update model select components to use new API model selection
This commit is contained in:
icarus 2025-09-26 04:04:23 +08:00
parent a419aed404
commit de9cb2fbdb
14 changed files with 631 additions and 18 deletions

View File

@ -6,7 +6,7 @@ import { FC, MouseEvent } from 'react'
import styled from 'styled-components'
import IndicatorLight from './IndicatorLight'
import SelectModelPopup from './Popups/SelectModelPopup'
import { SelectModelPopup } from './Popups/SelectModelPopup'
import CustomTag from './Tags/CustomTag'
interface Props {

View File

@ -3,7 +3,7 @@ import { Button, Tooltip, TooltipProps } from 'antd'
import { useCallback, useMemo } from 'react'
import ModelAvatar from './Avatar/ModelAvatar'
import SelectModelPopup from './Popups/SelectModelPopup'
import { SelectModelPopup } from './Popups/SelectModelPopup'
type Props = {
model: Model

View File

@ -0,0 +1,503 @@
import { FreeTrialModelTag } from '@renderer/components/FreeTrialModelTag'
import { HStack } from '@renderer/components/Layout'
import ModelTagsWithLabel from '@renderer/components/ModelTagsWithLabel'
import { TopView } from '@renderer/components/TopView'
import { DynamicVirtualList, type DynamicVirtualListRef } from '@renderer/components/VirtualList'
import { getModelLogo } from '@renderer/config/models'
import { useApiModels } from '@renderer/hooks/agents/useModels'
import { getModelUniqId } from '@renderer/services/ModelService'
import { getProviderNameById } from '@renderer/services/ProviderService'
import { AdaptedApiModel, ApiModel, ApiModelsFilter, ModelType, objectEntries } from '@renderer/types'
import { classNames, filterModelsByKeywords } from '@renderer/utils'
import { apiModelAdapter, getModelTags } from '@renderer/utils/model'
import { Avatar, Divider, Empty, Modal } from 'antd'
import { first, groupBy, sortBy } from 'lodash'
import React, {
startTransition,
useCallback,
useDeferredValue,
useEffect,
useLayoutEffect,
useMemo,
useRef,
useState
} from 'react'
import styled from 'styled-components'
import { useModelTagFilter } from './filters'
import SelectModelSearchBar from './searchbar'
import TagFilterSection from './TagFilterSection'
import { FlatListApiItem, FlatListApiModel } from './types'
const PAGE_SIZE = 12
const ITEM_HEIGHT = 36
interface PopupParams {
model?: ApiModel
/** Api model filter */
filter?: ApiModelsFilter
/** Show tag filter section */
showTagFilter?: boolean
}
interface Props extends PopupParams {
resolve: (value: ApiModel | undefined) => void
}
export type FilterType = Exclude<ModelType, 'text'> | 'free'
// const logger = loggerService.withContext('SelectModelPopup')
const PopupContainer: React.FC<Props> = ({ model, filter: baseFilter, showTagFilter = true, resolve }) => {
const [open, setOpen] = useState(true)
const listRef = useRef<DynamicVirtualListRef>(null)
const [_searchText, setSearchText] = useState('')
const searchText = useDeferredValue(_searchText)
const { models, isLoading } = useApiModels(baseFilter)
const adaptedModels = models.map((model) => apiModelAdapter(model))
// 当前选中的模型ID
const currentModelId = model ? model.id : ''
// 管理滚动和焦点状态
const [focusedItemKey, _setFocusedItemKey] = useState('')
const [isMouseOver, setIsMouseOver] = useState(false)
const preventScrollToIndex = useRef(false)
const setFocusedItemKey = useCallback((key: string) => {
startTransition(() => {
_setFocusedItemKey(key)
})
}, [])
const { tagSelection, selectedTags, tagFilter, toggleTag } = useModelTagFilter()
// 计算要显示的可用标签列表
const availableTags = useMemo(() => {
return objectEntries(getModelTags(adaptedModels))
.filter(([, state]) => state)
.map(([tag]) => tag)
}, [adaptedModels])
// 根据输入的文本筛选模型
const searchFilter = useCallback(
(models: AdaptedApiModel[]) => {
if (searchText.trim()) {
models = filterModelsByKeywords(searchText, models)
}
return sortBy(models, ['group', 'name'])
},
[searchText]
)
// 创建模型列表项
const createModelItem = useCallback(
(model: AdaptedApiModel): FlatListApiModel => {
const modelId = getModelUniqId(model)
const isCherryAi = model.provider === 'cherryai'
return {
key: modelId,
type: 'model',
name: (
<ModelName>
<HStack alignItems="center">{model.name}</HStack>
{isCherryAi && <FreeTrialModelTag model={model} showLabel={false} />}
</ModelName>
),
tags: (
<TagsContainer>
<ModelTagsWithLabel model={model} size={11} showLabel={true} />
</TagsContainer>
),
icon: (
<Avatar src={getModelLogo(model.id || '')} size={24}>
{first(model.name) || 'M'}
</Avatar>
),
model,
isSelected: modelId === currentModelId
}
},
[currentModelId]
)
// 构建扁平化列表数据,并派生出可选择的模型项
const { listItems, modelItems } = useMemo(() => {
const items: FlatListApiItem[] = []
const finalModelFilter = (model: AdaptedApiModel) => {
const _tagFilter = !showTagFilter || tagFilter(model)
return _tagFilter
}
// 筛选模型
const filteredModels = searchFilter(adaptedModels).filter(finalModelFilter)
// 按 provider 分组
const groups = groupBy(filteredModels, (model) => model.provider) as Record<string, AdaptedApiModel[]>
objectEntries(groups).forEach(([key, models]) => {
items.push({
key: key ?? 'Unknown',
type: 'group',
name: getProviderNameById(key ?? 'Unknown'),
isSelected: false
})
items.push(...models.map((m) => createModelItem(m)))
})
// 获取可选择的模型项(过滤掉分组标题)
const modelItems = items.filter((item) => item.type === 'model')
return { listItems: items, modelItems }
}, [searchFilter, adaptedModels, showTagFilter, tagFilter, createModelItem])
const listHeight = useMemo(() => {
return Math.min(PAGE_SIZE, listItems.length) * ITEM_HEIGHT
}, [listItems.length])
// 处理程序化滚动加载、搜索开始、搜索清空、tag 筛选)
useLayoutEffect(() => {
if (isLoading) return
if (preventScrollToIndex.current) {
preventScrollToIndex.current = false
return
}
let targetItemKey: string | undefined
// 启动搜索或 tag 筛选时,滚动到第一个 item
if (searchText || selectedTags.length > 0) {
targetItemKey = modelItems[0]?.key
}
// 初始加载或清空搜索时,滚动到 selected item
else {
targetItemKey = modelItems.find((item) => item.isSelected)?.key
}
if (targetItemKey) {
setFocusedItemKey(targetItemKey)
const index = listItems.findIndex((item) => item.key === targetItemKey)
if (index >= 0) {
// FIXME: 手动计算偏移量,给 scroller 增加了 scrollPaddingStart 之后,
// scrollToIndex 不能准确滚动到 item 中心,但是又需要 padding 来改善体验。
const targetScrollTop = index * ITEM_HEIGHT - listHeight / 2
listRef.current?.scrollToOffset(targetScrollTop, {
align: 'start',
behavior: 'auto'
})
}
}
}, [searchText, listItems, modelItems, setFocusedItemKey, listHeight, selectedTags.length, isLoading])
const handleItemClick = useCallback(
(item: FlatListApiItem) => {
if (item.type === 'model') {
resolve(item.model.origin)
setOpen(false)
}
},
[resolve]
)
// 处理键盘导航
const handleKeyDown = useCallback(
(e: KeyboardEvent) => {
const modelCount = modelItems.length
if (!open || modelCount === 0 || e.isComposing) return
// 键盘操作时禁用鼠标 hover
if (['ArrowUp', 'ArrowDown', 'PageUp', 'PageDown', 'Enter', 'Escape'].includes(e.key)) {
e.preventDefault()
e.stopPropagation()
setIsMouseOver(false)
}
// 当前聚焦的模型 index
const currentIndex = modelItems.findIndex((item) => item.key === focusedItemKey)
let nextIndex = -1
switch (e.key) {
case 'ArrowUp': {
nextIndex = (currentIndex < 0 ? 0 : currentIndex - 1 + modelCount) % modelCount
break
}
case 'ArrowDown': {
nextIndex = (currentIndex < 0 ? 0 : currentIndex + 1) % modelCount
break
}
case 'PageUp': {
nextIndex = Math.max(0, (currentIndex < 0 ? 0 : currentIndex) - PAGE_SIZE)
break
}
case 'PageDown': {
nextIndex = Math.min(modelCount - 1, (currentIndex < 0 ? 0 : currentIndex) + PAGE_SIZE)
break
}
case 'Enter':
if (currentIndex >= 0) {
const selectedItem = modelItems[currentIndex]
if (selectedItem) {
handleItemClick(selectedItem)
}
}
break
case 'Escape':
e.preventDefault()
e.stopPropagation()
setOpen(false)
resolve(undefined)
break
}
// 没有键盘导航,直接返回
if (nextIndex < 0) return
const nextKey = modelItems[nextIndex]?.key || ''
if (nextKey) {
setFocusedItemKey(nextKey)
const index = listItems.findIndex((item) => item.key === nextKey)
if (index >= 0) {
listRef.current?.scrollToIndex(index, { align: 'auto' })
}
}
},
[modelItems, open, focusedItemKey, resolve, handleItemClick, setFocusedItemKey, listItems]
)
useEffect(() => {
window.addEventListener('keydown', handleKeyDown)
return () => window.removeEventListener('keydown', handleKeyDown)
}, [handleKeyDown])
const onCancel = useCallback(() => {
setOpen(false)
}, [])
const onAfterClose = useCallback(async () => {
resolve(undefined)
SelectApiModelPopup.hide()
}, [resolve])
const getItemKey = useCallback((index: number) => listItems[index].key, [listItems])
const estimateSize = useCallback(() => ITEM_HEIGHT, [])
const isSticky = useCallback((index: number) => listItems[index].type === 'group', [listItems])
const rowRenderer = useCallback(
(item: FlatListApiItem) => {
const isFocused = item.key === focusedItemKey
if (item.type === 'group') {
return (
<GroupItem>
{item.name}
{item.actions}
</GroupItem>
)
}
return (
<ModelItem
className={classNames({
focused: isFocused,
selected: item.isSelected
})}
onClick={() => handleItemClick(item)}
onMouseOver={() => !isFocused && setFocusedItemKey(item.key)}>
<ModelItemLeft>
{item.icon}
{item.name}
{item.tags}
</ModelItemLeft>
</ModelItem>
)
},
[focusedItemKey, handleItemClick, setFocusedItemKey]
)
return (
<Modal
centered
open={open}
onCancel={onCancel}
afterClose={onAfterClose}
width={600}
transitionName="animation-move-down"
styles={{
content: {
borderRadius: 20,
padding: 0,
overflow: 'hidden',
paddingBottom: 16
},
body: {
maxHeight: 'inherit',
padding: 0
}
}}
closeIcon={null}
footer={null}>
{/* 搜索框 */}
<SelectModelSearchBar onSearch={setSearchText} />
<Divider style={{ margin: 0, marginTop: 4, borderBlockStartWidth: 0.5 }} />
{showTagFilter && (
<>
<TagFilterSection availableTags={availableTags} tagSelection={tagSelection} onToggleTag={toggleTag} />
<Divider style={{ margin: 0, borderBlockStartWidth: 0.5 }} />
</>
)}
{listItems.length > 0 ? (
<ListContainer onMouseMove={() => !isMouseOver && setIsMouseOver(true)}>
<DynamicVirtualList
ref={listRef}
list={listItems}
size={listHeight}
getItemKey={getItemKey}
estimateSize={estimateSize}
isSticky={isSticky}
scrollPaddingStart={ITEM_HEIGHT} // 留出 sticky header 高度
overscan={5}
scrollerStyle={{ pointerEvents: isMouseOver ? 'auto' : 'none' }}>
{rowRenderer}
</DynamicVirtualList>
</ListContainer>
) : (
<EmptyState>
<Empty image={Empty.PRESENTED_IMAGE_SIMPLE} />
</EmptyState>
)}
</Modal>
)
}
const ListContainer = styled.div`
position: relative;
overflow: hidden;
`
const GroupItem = styled.div`
display: flex;
align-items: center;
gap: 8px;
position: relative;
line-height: 1;
font-size: 12px;
font-weight: normal;
height: ${ITEM_HEIGHT}px;
padding: 5px 18px;
color: var(--color-text-3);
z-index: 1;
background: var(--modal-background);
.action-icon {
cursor: pointer;
opacity: 0;
transition: opacity 0.2s;
&:hover {
opacity: 1 !important;
}
}
&:hover .action-icon {
opacity: 0.3;
}
`
const ModelItem = styled.div`
display: flex;
align-items: center;
justify-content: space-between;
position: relative;
font-size: 14px;
padding: 0 8px;
margin: 1px 8px;
height: ${ITEM_HEIGHT - 2}px;
border-radius: 8px;
cursor: pointer;
transition: background-color 0.1s ease;
&.focused {
background-color: var(--color-background-mute);
}
&.selected {
&::before {
content: '';
display: block;
position: absolute;
left: -1px;
top: 13%;
width: 3px;
height: 74%;
background: var(--color-primary-soft);
border-radius: 8px;
}
}
.pin-icon {
opacity: 0;
}
&:hover .pin-icon {
opacity: 0.3;
}
`
const ModelItemLeft = styled.div`
display: flex;
align-items: center;
width: 100%;
overflow: hidden;
padding-right: 26px;
.anticon {
min-width: auto;
flex-shrink: 0;
}
`
const ModelName = styled.div`
display: flex;
flex-direction: row;
justify-content: space-between;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
flex: 1;
margin: 0 8px;
min-width: 0;
gap: 5px;
`
const TagsContainer = styled.div`
display: flex;
justify-content: flex-end;
min-width: 80px;
max-width: 180px;
overflow: hidden;
flex-shrink: 0;
`
const EmptyState = styled.div`
display: flex;
justify-content: center;
align-items: center;
height: 200px;
`
const TopViewKey = 'SelectModelPopup'
export class SelectApiModelPopup {
static topviewId = 0
static hide() {
TopView.hide(TopViewKey)
}
static show(params: PopupParams) {
return new Promise<ApiModel | undefined>((resolve) => {
TopView.show(<PopupContainer {...params} resolve={(v) => resolve(v)} />, TopViewKey)
})
}
}

View File

@ -1,3 +1,2 @@
import { SelectModelPopup } from './popup'
export default SelectModelPopup
export { SelectApiModelPopup } from './api-model-popup'
export { SelectModelPopup } from './popup'

View File

@ -1,4 +1,4 @@
import { Model } from '@renderer/types'
import { AdaptedApiModel, Model } from '@renderer/types'
import { ReactNode } from 'react'
/**
@ -44,3 +44,17 @@ export type FlatListModel = FlatListBaseItem & {
*
*/
export type FlatListItem = FlatListGroup | FlatListModel
/**
*
*/
export type FlatListApiModel = FlatListBaseItem & {
type: 'model'
model: AdaptedApiModel
tags?: ReactNode
}
/**
*
*/
export type FlatListApiItem = FlatListGroup | FlatListApiModel

View File

@ -1,4 +1,3 @@
import { ApiModelLabel } from '@renderer/components/ApiModelLabel'
import { NavbarHeader } from '@renderer/components/app/Navbar'
import { HStack } from '@renderer/components/Layout'
import SearchPopup from '@renderer/components/Popups/SearchPopup'
@ -21,6 +20,7 @@ import { FC } from 'react'
import styled from 'styled-components'
import AssistantsDrawer from './components/AssistantsDrawer'
import SelectAgentModelButton from './components/SelectAgentModelButton'
import SelectModelButton from './components/SelectModelButton'
import UpdateAppButton from './components/UpdateAppButton'
@ -74,7 +74,7 @@ const HeaderNavbar: FC<Props> = ({ activeAssistant, setActiveAssistant, activeTo
return (
<NavbarHeader className="home-navbar">
<HStack alignItems="center">
<div className="flex flex-1 items-center">
{showAssistants && (
<Tooltip title={t('navbar.hide_sidebar')} mouseEnterDelay={0.8}>
<NavbarIcon onClick={toggleShowAssistants}>
@ -103,12 +103,10 @@ const HeaderNavbar: FC<Props> = ({ activeAssistant, setActiveAssistant, activeTo
)}
</AnimatePresence>
{activeTopicOrSession === 'topic' && <SelectModelButton assistant={assistant} />}
{/* TODO: Show a select model button for agent. */}
{/* FIXME: models endpoint doesn't return all models, so cannot found. */}
{activeTopicOrSession === 'session' && (
<ApiModelLabel classNames={{ container: 'text-xs' }} model={agentModel} />
{activeTopicOrSession === 'session' && agent && agentModel && (
<SelectAgentModelButton agent={agent} model={agentModel} />
)}
</HStack>
</div>
<HStack alignItems="center" gap={8}>
<UpdateAppButton />
<Tooltip title={t('navbar.expand')} mouseEnterDelay={0.8}>

View File

@ -3,7 +3,7 @@ import { loggerService } from '@logger'
import { CopyIcon, DeleteIcon, EditIcon, RefreshIcon } from '@renderer/components/Icons'
import ObsidianExportPopup from '@renderer/components/Popups/ObsidianExportPopup'
import SaveToKnowledgePopup from '@renderer/components/Popups/SaveToKnowledgePopup'
import SelectModelPopup from '@renderer/components/Popups/SelectModelPopup'
import { SelectModelPopup } from '@renderer/components/Popups/SelectModelPopup'
import { isEmbeddingModel, isRerankModel, isVisionModel } from '@renderer/config/models'
import {
DEFAULT_MESSAGE_MENUBAR_SCOPE,

View File

@ -0,0 +1,70 @@
import { Button } from '@heroui/react'
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
import { SelectApiModelPopup } from '@renderer/components/Popups/SelectModelPopup'
import { useUpdateAgent } from '@renderer/hooks/agents/useUpdateAgent'
import { AgentEntity, ApiModel } from '@renderer/types'
import { getModelFilterByAgentType } from '@renderer/utils/agentSession'
import { apiModelAdapter } from '@renderer/utils/model'
import { ChevronsUpDown } from 'lucide-react'
import { FC } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
interface Props {
agent: AgentEntity
model: ApiModel
}
const SelectAgentModelButton: FC<Props> = ({ agent, model }) => {
const { t } = useTranslation()
const update = useUpdateAgent()
const modelFilter = getModelFilterByAgentType(agent.type)
if (!agent) return null
const onSelectModel = async () => {
const selectedModel = await SelectApiModelPopup.show({ model, filter: modelFilter })
if (selectedModel) {
update({ id: agent.id, model: selectedModel.id })
}
}
const providerName = model.provider_name
return (
<DropdownButton size="sm" onPress={onSelectModel}>
<ButtonContent>
<ModelAvatar model={apiModelAdapter(model)} size={20} />
<ModelName>
{model ? model.name : t('button.select_model')} {providerName ? ' | ' + providerName : ''}
</ModelName>
</ButtonContent>
<ChevronsUpDown size={14} color="var(--color-icon)" />
</DropdownButton>
)
}
const DropdownButton = styled(Button)`
font-size: 11px;
border-radius: 15px;
padding: 13px 5px;
-webkit-app-region: none;
box-shadow: none;
background-color: transparent;
border: 1px solid transparent;
margin-top: 1px;
`
const ButtonContent = styled.div`
display: flex;
align-items: center;
gap: 6px;
`
const ModelName = styled.span`
font-weight: 500;
margin-right: -2px;
`
export default SelectAgentModelButton

View File

@ -1,5 +1,5 @@
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
import SelectModelPopup from '@renderer/components/Popups/SelectModelPopup'
import { SelectModelPopup } from '@renderer/components/Popups/SelectModelPopup'
import { isLocalAi } from '@renderer/config/env'
import { isEmbeddingModel, isRerankModel, isWebSearchModel } from '@renderer/config/models'
import { useAssistant } from '@renderer/hooks/useAssistant'

View File

@ -3,7 +3,7 @@ import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
import EditableNumber from '@renderer/components/EditableNumber'
import { DeleteIcon, ResetIcon } from '@renderer/components/Icons'
import { HStack } from '@renderer/components/Layout'
import SelectModelPopup from '@renderer/components/Popups/SelectModelPopup'
import { SelectModelPopup } from '@renderer/components/Popups/SelectModelPopup'
import Selector from '@renderer/components/Selector'
import { DEFAULT_CONTEXTCOUNT, DEFAULT_TEMPERATURE, MAX_CONTEXT_COUNT } from '@renderer/config/constant'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'

View File

@ -1,3 +1,4 @@
import { Model } from '@types'
import { z } from 'zod'
import { ProviderTypeSchema } from './provider'
@ -35,3 +36,8 @@ export const ApiModelsResponseSchema = z.object({
export type ApiModel = z.infer<typeof ApiModelSchema>
export type ApiModelsFilter = z.infer<typeof ApiModelsFilterSchema>
export type ApiModelsResponse = z.infer<typeof ApiModelsResponseSchema>
// Adapted
export type AdaptedApiModel = Model & {
origin: ApiModel
}

View File

@ -1,3 +1,5 @@
import { AgentType, ApiModelsFilter } from '@renderer/types'
const SESSION_TOPIC_PREFIX = 'agent-session:'
export const buildAgentSessionTopicId = (sessionId: string): string => {
@ -11,3 +13,14 @@ export const isAgentSessionTopicId = (topicId: string): boolean => {
export const extractAgentSessionIdFromTopicId = (topicId: string): string => {
return topicId.replace(SESSION_TOPIC_PREFIX, '')
}
export const getModelFilterByAgentType = (type: AgentType): ApiModelsFilter => {
switch (type) {
case 'claude-code':
return {
providerType: 'anthropic'
}
default:
return {}
}
}

View File

@ -74,7 +74,7 @@ function getProviderSearchString(provider: Provider) {
* @param provider Provider
* @returns
*/
export function filterModelsByKeywords(keywords: string, models: Model[], provider?: Provider): Model[] {
export function filterModelsByKeywords<T extends Model>(keywords: string, models: T[], provider?: Provider): T[] {
const keywordsArray = keywords.toLowerCase().split(/\s+/).filter(Boolean)
return models.filter((model) => matchKeywordsInModel(keywordsArray, model, provider))
}

View File

@ -6,7 +6,7 @@ import {
isVisionModel,
isWebSearchModel
} from '@renderer/config/models'
import { Model, ModelTag, objectKeys } from '@renderer/types'
import { AdaptedApiModel, ApiModel, Model, ModelTag, objectKeys } from '@renderer/types'
/**
*
@ -70,3 +70,13 @@ export function isFreeModel(model: Model) {
return (model.id + model.name).toLocaleLowerCase().includes('free')
}
export const apiModelAdapter = (model: ApiModel): AdaptedApiModel => {
return {
id: model.provider_model_id ?? model.id,
provider: model.provider ?? '',
name: model.name,
group: '',
origin: model
}
}