fix: SelectModelPopup scrolling behaviour (#5812)

* fix: focus the selected or the first item on searching

* refactor: remove unnecessary deferred values

* refactor: add a hook usePinnedModels for pinned models

* refactor: make the definition more consistent with other popups

* refactor: improve state management, improve scrolling behaviour

* fix: avoid potential modulo-by-zero

* fix: type error

* fix: async loading pinned models
This commit is contained in:
one 2025-05-10 13:48:25 +08:00 committed by GitHub
parent f9f0b857ae
commit 5ae6562f5b
6 changed files with 374 additions and 162 deletions

View File

@ -0,0 +1,41 @@
import { useMemo, useReducer } from 'react'
import { initialScrollState, scrollReducer } from './reducer'
import { FlatListItem, ScrollTrigger } from './types'
/**
* hook
*/
export function useScrollState() {
const [state, dispatch] = useReducer(scrollReducer, initialScrollState)
const actions = useMemo(
() => ({
setFocusedItemKey: (key: string) => dispatch({ type: 'SET_FOCUSED_ITEM_KEY', payload: key }),
setScrollTrigger: (trigger: ScrollTrigger) => dispatch({ type: 'SET_SCROLL_TRIGGER', payload: trigger }),
setLastScrollOffset: (offset: number) => dispatch({ type: 'SET_LAST_SCROLL_OFFSET', payload: offset }),
setStickyGroup: (group: FlatListItem | null) => dispatch({ type: 'SET_STICKY_GROUP', payload: group }),
setIsMouseOver: (isMouseOver: boolean) => dispatch({ type: 'SET_IS_MOUSE_OVER', payload: isMouseOver }),
focusNextItem: (modelItems: FlatListItem[], step: number) =>
dispatch({ type: 'FOCUS_NEXT_ITEM', payload: { modelItems, step } }),
focusPage: (modelItems: FlatListItem[], currentIndex: number, step: number) =>
dispatch({ type: 'FOCUS_PAGE', payload: { modelItems, currentIndex, step } }),
searchChanged: (searchText: string) => dispatch({ type: 'SEARCH_CHANGED', payload: { searchText } }),
updateOnListChange: (modelItems: FlatListItem[]) =>
dispatch({ type: 'UPDATE_ON_LIST_CHANGE', payload: { modelItems } }),
initScroll: () => dispatch({ type: 'INIT_SCROLL' })
}),
[]
)
return {
// 状态
focusedItemKey: state.focusedItemKey,
scrollTrigger: state.scrollTrigger,
lastScrollOffset: state.lastScrollOffset,
stickyGroup: state.stickyGroup,
isMouseOver: state.isMouseOver,
// 操作
...actions
}
}

View File

@ -0,0 +1,3 @@
import { SelectModelPopup } from './popup'
export default SelectModelPopup

View File

@ -1,7 +1,9 @@
import { PushpinOutlined } from '@ant-design/icons' import { PushpinOutlined } from '@ant-design/icons'
import { HStack } from '@renderer/components/Layout'
import ModelTagsWithLabel from '@renderer/components/ModelTagsWithLabel'
import { TopView } from '@renderer/components/TopView' import { TopView } from '@renderer/components/TopView'
import { getModelLogo, isEmbeddingModel, isRerankModel } from '@renderer/config/models' import { getModelLogo, isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import db from '@renderer/databases' import { usePinnedModels } from '@renderer/hooks/usePinnedModels'
import { useProviders } from '@renderer/hooks/useProvider' import { useProviders } from '@renderer/hooks/useProvider'
import { getModelUniqId } from '@renderer/services/ModelService' import { getModelUniqId } from '@renderer/services/ModelService'
import { Model } from '@renderer/types' import { Model } from '@renderer/types'
@ -15,101 +17,61 @@ import { useTranslation } from 'react-i18next'
import { FixedSizeList } from 'react-window' import { FixedSizeList } from 'react-window'
import styled from 'styled-components' import styled from 'styled-components'
import { HStack } from '../Layout' import { useScrollState } from './hook'
import ModelTagsWithLabel from '../ModelTagsWithLabel' import { FlatListItem } from './types'
const PAGE_SIZE = 9 const PAGE_SIZE = 9
const ITEM_HEIGHT = 36 const ITEM_HEIGHT = 36
// 列表项类型,组名也作为列表项 interface PopupParams {
type ListItemType = 'group' | 'model'
// 滚动触发来源类型
type ScrollTrigger = 'initial' | 'search' | 'keyboard' | 'none'
// 扁平化列表项接口
interface FlatListItem {
key: string
type: ListItemType
icon?: React.ReactNode
name: React.ReactNode
tags?: React.ReactNode
model?: Model
isPinned?: boolean
isSelected?: boolean
}
interface Props {
model?: Model model?: Model
} }
interface PopupContainerProps extends Props { interface Props extends PopupParams {
resolve: (value: Model | undefined) => void resolve: (value: Model | undefined) => void
} }
const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => { const PopupContainer: React.FC<Props> = ({ model, resolve }) => {
const { t } = useTranslation() const { t } = useTranslation()
const { providers } = useProviders() const { providers } = useProviders()
const { pinnedModels, togglePinnedModel, loading: loadingPinnedModels } = usePinnedModels()
const [open, setOpen] = useState(true) const [open, setOpen] = useState(true)
const inputRef = useRef<InputRef>(null) const inputRef = useRef<InputRef>(null)
const listRef = useRef<FixedSizeList>(null) const listRef = useRef<FixedSizeList>(null)
const [_searchText, setSearchText] = useState('') const [_searchText, setSearchText] = useState('')
const searchText = useDeferredValue(_searchText) const searchText = useDeferredValue(_searchText)
const [isMouseOver, setIsMouseOver] = useState(false)
const [pinnedModels, setPinnedModels] = useState<string[]>([])
const [_focusedItemKey, setFocusedItemKey] = useState<string>('')
const focusedItemKey = useDeferredValue(_focusedItemKey)
const [_stickyGroup, setStickyGroup] = useState<FlatListItem | null>(null)
const stickyGroup = useDeferredValue(_stickyGroup)
const firstGroupRef = useRef<FlatListItem | null>(null)
const scrollTriggerRef = useRef<ScrollTrigger>('initial')
const lastScrollOffsetRef = useRef(0)
// 当前选中的模型ID // 当前选中的模型ID
const currentModelId = model ? getModelUniqId(model) : '' const currentModelId = model ? getModelUniqId(model) : ''
// 加载置顶模型列表 // 管理滚动和焦点状态
useEffect(() => { const {
const loadPinnedModels = async () => { focusedItemKey,
const setting = await db.settings.get('pinned:models') scrollTrigger,
const savedPinnedModels = setting?.value || [] lastScrollOffset,
stickyGroup: _stickyGroup,
isMouseOver,
setFocusedItemKey,
setScrollTrigger,
setLastScrollOffset,
setStickyGroup,
setIsMouseOver,
focusNextItem,
focusPage,
searchChanged,
updateOnListChange,
initScroll
} = useScrollState()
// Filter out invalid pinned models const stickyGroup = useDeferredValue(_stickyGroup)
const allModelIds = providers.flatMap((p) => p.models || []).map((m) => getModelUniqId(m)) const firstGroupRef = useRef<FlatListItem | null>(null)
const validPinnedModels = savedPinnedModels.filter((id) => allModelIds.includes(id))
// Update storage if there were invalid models
if (validPinnedModels.length !== savedPinnedModels.length) {
await db.settings.put({ id: 'pinned:models', value: validPinnedModels })
}
setPinnedModels(sortBy(validPinnedModels))
}
try {
loadPinnedModels()
} catch (error) {
console.error('Failed to load pinned models', error)
setPinnedModels([])
}
}, [providers])
const togglePin = useCallback( const togglePin = useCallback(
async (modelId: string) => { async (modelId: string) => {
const newPinnedModels = pinnedModels.includes(modelId) await togglePinnedModel(modelId)
? pinnedModels.filter((id) => id !== modelId) setScrollTrigger('none') // pin操作不触发滚动
: [...pinnedModels, modelId]
try {
await db.settings.put({ id: 'pinned:models', value: newPinnedModels })
setPinnedModels(sortBy(newPinnedModels))
// Pin操作不触发滚动
scrollTriggerRef.current = 'none'
} catch (error) {
console.error('Failed to update pinned models', error)
}
}, },
[pinnedModels] [togglePinnedModel, setScrollTrigger]
) )
// 根据输入的文本筛选模型 // 根据输入的文本筛选模型
@ -222,6 +184,16 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
return items return items
}, [providers, getFilteredModels, pinnedModels, searchText, t, createModelItem]) }, [providers, getFilteredModels, pinnedModels, searchText, t, createModelItem])
// 获取可选择的模型项(过滤掉分组标题)
const modelItems = useMemo(() => {
return listItems.filter((item) => item.type === 'model')
}, [listItems])
// 当搜索文本变化时更新滚动触发器
useEffect(() => {
searchChanged(searchText)
}, [searchText, searchChanged])
// 基于滚动位置更新sticky分组标题 // 基于滚动位置更新sticky分组标题
const updateStickyGroup = useCallback( const updateStickyGroup = useCallback(
(scrollOffset?: number) => { (scrollOffset?: number) => {
@ -231,7 +203,7 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
} }
// 基于滚动位置计算当前可见的第一个项的索引 // 基于滚动位置计算当前可见的第一个项的索引
const estimatedIndex = Math.floor((scrollOffset ?? lastScrollOffsetRef.current) / ITEM_HEIGHT) const estimatedIndex = Math.floor((scrollOffset ?? lastScrollOffset) / ITEM_HEIGHT)
// 从该索引向前查找最近的分组标题 // 从该索引向前查找最近的分组标题
for (let i = estimatedIndex - 1; i >= 0; i--) { for (let i = estimatedIndex - 1; i >= 0; i--) {
@ -242,9 +214,9 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
} }
// 找不到则使用第一个分组标题 // 找不到则使用第一个分组标题
setStickyGroup(firstGroupRef.current ?? null) setStickyGroup(firstGroupRef.current)
}, },
[listItems] [listItems, lastScrollOffset, setStickyGroup]
) )
// 在listItems变化时更新sticky group // 在listItems变化时更新sticky group
@ -255,67 +227,46 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
// 处理列表滚动事件更新lastScrollOffset并更新sticky分组 // 处理列表滚动事件更新lastScrollOffset并更新sticky分组
const handleScroll = useCallback( const handleScroll = useCallback(
({ scrollOffset }) => { ({ scrollOffset }) => {
lastScrollOffsetRef.current = scrollOffset setLastScrollOffset(scrollOffset)
updateStickyGroup(scrollOffset) updateStickyGroup(scrollOffset)
}, },
[updateStickyGroup] [updateStickyGroup, setLastScrollOffset]
) )
// 获取可选择的模型项(过滤掉分组标题) // 在列表项更新时,更新焦点项
const modelItems = useMemo(() => {
return listItems.filter((item) => item.type === 'model')
}, [listItems])
// 搜索文本变化时设置滚动来源
useEffect(() => { useEffect(() => {
if (searchText.trim() !== '') { updateOnListChange(modelItems)
scrollTriggerRef.current = 'search' }, [modelItems, updateOnListChange])
setFocusedItemKey('')
}
}, [searchText])
// 设置初始聚焦项以触发滚动
useEffect(() => {
if (scrollTriggerRef.current === 'initial' || scrollTriggerRef.current === 'search') {
const selectedItem = modelItems.find((item) => item.isSelected)
if (selectedItem) {
setFocusedItemKey(selectedItem.key)
} else if (scrollTriggerRef.current === 'initial' && modelItems.length > 0) {
setFocusedItemKey(modelItems[0].key)
}
// 其余情况不设置focusedItemKey
}
}, [modelItems])
// 滚动到聚焦项 // 滚动到聚焦项
useEffect(() => { useEffect(() => {
if (scrollTriggerRef.current === 'none' || !focusedItemKey) return if (scrollTrigger === 'none' || !focusedItemKey) return
const index = listItems.findIndex((item) => item.key === focusedItemKey) const index = listItems.findIndex((item) => item.key === focusedItemKey)
if (index < 0) return if (index < 0) return
// 根据触发源决定滚动对齐方式 // 根据触发源决定滚动对齐方式
const alignment = scrollTriggerRef.current === 'keyboard' ? 'auto' : 'center' const alignment = scrollTrigger === 'keyboard' ? 'auto' : 'center'
listRef.current?.scrollToItem(index, alignment) listRef.current?.scrollToItem(index, alignment)
// 滚动后重置触发器 // 滚动后重置触发器
scrollTriggerRef.current = 'none' setScrollTrigger('none')
}, [focusedItemKey, listItems]) }, [focusedItemKey, scrollTrigger, listItems, setScrollTrigger])
const handleItemClick = useCallback( const handleItemClick = useCallback(
(item: FlatListItem) => { (item: FlatListItem) => {
if (item.type === 'model') { if (item.type === 'model') {
scrollTriggerRef.current = 'none' setScrollTrigger('initial')
resolve(item.model) resolve(item.model)
setOpen(false) setOpen(false)
} }
}, },
[resolve] [resolve, setScrollTrigger]
) )
// 处理键盘导航 // 处理键盘导航
useEffect(() => { const handleKeyDown = useCallback(
const handleKeyDown = (e: KeyboardEvent) => { (e: KeyboardEvent) => {
if (!open) return if (!open) return
if (modelItems.length === 0) { if (modelItems.length === 0) {
@ -329,43 +280,21 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
setIsMouseOver(false) setIsMouseOver(false)
} }
const getCurrentIndex = (currentKey: string) => { const currentIndex = modelItems.findIndex((item) => item.key === focusedItemKey)
const currentIndex = modelItems.findIndex((item) => item.key === currentKey) const normalizedIndex = currentIndex < 0 ? 0 : currentIndex
return currentIndex < 0 ? 0 : currentIndex
}
switch (e.key) { switch (e.key) {
case 'ArrowUp': case 'ArrowUp':
scrollTriggerRef.current = 'keyboard' focusNextItem(modelItems, -1)
setFocusedItemKey((prev) => {
const currentIndex = getCurrentIndex(prev)
const nextIndex = (currentIndex - 1 + modelItems.length) % modelItems.length
return modelItems[nextIndex].key
})
break break
case 'ArrowDown': case 'ArrowDown':
scrollTriggerRef.current = 'keyboard' focusNextItem(modelItems, 1)
setFocusedItemKey((prev) => {
const currentIndex = getCurrentIndex(prev)
const nextIndex = (currentIndex + 1) % modelItems.length
return modelItems[nextIndex].key
})
break break
case 'PageUp': case 'PageUp':
scrollTriggerRef.current = 'keyboard' focusPage(modelItems, normalizedIndex, -PAGE_SIZE)
setFocusedItemKey((prev) => {
const currentIndex = getCurrentIndex(prev)
const nextIndex = Math.max(currentIndex - PAGE_SIZE, 0)
return modelItems[nextIndex].key
})
break break
case 'PageDown': case 'PageDown':
scrollTriggerRef.current = 'keyboard' focusPage(modelItems, normalizedIndex, PAGE_SIZE)
setFocusedItemKey((prev) => {
const currentIndex = getCurrentIndex(prev)
const nextIndex = Math.min(currentIndex + PAGE_SIZE, modelItems.length - 1)
return modelItems[nextIndex].key
})
break break
case 'Enter': case 'Enter':
if (focusedItemKey) { if (focusedItemKey) {
@ -377,34 +306,47 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
break break
case 'Escape': case 'Escape':
e.preventDefault() e.preventDefault()
scrollTriggerRef.current = 'none' setScrollTrigger('none')
setOpen(false) setOpen(false)
resolve(undefined) resolve(undefined)
break break
} }
} },
[
window.addEventListener('keydown', handleKeyDown) focusedItemKey,
return () => window.removeEventListener('keydown', handleKeyDown) modelItems,
}, [focusedItemKey, modelItems, handleItemClick, open, resolve]) handleItemClick,
open,
const onCancel = useCallback(() => { resolve,
scrollTriggerRef.current = 'none' setIsMouseOver,
setOpen(false) focusNextItem,
}, []) focusPage,
setScrollTrigger
const onClose = useCallback(async () => { ]
scrollTriggerRef.current = 'none' )
resolve(undefined)
SelectModelPopup.hide()
}, [resolve])
useEffect(() => { useEffect(() => {
if (!open) return window.addEventListener('keydown', handleKeyDown)
return () => window.removeEventListener('keydown', handleKeyDown)
}, [handleKeyDown])
const onCancel = useCallback(() => {
setScrollTrigger('initial')
setOpen(false)
}, [setScrollTrigger])
const onClose = useCallback(async () => {
setScrollTrigger('initial')
resolve(undefined)
SelectModelPopup.hide()
}, [resolve, setScrollTrigger])
// 初始化焦点和滚动位置
useEffect(() => {
if (!open || loadingPinnedModels) return
setTimeout(() => inputRef.current?.focus(), 0) setTimeout(() => inputRef.current?.focus(), 0)
scrollTriggerRef.current = 'initial' initScroll()
lastScrollOffsetRef.current = 0 }, [open, initScroll, loadingPinnedModels])
}, [open])
const RowData = useMemo( const RowData = useMemo(
(): VirtualizedRowData => ({ (): VirtualizedRowData => ({
@ -415,7 +357,7 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
handleItemClick, handleItemClick,
togglePin togglePin
}), }),
[stickyGroup, focusedItemKey, handleItemClick, listItems, togglePin] [stickyGroup, focusedItemKey, handleItemClick, listItems, togglePin, setFocusedItemKey]
) )
const listHeight = useMemo(() => { const listHeight = useMemo(() => {
@ -470,7 +412,7 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
<Divider style={{ margin: 0, marginTop: 4, borderBlockStartWidth: 0.5 }} /> <Divider style={{ margin: 0, marginTop: 4, borderBlockStartWidth: 0.5 }} />
{listItems.length > 0 ? ( {listItems.length > 0 ? (
<ListContainer onMouseMove={() => setIsMouseOver(true)}> <ListContainer onMouseMove={() => !isMouseOver && setIsMouseOver(true)}>
{/* Sticky Group Banner它会替换第一个分组名称 */} {/* Sticky Group Banner它会替换第一个分组名称 */}
<StickyGroupBanner>{stickyGroup?.name}</StickyGroupBanner> <StickyGroupBanner>{stickyGroup?.name}</StickyGroupBanner>
<FixedSizeList <FixedSizeList
@ -685,14 +627,26 @@ const PinIconWrapper = styled.div.attrs({ className: 'pin-icon' })<{ $isPinned?:
} }
` `
export default class SelectModelPopup { const TopViewKey = 'SelectModelPopup'
export class SelectModelPopup {
static topviewId = 0
static hide() { static hide() {
TopView.hide('SelectModelPopup') TopView.hide(TopViewKey)
} }
static show(params: Props) { static show(params: PopupParams) {
return new Promise<Model | undefined>((resolve) => { return new Promise<Model | undefined>((resolve) => {
TopView.show(<PopupContainer {...params} resolve={resolve} />, 'SelectModelPopup') TopView.show(
<PopupContainer
{...params}
resolve={(v) => {
resolve(v)
TopView.hide(TopViewKey)
}}
/>,
TopViewKey
)
}) })
} }
} }

View File

@ -0,0 +1,109 @@
import { ScrollAction, ScrollState } from './types'
/**
*
*/
export const initialScrollState: ScrollState = {
focusedItemKey: '',
scrollTrigger: 'initial',
lastScrollOffset: 0,
stickyGroup: null,
isMouseOver: false
}
/**
* reducer
* @param state
* @param action
* @returns
*/
export const scrollReducer = (state: ScrollState, action: ScrollAction): ScrollState => {
switch (action.type) {
case 'SET_FOCUSED_ITEM_KEY':
return { ...state, focusedItemKey: action.payload }
case 'SET_SCROLL_TRIGGER':
return { ...state, scrollTrigger: action.payload }
case 'SET_LAST_SCROLL_OFFSET':
return { ...state, lastScrollOffset: action.payload }
case 'SET_STICKY_GROUP':
return { ...state, stickyGroup: action.payload }
case 'SET_IS_MOUSE_OVER':
return { ...state, isMouseOver: action.payload }
case 'FOCUS_NEXT_ITEM': {
const { modelItems, step } = action.payload
if (modelItems.length === 0) {
return {
...state,
focusedItemKey: '',
scrollTrigger: 'keyboard'
}
}
const currentIndex = modelItems.findIndex((item) => item.key === state.focusedItemKey)
const nextIndex = (currentIndex < 0 ? 0 : currentIndex + step + modelItems.length) % modelItems.length
return {
...state,
focusedItemKey: modelItems[nextIndex].key,
scrollTrigger: 'keyboard'
}
}
case 'FOCUS_PAGE': {
const { modelItems, currentIndex, step } = action.payload
const nextIndex = Math.max(0, Math.min(currentIndex + step, modelItems.length - 1))
return {
...state,
focusedItemKey: modelItems.length > 0 ? modelItems[nextIndex].key : '',
scrollTrigger: 'keyboard'
}
}
case 'SEARCH_CHANGED':
return {
...state,
scrollTrigger: action.payload.searchText ? 'search' : 'initial'
}
case 'UPDATE_ON_LIST_CHANGE': {
const { modelItems } = action.payload
// 在列表变化时尝试聚焦一个模型:
// - 如果是 initial 状态,先尝试聚焦当前选中的模型
// - 如果是 search 状态,尝试聚焦第一个模型
let newFocusedKey = ''
if (state.scrollTrigger === 'initial' || state.scrollTrigger === 'search') {
const selectedItem = modelItems.find((item) => item.isSelected)
if (selectedItem && state.scrollTrigger === 'initial') {
newFocusedKey = selectedItem.key
} else if (modelItems.length > 0) {
newFocusedKey = modelItems[0].key
}
} else {
newFocusedKey = state.focusedItemKey
}
return {
...state,
focusedItemKey: newFocusedKey
}
}
case 'INIT_SCROLL':
return {
...state,
scrollTrigger: 'initial',
lastScrollOffset: 0
}
default:
return state
}
}

View File

@ -0,0 +1,42 @@
import { Model } from '@renderer/types'
import { ReactNode } from 'react'
// 列表项类型,组名也作为列表项
export type ListItemType = 'group' | 'model'
// 滚动触发来源类型
export type ScrollTrigger = 'initial' | 'search' | 'keyboard' | 'none'
// 扁平化列表项接口
export interface FlatListItem {
key: string
type: ListItemType
icon?: ReactNode
name: ReactNode
tags?: ReactNode
model?: Model
isPinned?: boolean
isSelected?: boolean
}
// 滚动和焦点相关的状态类型
export interface ScrollState {
focusedItemKey: string
scrollTrigger: ScrollTrigger
lastScrollOffset: number
stickyGroup: FlatListItem | null
isMouseOver: boolean
}
// 滚动和焦点相关的 action 类型
export type ScrollAction =
| { type: 'SET_FOCUSED_ITEM_KEY'; payload: string }
| { type: 'SET_SCROLL_TRIGGER'; payload: ScrollTrigger }
| { type: 'SET_LAST_SCROLL_OFFSET'; payload: number }
| { type: 'SET_STICKY_GROUP'; payload: FlatListItem | null }
| { type: 'SET_IS_MOUSE_OVER'; payload: boolean }
| { type: 'FOCUS_NEXT_ITEM'; payload: { modelItems: FlatListItem[]; step: number } }
| { type: 'FOCUS_PAGE'; payload: { modelItems: FlatListItem[]; currentIndex: number; step: number } }
| { type: 'SEARCH_CHANGED'; payload: { searchText: string } }
| { type: 'UPDATE_ON_LIST_CHANGE'; payload: { modelItems: FlatListItem[] } }
| { type: 'INIT_SCROLL'; payload?: void }

View File

@ -0,0 +1,63 @@
import db from '@renderer/databases'
import { getModelUniqId } from '@renderer/services/ModelService'
import { sortBy } from 'lodash'
import { useCallback, useEffect, useState } from 'react'
import { useProviders } from './useProvider'
export const usePinnedModels = () => {
const [pinnedModels, setPinnedModels] = useState<string[]>([])
const [loading, setLoading] = useState(true)
const { providers } = useProviders()
useEffect(() => {
const loadPinnedModels = async () => {
setLoading(true)
const setting = await db.settings.get('pinned:models')
const savedPinnedModels = setting?.value || []
// Filter out invalid pinned models
const allModelIds = providers.flatMap((p) => p.models || []).map((m) => getModelUniqId(m))
const validPinnedModels = savedPinnedModels.filter((id) => allModelIds.includes(id))
// Update storage if there were invalid models
if (validPinnedModels.length !== savedPinnedModels.length) {
await db.settings.put({ id: 'pinned:models', value: validPinnedModels })
}
setPinnedModels(sortBy(validPinnedModels))
setLoading(false)
}
loadPinnedModels().catch((error) => {
console.error('Failed to load pinned models', error)
setPinnedModels([])
setLoading(false)
})
}, [providers])
const updatePinnedModels = useCallback(async (models: string[]) => {
await db.settings.put({ id: 'pinned:models', value: models })
setPinnedModels(sortBy(models))
}, [])
/**
* Toggle a single pinned model
* @param modelId - The ID string of the model to toggle
*/
const togglePinnedModel = useCallback(
async (modelId: string) => {
try {
const newPinnedModels = pinnedModels.includes(modelId)
? pinnedModels.filter((id) => id !== modelId)
: [...pinnedModels, modelId]
await updatePinnedModels(newPinnedModels)
} catch (error) {
console.error('Failed to toggle pinned model', error)
}
},
[pinnedModels, updatePinnedModels]
)
return { pinnedModels, updatePinnedModels, togglePinnedModel, loading }
}