mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-07 22:10:21 +08:00
refactor(agent): update useUpdateAgent hook and related components
- Refactor useUpdateAgent to return both updateAgent and updateModel functions - Update all components using useUpdateAgent to use the new hook structure - Improve model selection by reusing SelectAgentModelButton component - Add pagination support to useApiModels hook
This commit is contained in:
parent
3111979bb4
commit
42435e8f76
@ -98,7 +98,7 @@ export const AgentModal: React.FC<Props> = ({ agent, trigger, isOpen: _isOpen, o
|
|||||||
const loadingRef = useRef(false)
|
const loadingRef = useRef(false)
|
||||||
// const { setTimeoutTimer } = useTimer()
|
// const { setTimeoutTimer } = useTimer()
|
||||||
const { addAgent } = useAgents()
|
const { addAgent } = useAgents()
|
||||||
const updateAgent = useUpdateAgent()
|
const { updateAgent } = useUpdateAgent()
|
||||||
// hard-coded. We only support anthropic for now.
|
// hard-coded. We only support anthropic for now.
|
||||||
const { models } = useApiModels({ providerType: 'anthropic' })
|
const { models } = useApiModels({ providerType: 'anthropic' })
|
||||||
const isEditing = (agent?: AgentWithTools) => agent !== undefined
|
const isEditing = (agent?: AgentWithTools) => agent !== undefined
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import { ApiModelsFilter } from '@renderer/types'
|
import { ApiModel, ApiModelsFilter } from '@renderer/types'
|
||||||
import { merge } from 'lodash'
|
import { merge } from 'lodash'
|
||||||
import { useCallback } from 'react'
|
import { useCallback } from 'react'
|
||||||
import useSWR from 'swr'
|
import useSWR from 'swr'
|
||||||
@ -11,8 +11,20 @@ export const useApiModels = (filter?: ApiModelsFilter) => {
|
|||||||
const defaultFilter = {} satisfies ApiModelsFilter
|
const defaultFilter = {} satisfies ApiModelsFilter
|
||||||
const finalFilter = merge(filter, defaultFilter)
|
const finalFilter = merge(filter, defaultFilter)
|
||||||
const path = client.getModelsPath(finalFilter)
|
const path = client.getModelsPath(finalFilter)
|
||||||
const fetcher = useCallback(() => {
|
const fetcher = useCallback(async () => {
|
||||||
return client.getModels(finalFilter)
|
const limit = finalFilter.limit || 100
|
||||||
|
let offset = finalFilter.offset || 0
|
||||||
|
const allModels: ApiModel[] = []
|
||||||
|
let total = Infinity
|
||||||
|
|
||||||
|
while (offset < total) {
|
||||||
|
const pageFilter = { ...finalFilter, limit, offset }
|
||||||
|
const res = await client.getModels(pageFilter)
|
||||||
|
allModels.push(...(res.data || []))
|
||||||
|
total = res.total ?? 0
|
||||||
|
offset += limit
|
||||||
|
}
|
||||||
|
return { data: allModels, total }
|
||||||
}, [client, finalFilter])
|
}, [client, finalFilter])
|
||||||
const { data, error, isLoading } = useSWR(path, fetcher)
|
const { data, error, isLoading } = useSWR(path, fetcher)
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -6,20 +6,27 @@ import { mutate } from 'swr'
|
|||||||
|
|
||||||
import { useAgentClient } from './useAgentClient'
|
import { useAgentClient } from './useAgentClient'
|
||||||
|
|
||||||
|
export type UpdateAgentOptions = {
|
||||||
|
/** Whether to show success toast after updating. Defaults to true. */
|
||||||
|
showSuccessToast?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
export const useUpdateAgent = () => {
|
export const useUpdateAgent = () => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const client = useAgentClient()
|
const client = useAgentClient()
|
||||||
const listKey = client.agentPaths.base
|
const listKey = client.agentPaths.base
|
||||||
|
|
||||||
const updateAgent = useCallback(
|
const updateAgent = useCallback(
|
||||||
async (form: UpdateAgentForm) => {
|
async (form: UpdateAgentForm, options?: UpdateAgentOptions) => {
|
||||||
try {
|
try {
|
||||||
const itemKey = client.agentPaths.withId(form.id)
|
const itemKey = client.agentPaths.withId(form.id)
|
||||||
// may change to optimistic update
|
// may change to optimistic update
|
||||||
const result = await client.updateAgent(form)
|
const result = await client.updateAgent(form)
|
||||||
mutate<ListAgentsResponse['data']>(listKey, (prev) => prev?.map((a) => (a.id === result.id ? result : a)) ?? [])
|
mutate<ListAgentsResponse['data']>(listKey, (prev) => prev?.map((a) => (a.id === result.id ? result : a)) ?? [])
|
||||||
mutate(itemKey, result)
|
mutate(itemKey, result)
|
||||||
window.toast.success(t('common.update_success'))
|
if (options?.showSuccessToast ?? true) {
|
||||||
|
window.toast.success(t('common.update_success'))
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
window.toast.error(formatErrorMessageWithPrefix(error, t('agent.update.error.failed')))
|
window.toast.error(formatErrorMessageWithPrefix(error, t('agent.update.error.failed')))
|
||||||
}
|
}
|
||||||
@ -27,5 +34,12 @@ export const useUpdateAgent = () => {
|
|||||||
[client, listKey, t]
|
[client, listKey, t]
|
||||||
)
|
)
|
||||||
|
|
||||||
return updateAgent
|
const updateModel = useCallback(
|
||||||
|
async (agentId: string, modelId: string, options?: UpdateAgentOptions) => {
|
||||||
|
updateAgent({ id: agentId, model: modelId }, options)
|
||||||
|
},
|
||||||
|
[updateAgent]
|
||||||
|
)
|
||||||
|
|
||||||
|
return { updateAgent, updateModel }
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import { NavbarHeader } from '@renderer/components/app/Navbar'
|
|||||||
import { HStack } from '@renderer/components/Layout'
|
import { HStack } from '@renderer/components/Layout'
|
||||||
import SearchPopup from '@renderer/components/Popups/SearchPopup'
|
import SearchPopup from '@renderer/components/Popups/SearchPopup'
|
||||||
import { useAgent } from '@renderer/hooks/agents/useAgent'
|
import { useAgent } from '@renderer/hooks/agents/useAgent'
|
||||||
import { useApiModel } from '@renderer/hooks/agents/useModel'
|
import { useUpdateAgent } from '@renderer/hooks/agents/useUpdateAgent'
|
||||||
import { useAssistant } from '@renderer/hooks/useAssistant'
|
import { useAssistant } from '@renderer/hooks/useAssistant'
|
||||||
import { modelGenerating, useRuntime } from '@renderer/hooks/useRuntime'
|
import { modelGenerating, useRuntime } from '@renderer/hooks/useRuntime'
|
||||||
import { useSettings } from '@renderer/hooks/useSettings'
|
import { useSettings } from '@renderer/hooks/useSettings'
|
||||||
@ -11,12 +11,12 @@ import { useShowAssistants, useShowTopics } from '@renderer/hooks/useStore'
|
|||||||
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
||||||
import { useAppDispatch } from '@renderer/store'
|
import { useAppDispatch } from '@renderer/store'
|
||||||
import { setNarrowMode } from '@renderer/store/settings'
|
import { setNarrowMode } from '@renderer/store/settings'
|
||||||
import { Assistant, Topic } from '@renderer/types'
|
import { ApiModel, Assistant, Topic } from '@renderer/types'
|
||||||
import { Tooltip } from 'antd'
|
import { Tooltip } from 'antd'
|
||||||
import { t } from 'i18next'
|
import { t } from 'i18next'
|
||||||
import { Menu, PanelLeftClose, PanelRightClose, Search } from 'lucide-react'
|
import { Menu, PanelLeftClose, PanelRightClose, Search } from 'lucide-react'
|
||||||
import { AnimatePresence, motion } from 'motion/react'
|
import { AnimatePresence, motion } from 'motion/react'
|
||||||
import { FC } from 'react'
|
import { FC, useCallback } from 'react'
|
||||||
import styled from 'styled-components'
|
import styled from 'styled-components'
|
||||||
|
|
||||||
import AssistantsDrawer from './components/AssistantsDrawer'
|
import AssistantsDrawer from './components/AssistantsDrawer'
|
||||||
@ -41,8 +41,7 @@ const HeaderNavbar: FC<Props> = ({ activeAssistant, setActiveAssistant, activeTo
|
|||||||
const { chat } = useRuntime()
|
const { chat } = useRuntime()
|
||||||
const { activeTopicOrSession, activeAgentId } = chat
|
const { activeTopicOrSession, activeAgentId } = chat
|
||||||
const { agent } = useAgent(activeAgentId)
|
const { agent } = useAgent(activeAgentId)
|
||||||
// TODO: filter is temporally for agent since it cannot get all models once
|
const { updateModel } = useUpdateAgent()
|
||||||
const agentModel = useApiModel({ id: agent?.model, filter: { providerType: 'anthropic' } })
|
|
||||||
|
|
||||||
useShortcut('toggle_show_assistants', toggleShowAssistants)
|
useShortcut('toggle_show_assistants', toggleShowAssistants)
|
||||||
|
|
||||||
@ -72,6 +71,14 @@ const HeaderNavbar: FC<Props> = ({ activeAssistant, setActiveAssistant, activeTo
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleUpdateModel = useCallback(
|
||||||
|
async (model: ApiModel) => {
|
||||||
|
if (!agent) return
|
||||||
|
return updateModel(agent.id, model.id, { showSuccessToast: false })
|
||||||
|
},
|
||||||
|
[agent, updateModel]
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<NavbarHeader className="home-navbar">
|
<NavbarHeader className="home-navbar">
|
||||||
<div className="flex flex-1 items-center">
|
<div className="flex flex-1 items-center">
|
||||||
@ -103,8 +110,8 @@ const HeaderNavbar: FC<Props> = ({ activeAssistant, setActiveAssistant, activeTo
|
|||||||
)}
|
)}
|
||||||
</AnimatePresence>
|
</AnimatePresence>
|
||||||
{activeTopicOrSession === 'topic' && <SelectModelButton assistant={assistant} />}
|
{activeTopicOrSession === 'topic' && <SelectModelButton assistant={assistant} />}
|
||||||
{activeTopicOrSession === 'session' && agent && agentModel && (
|
{activeTopicOrSession === 'session' && agent && (
|
||||||
<SelectAgentModelButton agent={agent} model={agentModel} />
|
<SelectAgentModelButton agent={agent} onSelect={handleUpdateModel} />
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<HStack alignItems="center" gap={8}>
|
<HStack alignItems="center" gap={8}>
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
import { Button } from '@heroui/react'
|
import { Button } from '@heroui/react'
|
||||||
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
|
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
|
||||||
import { SelectApiModelPopup } from '@renderer/components/Popups/SelectModelPopup'
|
import { SelectApiModelPopup } from '@renderer/components/Popups/SelectModelPopup'
|
||||||
import { useUpdateAgent } from '@renderer/hooks/agents/useUpdateAgent'
|
import { useApiModel } from '@renderer/hooks/agents/useModel'
|
||||||
import { AgentEntity, ApiModel } from '@renderer/types'
|
import { getProviderNameById } from '@renderer/services/ProviderService'
|
||||||
|
import { AgentBaseWithId, ApiModel, isAgentEntity } from '@renderer/types'
|
||||||
import { getModelFilterByAgentType } from '@renderer/utils/agentSession'
|
import { getModelFilterByAgentType } from '@renderer/utils/agentSession'
|
||||||
import { apiModelAdapter } from '@renderer/utils/model'
|
import { apiModelAdapter } from '@renderer/utils/model'
|
||||||
import { ChevronsUpDown } from 'lucide-react'
|
import { ChevronsUpDown } from 'lucide-react'
|
||||||
@ -10,31 +11,37 @@ import { FC } from 'react'
|
|||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
agent: AgentEntity
|
agent: AgentBaseWithId
|
||||||
model: ApiModel
|
onSelect: (model: ApiModel) => Promise<void>
|
||||||
|
isDisabled?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
const SelectAgentModelButton: FC<Props> = ({ agent, model }) => {
|
const SelectAgentModelButton: FC<Props> = ({ agent, onSelect, isDisabled }) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const update = useUpdateAgent()
|
const model = useApiModel({ id: agent?.model })
|
||||||
|
|
||||||
const modelFilter = getModelFilterByAgentType(agent.type)
|
const modelFilter = isAgentEntity(agent) ? getModelFilterByAgentType(agent.type) : undefined
|
||||||
|
|
||||||
if (!agent) return null
|
if (!agent) return null
|
||||||
|
|
||||||
const onSelectModel = async () => {
|
const onSelectModel = async () => {
|
||||||
const selectedModel = await SelectApiModelPopup.show({ model, filter: modelFilter })
|
const selectedModel = await SelectApiModelPopup.show({ model, filter: modelFilter })
|
||||||
if (selectedModel && selectedModel.id !== agent.model) {
|
if (selectedModel && selectedModel.id !== agent.model) {
|
||||||
update({ id: agent.id, model: selectedModel.id })
|
onSelect(selectedModel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const providerName = model.provider_name
|
const providerName = model?.provider ? getProviderNameById(model.provider) : model?.provider_name
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Button size="sm" variant="light" className="nodrag rounded-2xl px-1 py-3" onPress={onSelectModel}>
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="light"
|
||||||
|
className="nodrag rounded-2xl px-1 py-3"
|
||||||
|
onPress={onSelectModel}
|
||||||
|
isDisabled={isDisabled}>
|
||||||
<div className="flex items-center gap-1.5">
|
<div className="flex items-center gap-1.5">
|
||||||
<ModelAvatar model={apiModelAdapter(model)} size={20} />
|
<ModelAvatar model={model ? apiModelAdapter(model) : undefined} size={20} />
|
||||||
<span className="-mr-0.5 font-medium">
|
<span className="-mr-0.5 font-medium">
|
||||||
{model ? model.name : t('button.select_model')} {providerName ? ' | ' + providerName : ''}
|
{model ? model.name : t('button.select_model')} {providerName ? ' | ' + providerName : ''}
|
||||||
</span>
|
</span>
|
||||||
|
|||||||
@ -19,7 +19,7 @@ type AgentConfigurationState = AgentConfiguration & Record<string, unknown>
|
|||||||
type AdvancedSettingsProps =
|
type AdvancedSettingsProps =
|
||||||
| {
|
| {
|
||||||
agentBase: GetAgentResponse | undefined | null
|
agentBase: GetAgentResponse | undefined | null
|
||||||
update: ReturnType<typeof useUpdateAgent>
|
update: ReturnType<typeof useUpdateAgent>['updateAgent']
|
||||||
}
|
}
|
||||||
| {
|
| {
|
||||||
agentBase: GetAgentSessionResponse | undefined | null
|
agentBase: GetAgentSessionResponse | undefined | null
|
||||||
|
|||||||
@ -13,7 +13,7 @@ import { AgentLabel, SettingsContainer, SettingsItem, SettingsTitle } from './sh
|
|||||||
|
|
||||||
interface AgentEssentialSettingsProps {
|
interface AgentEssentialSettingsProps {
|
||||||
agent: GetAgentResponse | undefined | null
|
agent: GetAgentResponse | undefined | null
|
||||||
update: ReturnType<typeof useUpdateAgent>
|
update: ReturnType<typeof useUpdateAgent>['updateAgent']
|
||||||
}
|
}
|
||||||
|
|
||||||
const AgentEssentialSettings: FC<AgentEssentialSettingsProps> = ({ agent, update }) => {
|
const AgentEssentialSettings: FC<AgentEssentialSettingsProps> = ({ agent, update }) => {
|
||||||
|
|||||||
@ -28,7 +28,7 @@ const AgentSettingPopupContainer: React.FC<AgentSettingPopupParams> = ({ tab, ag
|
|||||||
const [menu, setMenu] = useState<AgentSettingPopupTab>(tab || 'essential')
|
const [menu, setMenu] = useState<AgentSettingPopupTab>(tab || 'essential')
|
||||||
|
|
||||||
const { agent, isLoading, error } = useAgent(agentId)
|
const { agent, isLoading, error } = useAgent(agentId)
|
||||||
const updateAgent = useUpdateAgent()
|
const { updateAgent } = useUpdateAgent()
|
||||||
|
|
||||||
const onOk = () => {
|
const onOk = () => {
|
||||||
setOpen(false)
|
setOpen(false)
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
import { Select, SelectedItems, SelectItem } from '@heroui/react'
|
import SelectAgentModelButton from '@renderer/pages/home/components/SelectAgentModelButton'
|
||||||
import { ApiModelLabel } from '@renderer/components/ApiModelLabel'
|
import { AgentBaseWithId, ApiModel, UpdateAgentBaseForm } from '@renderer/types'
|
||||||
import { useApiModels } from '@renderer/hooks/agents/useModels'
|
|
||||||
import { AgentBaseWithId, ApiModel, UpdateAgentBaseForm, UpdateAgentForm } from '@renderer/types'
|
|
||||||
import { useCallback } from 'react'
|
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
|
|
||||||
import { SettingsItem, SettingsTitle } from './shared'
|
import { SettingsItem, SettingsTitle } from './shared'
|
||||||
@ -15,43 +12,18 @@ export interface ModelSettingProps {
|
|||||||
|
|
||||||
export const ModelSetting: React.FC<ModelSettingProps> = ({ base, update, isDisabled }) => {
|
export const ModelSetting: React.FC<ModelSettingProps> = ({ base, update, isDisabled }) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const { models } = useApiModels({ providerType: 'anthropic' })
|
|
||||||
|
|
||||||
const updateModel = (model: UpdateAgentForm['model']) => {
|
const updateModel = async (model: ApiModel) => {
|
||||||
if (!base) return
|
if (!base) return
|
||||||
update({ id: base.id, model })
|
return update({ id: base.id, model: model.id })
|
||||||
}
|
}
|
||||||
|
|
||||||
const renderModels = useCallback((items: SelectedItems<ApiModel>) => {
|
|
||||||
return items.map((item) => {
|
|
||||||
const model = item.data ?? undefined
|
|
||||||
return <ApiModelLabel key={model?.id} model={model} />
|
|
||||||
})
|
|
||||||
}, [])
|
|
||||||
|
|
||||||
if (!base) return null
|
if (!base) return null
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<SettingsItem inline className="gap-8">
|
<SettingsItem inline className="gap-8">
|
||||||
<SettingsTitle id="model">{t('common.model')}</SettingsTitle>
|
<SettingsTitle id="model">{t('common.model')}</SettingsTitle>
|
||||||
<Select
|
<SelectAgentModelButton agent={base} onSelect={updateModel} isDisabled={isDisabled} />
|
||||||
isDisabled={isDisabled}
|
|
||||||
selectionMode="single"
|
|
||||||
aria-labelledby="model"
|
|
||||||
items={models}
|
|
||||||
selectedKeys={[base.model]}
|
|
||||||
onSelectionChange={(keys) => {
|
|
||||||
updateModel(keys.currentKey)
|
|
||||||
}}
|
|
||||||
className="max-w-80 flex-1"
|
|
||||||
placeholder={t('common.placeholders.select.model')}
|
|
||||||
renderValue={renderModels}>
|
|
||||||
{(model) => (
|
|
||||||
<SelectItem textValue={model.id}>
|
|
||||||
<ApiModelLabel model={model} />
|
|
||||||
</SelectItem>
|
|
||||||
)}
|
|
||||||
</Select>
|
|
||||||
</SettingsItem>
|
</SettingsItem>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,7 +18,7 @@ import { SettingsContainer, SettingsItem, SettingsTitle } from './shared'
|
|||||||
type AgentPromptSettingsProps =
|
type AgentPromptSettingsProps =
|
||||||
| {
|
| {
|
||||||
agentBase: AgentEntity | undefined | null
|
agentBase: AgentEntity | undefined | null
|
||||||
update: ReturnType<typeof useUpdateAgent>
|
update: ReturnType<typeof useUpdateAgent>['updateAgent']
|
||||||
}
|
}
|
||||||
| {
|
| {
|
||||||
agentBase: AgentSessionEntity | undefined | null
|
agentBase: AgentSessionEntity | undefined | null
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import { SettingsContainer } from './shared'
|
|||||||
|
|
||||||
interface SessionEssentialSettingsProps {
|
interface SessionEssentialSettingsProps {
|
||||||
session: GetAgentSessionResponse | undefined | null
|
session: GetAgentSessionResponse | undefined | null
|
||||||
update: ReturnType<typeof useUpdateAgent>
|
update: ReturnType<typeof useUpdateAgent>['updateAgent']
|
||||||
}
|
}
|
||||||
|
|
||||||
const SessionEssentialSettings: FC<SessionEssentialSettingsProps> = ({ session, update }) => {
|
const SessionEssentialSettings: FC<SessionEssentialSettingsProps> = ({ session, update }) => {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user