mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 12:51:26 +08:00
refactor(agent): update useUpdateAgent hook and related components
- Refactor useUpdateAgent to return both updateAgent and updateModel functions - Update all components using useUpdateAgent to use the new hook structure - Improve model selection by reusing SelectAgentModelButton component - Add pagination support to useApiModels hook
This commit is contained in:
parent
3111979bb4
commit
42435e8f76
@ -98,7 +98,7 @@ export const AgentModal: React.FC<Props> = ({ agent, trigger, isOpen: _isOpen, o
|
||||
const loadingRef = useRef(false)
|
||||
// const { setTimeoutTimer } = useTimer()
|
||||
const { addAgent } = useAgents()
|
||||
const updateAgent = useUpdateAgent()
|
||||
const { updateAgent } = useUpdateAgent()
|
||||
// hard-coded. We only support anthropic for now.
|
||||
const { models } = useApiModels({ providerType: 'anthropic' })
|
||||
const isEditing = (agent?: AgentWithTools) => agent !== undefined
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { ApiModelsFilter } from '@renderer/types'
|
||||
import { ApiModel, ApiModelsFilter } from '@renderer/types'
|
||||
import { merge } from 'lodash'
|
||||
import { useCallback } from 'react'
|
||||
import useSWR from 'swr'
|
||||
@ -11,8 +11,20 @@ export const useApiModels = (filter?: ApiModelsFilter) => {
|
||||
const defaultFilter = {} satisfies ApiModelsFilter
|
||||
const finalFilter = merge(filter, defaultFilter)
|
||||
const path = client.getModelsPath(finalFilter)
|
||||
const fetcher = useCallback(() => {
|
||||
return client.getModels(finalFilter)
|
||||
const fetcher = useCallback(async () => {
|
||||
const limit = finalFilter.limit || 100
|
||||
let offset = finalFilter.offset || 0
|
||||
const allModels: ApiModel[] = []
|
||||
let total = Infinity
|
||||
|
||||
while (offset < total) {
|
||||
const pageFilter = { ...finalFilter, limit, offset }
|
||||
const res = await client.getModels(pageFilter)
|
||||
allModels.push(...(res.data || []))
|
||||
total = res.total ?? 0
|
||||
offset += limit
|
||||
}
|
||||
return { data: allModels, total }
|
||||
}, [client, finalFilter])
|
||||
const { data, error, isLoading } = useSWR(path, fetcher)
|
||||
return {
|
||||
|
||||
@ -6,20 +6,27 @@ import { mutate } from 'swr'
|
||||
|
||||
import { useAgentClient } from './useAgentClient'
|
||||
|
||||
export type UpdateAgentOptions = {
|
||||
/** Whether to show success toast after updating. Defaults to true. */
|
||||
showSuccessToast?: boolean
|
||||
}
|
||||
|
||||
export const useUpdateAgent = () => {
|
||||
const { t } = useTranslation()
|
||||
const client = useAgentClient()
|
||||
const listKey = client.agentPaths.base
|
||||
|
||||
const updateAgent = useCallback(
|
||||
async (form: UpdateAgentForm) => {
|
||||
async (form: UpdateAgentForm, options?: UpdateAgentOptions) => {
|
||||
try {
|
||||
const itemKey = client.agentPaths.withId(form.id)
|
||||
// may change to optimistic update
|
||||
const result = await client.updateAgent(form)
|
||||
mutate<ListAgentsResponse['data']>(listKey, (prev) => prev?.map((a) => (a.id === result.id ? result : a)) ?? [])
|
||||
mutate(itemKey, result)
|
||||
window.toast.success(t('common.update_success'))
|
||||
if (options?.showSuccessToast ?? true) {
|
||||
window.toast.success(t('common.update_success'))
|
||||
}
|
||||
} catch (error) {
|
||||
window.toast.error(formatErrorMessageWithPrefix(error, t('agent.update.error.failed')))
|
||||
}
|
||||
@ -27,5 +34,12 @@ export const useUpdateAgent = () => {
|
||||
[client, listKey, t]
|
||||
)
|
||||
|
||||
return updateAgent
|
||||
const updateModel = useCallback(
|
||||
async (agentId: string, modelId: string, options?: UpdateAgentOptions) => {
|
||||
updateAgent({ id: agentId, model: modelId }, options)
|
||||
},
|
||||
[updateAgent]
|
||||
)
|
||||
|
||||
return { updateAgent, updateModel }
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@ import { NavbarHeader } from '@renderer/components/app/Navbar'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import SearchPopup from '@renderer/components/Popups/SearchPopup'
|
||||
import { useAgent } from '@renderer/hooks/agents/useAgent'
|
||||
import { useApiModel } from '@renderer/hooks/agents/useModel'
|
||||
import { useUpdateAgent } from '@renderer/hooks/agents/useUpdateAgent'
|
||||
import { useAssistant } from '@renderer/hooks/useAssistant'
|
||||
import { modelGenerating, useRuntime } from '@renderer/hooks/useRuntime'
|
||||
import { useSettings } from '@renderer/hooks/useSettings'
|
||||
@ -11,12 +11,12 @@ import { useShowAssistants, useShowTopics } from '@renderer/hooks/useStore'
|
||||
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
||||
import { useAppDispatch } from '@renderer/store'
|
||||
import { setNarrowMode } from '@renderer/store/settings'
|
||||
import { Assistant, Topic } from '@renderer/types'
|
||||
import { ApiModel, Assistant, Topic } from '@renderer/types'
|
||||
import { Tooltip } from 'antd'
|
||||
import { t } from 'i18next'
|
||||
import { Menu, PanelLeftClose, PanelRightClose, Search } from 'lucide-react'
|
||||
import { AnimatePresence, motion } from 'motion/react'
|
||||
import { FC } from 'react'
|
||||
import { FC, useCallback } from 'react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
import AssistantsDrawer from './components/AssistantsDrawer'
|
||||
@ -41,8 +41,7 @@ const HeaderNavbar: FC<Props> = ({ activeAssistant, setActiveAssistant, activeTo
|
||||
const { chat } = useRuntime()
|
||||
const { activeTopicOrSession, activeAgentId } = chat
|
||||
const { agent } = useAgent(activeAgentId)
|
||||
// TODO: filter is temporally for agent since it cannot get all models once
|
||||
const agentModel = useApiModel({ id: agent?.model, filter: { providerType: 'anthropic' } })
|
||||
const { updateModel } = useUpdateAgent()
|
||||
|
||||
useShortcut('toggle_show_assistants', toggleShowAssistants)
|
||||
|
||||
@ -72,6 +71,14 @@ const HeaderNavbar: FC<Props> = ({ activeAssistant, setActiveAssistant, activeTo
|
||||
})
|
||||
}
|
||||
|
||||
const handleUpdateModel = useCallback(
|
||||
async (model: ApiModel) => {
|
||||
if (!agent) return
|
||||
return updateModel(agent.id, model.id, { showSuccessToast: false })
|
||||
},
|
||||
[agent, updateModel]
|
||||
)
|
||||
|
||||
return (
|
||||
<NavbarHeader className="home-navbar">
|
||||
<div className="flex flex-1 items-center">
|
||||
@ -103,8 +110,8 @@ const HeaderNavbar: FC<Props> = ({ activeAssistant, setActiveAssistant, activeTo
|
||||
)}
|
||||
</AnimatePresence>
|
||||
{activeTopicOrSession === 'topic' && <SelectModelButton assistant={assistant} />}
|
||||
{activeTopicOrSession === 'session' && agent && agentModel && (
|
||||
<SelectAgentModelButton agent={agent} model={agentModel} />
|
||||
{activeTopicOrSession === 'session' && agent && (
|
||||
<SelectAgentModelButton agent={agent} onSelect={handleUpdateModel} />
|
||||
)}
|
||||
</div>
|
||||
<HStack alignItems="center" gap={8}>
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import { Button } from '@heroui/react'
|
||||
import ModelAvatar from '@renderer/components/Avatar/ModelAvatar'
|
||||
import { SelectApiModelPopup } from '@renderer/components/Popups/SelectModelPopup'
|
||||
import { useUpdateAgent } from '@renderer/hooks/agents/useUpdateAgent'
|
||||
import { AgentEntity, ApiModel } from '@renderer/types'
|
||||
import { useApiModel } from '@renderer/hooks/agents/useModel'
|
||||
import { getProviderNameById } from '@renderer/services/ProviderService'
|
||||
import { AgentBaseWithId, ApiModel, isAgentEntity } from '@renderer/types'
|
||||
import { getModelFilterByAgentType } from '@renderer/utils/agentSession'
|
||||
import { apiModelAdapter } from '@renderer/utils/model'
|
||||
import { ChevronsUpDown } from 'lucide-react'
|
||||
@ -10,31 +11,37 @@ import { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
interface Props {
|
||||
agent: AgentEntity
|
||||
model: ApiModel
|
||||
agent: AgentBaseWithId
|
||||
onSelect: (model: ApiModel) => Promise<void>
|
||||
isDisabled?: boolean
|
||||
}
|
||||
|
||||
const SelectAgentModelButton: FC<Props> = ({ agent, model }) => {
|
||||
const SelectAgentModelButton: FC<Props> = ({ agent, onSelect, isDisabled }) => {
|
||||
const { t } = useTranslation()
|
||||
const update = useUpdateAgent()
|
||||
const model = useApiModel({ id: agent?.model })
|
||||
|
||||
const modelFilter = getModelFilterByAgentType(agent.type)
|
||||
const modelFilter = isAgentEntity(agent) ? getModelFilterByAgentType(agent.type) : undefined
|
||||
|
||||
if (!agent) return null
|
||||
|
||||
const onSelectModel = async () => {
|
||||
const selectedModel = await SelectApiModelPopup.show({ model, filter: modelFilter })
|
||||
if (selectedModel && selectedModel.id !== agent.model) {
|
||||
update({ id: agent.id, model: selectedModel.id })
|
||||
onSelect(selectedModel)
|
||||
}
|
||||
}
|
||||
|
||||
const providerName = model.provider_name
|
||||
const providerName = model?.provider ? getProviderNameById(model.provider) : model?.provider_name
|
||||
|
||||
return (
|
||||
<Button size="sm" variant="light" className="nodrag rounded-2xl px-1 py-3" onPress={onSelectModel}>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="light"
|
||||
className="nodrag rounded-2xl px-1 py-3"
|
||||
onPress={onSelectModel}
|
||||
isDisabled={isDisabled}>
|
||||
<div className="flex items-center gap-1.5">
|
||||
<ModelAvatar model={apiModelAdapter(model)} size={20} />
|
||||
<ModelAvatar model={model ? apiModelAdapter(model) : undefined} size={20} />
|
||||
<span className="-mr-0.5 font-medium">
|
||||
{model ? model.name : t('button.select_model')} {providerName ? ' | ' + providerName : ''}
|
||||
</span>
|
||||
|
||||
@ -19,7 +19,7 @@ type AgentConfigurationState = AgentConfiguration & Record<string, unknown>
|
||||
type AdvancedSettingsProps =
|
||||
| {
|
||||
agentBase: GetAgentResponse | undefined | null
|
||||
update: ReturnType<typeof useUpdateAgent>
|
||||
update: ReturnType<typeof useUpdateAgent>['updateAgent']
|
||||
}
|
||||
| {
|
||||
agentBase: GetAgentSessionResponse | undefined | null
|
||||
|
||||
@ -13,7 +13,7 @@ import { AgentLabel, SettingsContainer, SettingsItem, SettingsTitle } from './sh
|
||||
|
||||
interface AgentEssentialSettingsProps {
|
||||
agent: GetAgentResponse | undefined | null
|
||||
update: ReturnType<typeof useUpdateAgent>
|
||||
update: ReturnType<typeof useUpdateAgent>['updateAgent']
|
||||
}
|
||||
|
||||
const AgentEssentialSettings: FC<AgentEssentialSettingsProps> = ({ agent, update }) => {
|
||||
|
||||
@ -28,7 +28,7 @@ const AgentSettingPopupContainer: React.FC<AgentSettingPopupParams> = ({ tab, ag
|
||||
const [menu, setMenu] = useState<AgentSettingPopupTab>(tab || 'essential')
|
||||
|
||||
const { agent, isLoading, error } = useAgent(agentId)
|
||||
const updateAgent = useUpdateAgent()
|
||||
const { updateAgent } = useUpdateAgent()
|
||||
|
||||
const onOk = () => {
|
||||
setOpen(false)
|
||||
|
||||
@ -1,8 +1,5 @@
|
||||
import { Select, SelectedItems, SelectItem } from '@heroui/react'
|
||||
import { ApiModelLabel } from '@renderer/components/ApiModelLabel'
|
||||
import { useApiModels } from '@renderer/hooks/agents/useModels'
|
||||
import { AgentBaseWithId, ApiModel, UpdateAgentBaseForm, UpdateAgentForm } from '@renderer/types'
|
||||
import { useCallback } from 'react'
|
||||
import SelectAgentModelButton from '@renderer/pages/home/components/SelectAgentModelButton'
|
||||
import { AgentBaseWithId, ApiModel, UpdateAgentBaseForm } from '@renderer/types'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { SettingsItem, SettingsTitle } from './shared'
|
||||
@ -15,43 +12,18 @@ export interface ModelSettingProps {
|
||||
|
||||
export const ModelSetting: React.FC<ModelSettingProps> = ({ base, update, isDisabled }) => {
|
||||
const { t } = useTranslation()
|
||||
const { models } = useApiModels({ providerType: 'anthropic' })
|
||||
|
||||
const updateModel = (model: UpdateAgentForm['model']) => {
|
||||
const updateModel = async (model: ApiModel) => {
|
||||
if (!base) return
|
||||
update({ id: base.id, model })
|
||||
return update({ id: base.id, model: model.id })
|
||||
}
|
||||
|
||||
const renderModels = useCallback((items: SelectedItems<ApiModel>) => {
|
||||
return items.map((item) => {
|
||||
const model = item.data ?? undefined
|
||||
return <ApiModelLabel key={model?.id} model={model} />
|
||||
})
|
||||
}, [])
|
||||
|
||||
if (!base) return null
|
||||
|
||||
return (
|
||||
<SettingsItem inline className="gap-8">
|
||||
<SettingsTitle id="model">{t('common.model')}</SettingsTitle>
|
||||
<Select
|
||||
isDisabled={isDisabled}
|
||||
selectionMode="single"
|
||||
aria-labelledby="model"
|
||||
items={models}
|
||||
selectedKeys={[base.model]}
|
||||
onSelectionChange={(keys) => {
|
||||
updateModel(keys.currentKey)
|
||||
}}
|
||||
className="max-w-80 flex-1"
|
||||
placeholder={t('common.placeholders.select.model')}
|
||||
renderValue={renderModels}>
|
||||
{(model) => (
|
||||
<SelectItem textValue={model.id}>
|
||||
<ApiModelLabel model={model} />
|
||||
</SelectItem>
|
||||
)}
|
||||
</Select>
|
||||
<SelectAgentModelButton agent={base} onSelect={updateModel} isDisabled={isDisabled} />
|
||||
</SettingsItem>
|
||||
)
|
||||
}
|
||||
|
||||
@ -18,7 +18,7 @@ import { SettingsContainer, SettingsItem, SettingsTitle } from './shared'
|
||||
type AgentPromptSettingsProps =
|
||||
| {
|
||||
agentBase: AgentEntity | undefined | null
|
||||
update: ReturnType<typeof useUpdateAgent>
|
||||
update: ReturnType<typeof useUpdateAgent>['updateAgent']
|
||||
}
|
||||
| {
|
||||
agentBase: AgentSessionEntity | undefined | null
|
||||
|
||||
@ -12,7 +12,7 @@ import { SettingsContainer } from './shared'
|
||||
|
||||
interface SessionEssentialSettingsProps {
|
||||
session: GetAgentSessionResponse | undefined | null
|
||||
update: ReturnType<typeof useUpdateAgent>
|
||||
update: ReturnType<typeof useUpdateAgent>['updateAgent']
|
||||
}
|
||||
|
||||
const SessionEssentialSettings: FC<SessionEssentialSettingsProps> = ({ session, update }) => {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user