mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-24 18:50:56 +08:00
feat(model-selection): add api model selection support for agents
- Introduce SelectAgentModelButton component for agent model selection - Add SelectApiModelPopup for displaying and selecting API models - Implement apiModelAdapter to convert API models to adapted format - Add model filtering by agent type in agentSession utils - Update model select components to use new API model selection
This commit is contained in:
parent
a419aed404
commit
de9cb2fbdb
@ -6,7 +6,7 @@ import { FC, MouseEvent } from 'react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
import IndicatorLight from './IndicatorLight'
|
||||
import SelectModelPopup from './Popups/SelectModelPopup'
|
||||
import { SelectModelPopup } from './Popups/SelectModelPopup'
|
||||
import CustomTag from './Tags/CustomTag'
|
||||
|
||||
interface Props {
|
||||
|
||||
@ -3,7 +3,7 @@ import { Button, Tooltip, TooltipProps } from 'antd'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
|
||||
import ModelAvatar from './Avatar/ModelAvatar'
|
||||
import SelectModelPopup from './Popups/SelectModelPopup'
|
||||
import { SelectModelPopup } from './Popups/SelectModelPopup'
|
||||
|
||||
type Props = {
|
||||
model: Model
|
||||
|
||||
@ -0,0 +1,503 @@
|
||||
import { FreeTrialModelTag } from '@renderer/components/FreeTrialModelTag'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import ModelTagsWithLabel from '@renderer/components/ModelTagsWithLabel'
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { DynamicVirtualList, type DynamicVirtualListRef } from '@renderer/components/VirtualList'
|
||||
import { getModelLogo } from '@renderer/config/models'
|
||||
import { useApiModels } from '@renderer/hooks/agents/useModels'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { getProviderNameById } from '@renderer/services/ProviderService'
|
||||
import { AdaptedApiModel, ApiModel, ApiModelsFilter, ModelType, objectEntries } from '@renderer/types'
|
||||
import { classNames, filterModelsByKeywords } from '@renderer/utils'
|
||||
import { apiModelAdapter, getModelTags } from '@renderer/utils/model'
|
||||
import { Avatar, Divider, Empty, Modal } from 'antd'
|
||||
import { first, groupBy, sortBy } from 'lodash'
|
||||
import React, {
|
||||
startTransition,
|
||||
useCallback,
|
||||
useDeferredValue,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState
|
||||
} from 'react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
import { useModelTagFilter } from './filters'
|
||||
import SelectModelSearchBar from './searchbar'
|
||||
import TagFilterSection from './TagFilterSection'
|
||||
import { FlatListApiItem, FlatListApiModel } from './types'
|
||||
|
||||
const PAGE_SIZE = 12
|
||||
const ITEM_HEIGHT = 36
|
||||
|
||||
interface PopupParams {
|
||||
model?: ApiModel
|
||||
/** Api model filter */
|
||||
filter?: ApiModelsFilter
|
||||
/** Show tag filter section */
|
||||
showTagFilter?: boolean
|
||||
}
|
||||
|
||||
interface Props extends PopupParams {
|
||||
resolve: (value: ApiModel | undefined) => void
|
||||
}
|
||||
|
||||
export type FilterType = Exclude<ModelType, 'text'> | 'free'
|
||||
|
||||
// const logger = loggerService.withContext('SelectModelPopup')
|
||||
|
||||
const PopupContainer: React.FC<Props> = ({ model, filter: baseFilter, showTagFilter = true, resolve }) => {
|
||||
const [open, setOpen] = useState(true)
|
||||
const listRef = useRef<DynamicVirtualListRef>(null)
|
||||
const [_searchText, setSearchText] = useState('')
|
||||
const searchText = useDeferredValue(_searchText)
|
||||
const { models, isLoading } = useApiModels(baseFilter)
|
||||
const adaptedModels = models.map((model) => apiModelAdapter(model))
|
||||
|
||||
// 当前选中的模型ID
|
||||
const currentModelId = model ? model.id : ''
|
||||
|
||||
// 管理滚动和焦点状态
|
||||
const [focusedItemKey, _setFocusedItemKey] = useState('')
|
||||
const [isMouseOver, setIsMouseOver] = useState(false)
|
||||
const preventScrollToIndex = useRef(false)
|
||||
|
||||
const setFocusedItemKey = useCallback((key: string) => {
|
||||
startTransition(() => {
|
||||
_setFocusedItemKey(key)
|
||||
})
|
||||
}, [])
|
||||
|
||||
const { tagSelection, selectedTags, tagFilter, toggleTag } = useModelTagFilter()
|
||||
|
||||
// 计算要显示的可用标签列表
|
||||
const availableTags = useMemo(() => {
|
||||
return objectEntries(getModelTags(adaptedModels))
|
||||
.filter(([, state]) => state)
|
||||
.map(([tag]) => tag)
|
||||
}, [adaptedModels])
|
||||
|
||||
// 根据输入的文本筛选模型
|
||||
const searchFilter = useCallback(
|
||||
(models: AdaptedApiModel[]) => {
|
||||
if (searchText.trim()) {
|
||||
models = filterModelsByKeywords(searchText, models)
|
||||
}
|
||||
|
||||
return sortBy(models, ['group', 'name'])
|
||||
},
|
||||
[searchText]
|
||||
)
|
||||
|
||||
// 创建模型列表项
|
||||
const createModelItem = useCallback(
|
||||
(model: AdaptedApiModel): FlatListApiModel => {
|
||||
const modelId = getModelUniqId(model)
|
||||
const isCherryAi = model.provider === 'cherryai'
|
||||
|
||||
return {
|
||||
key: modelId,
|
||||
type: 'model',
|
||||
name: (
|
||||
<ModelName>
|
||||
<HStack alignItems="center">{model.name}</HStack>
|
||||
{isCherryAi && <FreeTrialModelTag model={model} showLabel={false} />}
|
||||
</ModelName>
|
||||
),
|
||||
tags: (
|
||||
<TagsContainer>
|
||||
<ModelTagsWithLabel model={model} size={11} showLabel={true} />
|
||||
</TagsContainer>
|
||||
),
|
||||
icon: (
|
||||
<Avatar src={getModelLogo(model.id || '')} size={24}>
|
||||
{first(model.name) || 'M'}
|
||||
</Avatar>
|
||||
),
|
||||
model,
|
||||
isSelected: modelId === currentModelId
|
||||
}
|
||||
},
|
||||
[currentModelId]
|
||||
)
|
||||
|
||||
// 构建扁平化列表数据,并派生出可选择的模型项
|
||||
const { listItems, modelItems } = useMemo(() => {
|
||||
const items: FlatListApiItem[] = []
|
||||
const finalModelFilter = (model: AdaptedApiModel) => {
|
||||
const _tagFilter = !showTagFilter || tagFilter(model)
|
||||
return _tagFilter
|
||||
}
|
||||
|
||||
// 筛选模型
|
||||
const filteredModels = searchFilter(adaptedModels).filter(finalModelFilter)
|
||||
|
||||
// 按 provider 分组
|
||||
const groups = groupBy(filteredModels, (model) => model.provider) as Record<string, AdaptedApiModel[]>
|
||||
|
||||
objectEntries(groups).forEach(([key, models]) => {
|
||||
items.push({
|
||||
key: key ?? 'Unknown',
|
||||
type: 'group',
|
||||
name: getProviderNameById(key ?? 'Unknown'),
|
||||
isSelected: false
|
||||
})
|
||||
items.push(...models.map((m) => createModelItem(m)))
|
||||
})
|
||||
|
||||
// 获取可选择的模型项(过滤掉分组标题)
|
||||
const modelItems = items.filter((item) => item.type === 'model')
|
||||
return { listItems: items, modelItems }
|
||||
}, [searchFilter, adaptedModels, showTagFilter, tagFilter, createModelItem])
|
||||
|
||||
const listHeight = useMemo(() => {
|
||||
return Math.min(PAGE_SIZE, listItems.length) * ITEM_HEIGHT
|
||||
}, [listItems.length])
|
||||
|
||||
// 处理程序化滚动(加载、搜索开始、搜索清空、tag 筛选)
|
||||
useLayoutEffect(() => {
|
||||
if (isLoading) return
|
||||
|
||||
if (preventScrollToIndex.current) {
|
||||
preventScrollToIndex.current = false
|
||||
return
|
||||
}
|
||||
|
||||
let targetItemKey: string | undefined
|
||||
|
||||
// 启动搜索或 tag 筛选时,滚动到第一个 item
|
||||
if (searchText || selectedTags.length > 0) {
|
||||
targetItemKey = modelItems[0]?.key
|
||||
}
|
||||
// 初始加载或清空搜索时,滚动到 selected item
|
||||
else {
|
||||
targetItemKey = modelItems.find((item) => item.isSelected)?.key
|
||||
}
|
||||
|
||||
if (targetItemKey) {
|
||||
setFocusedItemKey(targetItemKey)
|
||||
const index = listItems.findIndex((item) => item.key === targetItemKey)
|
||||
if (index >= 0) {
|
||||
// FIXME: 手动计算偏移量,给 scroller 增加了 scrollPaddingStart 之后,
|
||||
// scrollToIndex 不能准确滚动到 item 中心,但是又需要 padding 来改善体验。
|
||||
const targetScrollTop = index * ITEM_HEIGHT - listHeight / 2
|
||||
listRef.current?.scrollToOffset(targetScrollTop, {
|
||||
align: 'start',
|
||||
behavior: 'auto'
|
||||
})
|
||||
}
|
||||
}
|
||||
}, [searchText, listItems, modelItems, setFocusedItemKey, listHeight, selectedTags.length, isLoading])
|
||||
|
||||
const handleItemClick = useCallback(
|
||||
(item: FlatListApiItem) => {
|
||||
if (item.type === 'model') {
|
||||
resolve(item.model.origin)
|
||||
setOpen(false)
|
||||
}
|
||||
},
|
||||
[resolve]
|
||||
)
|
||||
|
||||
// 处理键盘导航
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent) => {
|
||||
const modelCount = modelItems.length
|
||||
|
||||
if (!open || modelCount === 0 || e.isComposing) return
|
||||
|
||||
// 键盘操作时禁用鼠标 hover
|
||||
if (['ArrowUp', 'ArrowDown', 'PageUp', 'PageDown', 'Enter', 'Escape'].includes(e.key)) {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
setIsMouseOver(false)
|
||||
}
|
||||
|
||||
// 当前聚焦的模型 index
|
||||
const currentIndex = modelItems.findIndex((item) => item.key === focusedItemKey)
|
||||
|
||||
let nextIndex = -1
|
||||
|
||||
switch (e.key) {
|
||||
case 'ArrowUp': {
|
||||
nextIndex = (currentIndex < 0 ? 0 : currentIndex - 1 + modelCount) % modelCount
|
||||
break
|
||||
}
|
||||
case 'ArrowDown': {
|
||||
nextIndex = (currentIndex < 0 ? 0 : currentIndex + 1) % modelCount
|
||||
break
|
||||
}
|
||||
case 'PageUp': {
|
||||
nextIndex = Math.max(0, (currentIndex < 0 ? 0 : currentIndex) - PAGE_SIZE)
|
||||
break
|
||||
}
|
||||
case 'PageDown': {
|
||||
nextIndex = Math.min(modelCount - 1, (currentIndex < 0 ? 0 : currentIndex) + PAGE_SIZE)
|
||||
break
|
||||
}
|
||||
case 'Enter':
|
||||
if (currentIndex >= 0) {
|
||||
const selectedItem = modelItems[currentIndex]
|
||||
if (selectedItem) {
|
||||
handleItemClick(selectedItem)
|
||||
}
|
||||
}
|
||||
break
|
||||
case 'Escape':
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
setOpen(false)
|
||||
resolve(undefined)
|
||||
break
|
||||
}
|
||||
|
||||
// 没有键盘导航,直接返回
|
||||
if (nextIndex < 0) return
|
||||
|
||||
const nextKey = modelItems[nextIndex]?.key || ''
|
||||
if (nextKey) {
|
||||
setFocusedItemKey(nextKey)
|
||||
const index = listItems.findIndex((item) => item.key === nextKey)
|
||||
if (index >= 0) {
|
||||
listRef.current?.scrollToIndex(index, { align: 'auto' })
|
||||
}
|
||||
}
|
||||
},
|
||||
[modelItems, open, focusedItemKey, resolve, handleItemClick, setFocusedItemKey, listItems]
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
window.addEventListener('keydown', handleKeyDown)
|
||||
return () => window.removeEventListener('keydown', handleKeyDown)
|
||||
}, [handleKeyDown])
|
||||
|
||||
const onCancel = useCallback(() => {
|
||||
setOpen(false)
|
||||
}, [])
|
||||
|
||||
const onAfterClose = useCallback(async () => {
|
||||
resolve(undefined)
|
||||
SelectApiModelPopup.hide()
|
||||
}, [resolve])
|
||||
|
||||
const getItemKey = useCallback((index: number) => listItems[index].key, [listItems])
|
||||
const estimateSize = useCallback(() => ITEM_HEIGHT, [])
|
||||
const isSticky = useCallback((index: number) => listItems[index].type === 'group', [listItems])
|
||||
|
||||
const rowRenderer = useCallback(
|
||||
(item: FlatListApiItem) => {
|
||||
const isFocused = item.key === focusedItemKey
|
||||
if (item.type === 'group') {
|
||||
return (
|
||||
<GroupItem>
|
||||
{item.name}
|
||||
{item.actions}
|
||||
</GroupItem>
|
||||
)
|
||||
}
|
||||
return (
|
||||
<ModelItem
|
||||
className={classNames({
|
||||
focused: isFocused,
|
||||
selected: item.isSelected
|
||||
})}
|
||||
onClick={() => handleItemClick(item)}
|
||||
onMouseOver={() => !isFocused && setFocusedItemKey(item.key)}>
|
||||
<ModelItemLeft>
|
||||
{item.icon}
|
||||
{item.name}
|
||||
{item.tags}
|
||||
</ModelItemLeft>
|
||||
</ModelItem>
|
||||
)
|
||||
},
|
||||
[focusedItemKey, handleItemClick, setFocusedItemKey]
|
||||
)
|
||||
|
||||
return (
|
||||
<Modal
|
||||
centered
|
||||
open={open}
|
||||
onCancel={onCancel}
|
||||
afterClose={onAfterClose}
|
||||
width={600}
|
||||
transitionName="animation-move-down"
|
||||
styles={{
|
||||
content: {
|
||||
borderRadius: 20,
|
||||
padding: 0,
|
||||
overflow: 'hidden',
|
||||
paddingBottom: 16
|
||||
},
|
||||
body: {
|
||||
maxHeight: 'inherit',
|
||||
padding: 0
|
||||
}
|
||||
}}
|
||||
closeIcon={null}
|
||||
footer={null}>
|
||||
{/* 搜索框 */}
|
||||
<SelectModelSearchBar onSearch={setSearchText} />
|
||||
<Divider style={{ margin: 0, marginTop: 4, borderBlockStartWidth: 0.5 }} />
|
||||
{showTagFilter && (
|
||||
<>
|
||||
<TagFilterSection availableTags={availableTags} tagSelection={tagSelection} onToggleTag={toggleTag} />
|
||||
<Divider style={{ margin: 0, borderBlockStartWidth: 0.5 }} />
|
||||
</>
|
||||
)}
|
||||
|
||||
{listItems.length > 0 ? (
|
||||
<ListContainer onMouseMove={() => !isMouseOver && setIsMouseOver(true)}>
|
||||
<DynamicVirtualList
|
||||
ref={listRef}
|
||||
list={listItems}
|
||||
size={listHeight}
|
||||
getItemKey={getItemKey}
|
||||
estimateSize={estimateSize}
|
||||
isSticky={isSticky}
|
||||
scrollPaddingStart={ITEM_HEIGHT} // 留出 sticky header 高度
|
||||
overscan={5}
|
||||
scrollerStyle={{ pointerEvents: isMouseOver ? 'auto' : 'none' }}>
|
||||
{rowRenderer}
|
||||
</DynamicVirtualList>
|
||||
</ListContainer>
|
||||
) : (
|
||||
<EmptyState>
|
||||
<Empty image={Empty.PRESENTED_IMAGE_SIMPLE} />
|
||||
</EmptyState>
|
||||
)}
|
||||
</Modal>
|
||||
)
|
||||
}
|
||||
|
||||
const ListContainer = styled.div`
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
`
|
||||
|
||||
const GroupItem = styled.div`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
position: relative;
|
||||
line-height: 1;
|
||||
font-size: 12px;
|
||||
font-weight: normal;
|
||||
height: ${ITEM_HEIGHT}px;
|
||||
padding: 5px 18px;
|
||||
color: var(--color-text-3);
|
||||
z-index: 1;
|
||||
background: var(--modal-background);
|
||||
|
||||
.action-icon {
|
||||
cursor: pointer;
|
||||
opacity: 0;
|
||||
transition: opacity 0.2s;
|
||||
|
||||
&:hover {
|
||||
opacity: 1 !important;
|
||||
}
|
||||
}
|
||||
&:hover .action-icon {
|
||||
opacity: 0.3;
|
||||
}
|
||||
`
|
||||
|
||||
const ModelItem = styled.div`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
position: relative;
|
||||
font-size: 14px;
|
||||
padding: 0 8px;
|
||||
margin: 1px 8px;
|
||||
height: ${ITEM_HEIGHT - 2}px;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.1s ease;
|
||||
|
||||
&.focused {
|
||||
background-color: var(--color-background-mute);
|
||||
}
|
||||
|
||||
&.selected {
|
||||
&::before {
|
||||
content: '';
|
||||
display: block;
|
||||
position: absolute;
|
||||
left: -1px;
|
||||
top: 13%;
|
||||
width: 3px;
|
||||
height: 74%;
|
||||
background: var(--color-primary-soft);
|
||||
border-radius: 8px;
|
||||
}
|
||||
}
|
||||
|
||||
.pin-icon {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
&:hover .pin-icon {
|
||||
opacity: 0.3;
|
||||
}
|
||||
`
|
||||
|
||||
const ModelItemLeft = styled.div`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
overflow: hidden;
|
||||
padding-right: 26px;
|
||||
|
||||
.anticon {
|
||||
min-width: auto;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
`
|
||||
|
||||
const ModelName = styled.div`
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
justify-content: space-between;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
flex: 1;
|
||||
margin: 0 8px;
|
||||
min-width: 0;
|
||||
gap: 5px;
|
||||
`
|
||||
|
||||
const TagsContainer = styled.div`
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
min-width: 80px;
|
||||
max-width: 180px;
|
||||
overflow: hidden;
|
||||
flex-shrink: 0;
|
||||
`
|
||||
|
||||
const EmptyState = styled.div`
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 200px;
|
||||
`
|
||||
|
||||
const TopViewKey = 'SelectModelPopup'
|
||||
|
||||
export class SelectApiModelPopup {
|
||||
static topviewId = 0
|
||||
static hide() {
|
||||
TopView.hide(TopViewKey)
|
||||
}
|
||||
|
||||
static show(params: PopupParams) {
|
||||
return new Promise<ApiModel | undefined>((resolve) => {
|
||||
TopView.show(<PopupContainer {...params} resolve={(v) => resolve(v)} />, TopViewKey)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,3 +1,2 @@
|
||||
import { SelectModelPopup } from './popup'
|
||||
|
||||
export default SelectModelPopup
|
||||
export { SelectApiModelPopup } from './api-model-popup'
|
||||
export { SelectModelPopup } from './popup'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { Model } from '@renderer/types'
|
||||
import { AdaptedApiModel, Model } from '@renderer/types'
|
||||
import { ReactNode } from 'react'
|
||||
|
||||
/**
|
||||
@ -44,3 +44,17 @@ export type FlatListModel = FlatListBaseItem & {
|
||||
* 扁平化列表项
|
||||
*/
|
||||
export type FlatListItem = FlatListGroup | FlatListModel
|
||||
|
||||
/**
|
||||
* 模型列表项
|
||||
*/
|
||||
export type FlatListApiModel = FlatListBaseItem & {
|
||||
type: 'model'
|
||||
model: AdaptedApiModel
|
||||
tags?: ReactNode
|
||||
}
|
||||
|
||||
/**
|
||||
* 扁平化列表项
|
||||
*/
|
||||
export type FlatListApiItem = FlatListGroup | FlatListApiModel
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import { ApiModelLabel } from '@renderer/components/ApiModelLabel'
|
||||
import { NavbarHeader } from '@renderer/components/app/Navbar'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import SearchPopup from '@renderer/components/Popups/SearchPopup'
|
||||
@ -21,6 +20,7 @@ import { FC } from 'react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
import AssistantsDrawer from './components/AssistantsDrawer'
|
||||
import SelectAgentModelButton from './components/SelectAgentModelButton'
|
||||
import SelectModelButton from './components/SelectModelButton'
|
||||
import UpdateAppButton from './components/UpdateAppButton'
|
||||
|
||||
@ -74,7 +74,7 @@ const HeaderNavbar: FC<Props> = ({ activeAssistant, setActiveAssistant, activeTo
|
||||
|
||||
return (
|
||||
<NavbarHeader className="home-navbar">
|
||||
<HStack alignItems="center">
|
||||
<div className="flex flex-1 items-center">
|
||||
{showAssistants && (
|
||||
<Tooltip title={t('navbar.hide_sidebar')} mouseEnterDelay={0.8}>
|
||||
<NavbarIcon onClick={toggleShowAssistants}>
|
||||
@ -103,12 +103,10 @@ const HeaderNavbar: FC<Props> = ({ activeAssistant, setActiveAssistant, activeTo
|
||||
)}
|
||||
</AnimatePresence>
|
||||
{activeTopicOrSession === 'topic' && <SelectModelButton assistant={assistant} />}
|
||||
{/* TODO: Show a select model button for agent. */}
|
||||
{/* FIXME: models endpoint doesn't return all models, so cannot found. */}
|
||||
{activeTopicOrSession === 'session' && (
|
||||
<ApiModelLabel classNames={{ container: 'text-xs' }} model={agentModel} />
|
||||
{activeTopicOrSession === 'session' && agent && agentModel && (
|
||||
<SelectAgentModelButton agent={agent} model={agentModel} />
|
||||
)}
|
||||
</HStack>
|
||||
</div>
|
||||
<HStack alignItems="center" gap={8}>
|
||||
<UpdateAppButton />
|
||||
<Tooltip title={t('navbar.expand')} mouseEnterDelay={0.8}>
|
||||
|
||||
@ -3,7 +3,7 @@ import { loggerService } from '@logger'
|
||||
import { CopyIcon, DeleteIcon, EditIcon, RefreshIcon } from '@renderer/components/Icons'
|
||||
import ObsidianExportPopup from '@renderer/components/Popups/ObsidianExportPopup'
|
||||
import SaveToKnowledgePopup from '@renderer/components/Popups/SaveToKnowledgePopup'
|
||||
import SelectModelPopup from '@renderer/components/Popups/SelectModelPopup'
|
||||
import { SelectModelPopup } from '@renderer/components/Popups/SelectModelPopup'
|
||||
import { isEmbeddingModel, isRerankModel, isVisionModel } from '@renderer/config/models'
|
||||
import {
|
||||
DEFAULT_MESSAGE_MENUBAR_SCOPE,
|
||||
|
||||
@ -0,0 +1,70 @@
|
||||
import { Button } from '@heroui/react'
|
||||
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
|
||||
import { SelectApiModelPopup } from '@renderer/components/Popups/SelectModelPopup'
|
||||
import { useUpdateAgent } from '@renderer/hooks/agents/useUpdateAgent'
|
||||
import { AgentEntity, ApiModel } from '@renderer/types'
|
||||
import { getModelFilterByAgentType } from '@renderer/utils/agentSession'
|
||||
import { apiModelAdapter } from '@renderer/utils/model'
|
||||
import { ChevronsUpDown } from 'lucide-react'
|
||||
import { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
|
||||
interface Props {
|
||||
agent: AgentEntity
|
||||
model: ApiModel
|
||||
}
|
||||
|
||||
const SelectAgentModelButton: FC<Props> = ({ agent, model }) => {
|
||||
const { t } = useTranslation()
|
||||
const update = useUpdateAgent()
|
||||
|
||||
const modelFilter = getModelFilterByAgentType(agent.type)
|
||||
|
||||
if (!agent) return null
|
||||
|
||||
const onSelectModel = async () => {
|
||||
const selectedModel = await SelectApiModelPopup.show({ model, filter: modelFilter })
|
||||
if (selectedModel) {
|
||||
update({ id: agent.id, model: selectedModel.id })
|
||||
}
|
||||
}
|
||||
|
||||
const providerName = model.provider_name
|
||||
|
||||
return (
|
||||
<DropdownButton size="sm" onPress={onSelectModel}>
|
||||
<ButtonContent>
|
||||
<ModelAvatar model={apiModelAdapter(model)} size={20} />
|
||||
<ModelName>
|
||||
{model ? model.name : t('button.select_model')} {providerName ? ' | ' + providerName : ''}
|
||||
</ModelName>
|
||||
</ButtonContent>
|
||||
<ChevronsUpDown size={14} color="var(--color-icon)" />
|
||||
</DropdownButton>
|
||||
)
|
||||
}
|
||||
|
||||
const DropdownButton = styled(Button)`
|
||||
font-size: 11px;
|
||||
border-radius: 15px;
|
||||
padding: 13px 5px;
|
||||
-webkit-app-region: none;
|
||||
box-shadow: none;
|
||||
background-color: transparent;
|
||||
border: 1px solid transparent;
|
||||
margin-top: 1px;
|
||||
`
|
||||
|
||||
const ButtonContent = styled.div`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
`
|
||||
|
||||
const ModelName = styled.span`
|
||||
font-weight: 500;
|
||||
margin-right: -2px;
|
||||
`
|
||||
|
||||
export default SelectAgentModelButton
|
||||
@ -1,5 +1,5 @@
|
||||
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
|
||||
import SelectModelPopup from '@renderer/components/Popups/SelectModelPopup'
|
||||
import { SelectModelPopup } from '@renderer/components/Popups/SelectModelPopup'
|
||||
import { isLocalAi } from '@renderer/config/env'
|
||||
import { isEmbeddingModel, isRerankModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { useAssistant } from '@renderer/hooks/useAssistant'
|
||||
|
||||
@ -3,7 +3,7 @@ import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
|
||||
import EditableNumber from '@renderer/components/EditableNumber'
|
||||
import { DeleteIcon, ResetIcon } from '@renderer/components/Icons'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import SelectModelPopup from '@renderer/components/Popups/SelectModelPopup'
|
||||
import { SelectModelPopup } from '@renderer/components/Popups/SelectModelPopup'
|
||||
import Selector from '@renderer/components/Selector'
|
||||
import { DEFAULT_CONTEXTCOUNT, DEFAULT_TEMPERATURE, MAX_CONTEXT_COUNT } from '@renderer/config/constant'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { Model } from '@types'
|
||||
import { z } from 'zod'
|
||||
|
||||
import { ProviderTypeSchema } from './provider'
|
||||
@ -35,3 +36,8 @@ export const ApiModelsResponseSchema = z.object({
|
||||
export type ApiModel = z.infer<typeof ApiModelSchema>
|
||||
export type ApiModelsFilter = z.infer<typeof ApiModelsFilterSchema>
|
||||
export type ApiModelsResponse = z.infer<typeof ApiModelsResponseSchema>
|
||||
|
||||
// Adapted
|
||||
export type AdaptedApiModel = Model & {
|
||||
origin: ApiModel
|
||||
}
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import { AgentType, ApiModelsFilter } from '@renderer/types'
|
||||
|
||||
const SESSION_TOPIC_PREFIX = 'agent-session:'
|
||||
|
||||
export const buildAgentSessionTopicId = (sessionId: string): string => {
|
||||
@ -11,3 +13,14 @@ export const isAgentSessionTopicId = (topicId: string): boolean => {
|
||||
export const extractAgentSessionIdFromTopicId = (topicId: string): string => {
|
||||
return topicId.replace(SESSION_TOPIC_PREFIX, '')
|
||||
}
|
||||
|
||||
export const getModelFilterByAgentType = (type: AgentType): ApiModelsFilter => {
|
||||
switch (type) {
|
||||
case 'claude-code':
|
||||
return {
|
||||
providerType: 'anthropic'
|
||||
}
|
||||
default:
|
||||
return {}
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,7 +74,7 @@ function getProviderSearchString(provider: Provider) {
|
||||
* @param provider 可选的 Provider 对象,用于生成完整模型名称
|
||||
* @returns 过滤后的模型数组
|
||||
*/
|
||||
export function filterModelsByKeywords(keywords: string, models: Model[], provider?: Provider): Model[] {
|
||||
export function filterModelsByKeywords<T extends Model>(keywords: string, models: T[], provider?: Provider): T[] {
|
||||
const keywordsArray = keywords.toLowerCase().split(/\s+/).filter(Boolean)
|
||||
return models.filter((model) => matchKeywordsInModel(keywordsArray, model, provider))
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@ import {
|
||||
isVisionModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { Model, ModelTag, objectKeys } from '@renderer/types'
|
||||
import { AdaptedApiModel, ApiModel, Model, ModelTag, objectKeys } from '@renderer/types'
|
||||
|
||||
/**
|
||||
* 获取模型标签的状态
|
||||
@ -70,3 +70,13 @@ export function isFreeModel(model: Model) {
|
||||
|
||||
return (model.id + model.name).toLocaleLowerCase().includes('free')
|
||||
}
|
||||
|
||||
export const apiModelAdapter = (model: ApiModel): AdaptedApiModel => {
|
||||
return {
|
||||
id: model.provider_model_id ?? model.id,
|
||||
provider: model.provider ?? '',
|
||||
name: model.name,
|
||||
group: '',
|
||||
origin: model
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user