mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-26 03:31:24 +08:00
refactor(ModelEditContent): enhance model capabilities management and… (#8562)
* refactor(ModelEditContent): enhance model capabilities management and introduce uniqueObjectArray utility - Updated ModelEditContent to improve handling of model capabilities, ensuring user selections are accurately reflected. - Introduced a new utility function, uniqueObjectArray, to filter out duplicate objects in arrays, enhancing data integrity. - Refactored logic for updating model capabilities to utilize the new utility, streamlining the process and improving code clarity. * refactor(ModelEditContent): enhance model capabilities management with useEffect and improved type handling - Added useEffect to manage model capabilities based on user selections and showMoreSettings state. - Refactored logic to streamline the handling of default and selected model types, improving clarity and maintainability. - Utilized useRef to track changed types, ensuring accurate updates to model capabilities during user interactions. * refactor(ModelEditContent): optimize model capabilities update logic with getUnion utility - Enhanced the model capabilities management by integrating the getUnion utility to streamline the merging of selected types and unselected capabilities. - Improved clarity in the useEffect hook for managing model capabilities based on user selections and the showMoreSettings state. - Refactored condition checks for updating user selections to ensure accurate handling of model capabilities during interactions. * refactor(ModelEditContent): improve model capabilities reset logic and enhance debugging - Introduced a cloneDeep utility to preserve original model capabilities for reset functionality. - Updated the handleResetTypes function to restore original capabilities instead of clearing them. - Added console logs for better debugging and tracking of model capabilities during updates. * feat(ModelEditContent): track user modifications for model capabilities - Added a state variable to track if the user has modified model capabilities. - Updated the handleTypeChange function to set the modification flag when types are changed. - Modified the reset button to only display when there are user modifications, enhancing the user interface and interaction clarity.
This commit is contained in:
parent
f599bc80a1
commit
2b750b6d29
@ -10,10 +10,11 @@ import {
|
||||
} from '@renderer/config/models'
|
||||
import { useDynamicLabelWidth } from '@renderer/hooks/useDynamicLabelWidth'
|
||||
import { Model, ModelCapability, ModelType, Provider } from '@renderer/types'
|
||||
import { getDefaultGroupName, getDifference, getUnion } from '@renderer/utils'
|
||||
import { getDefaultGroupName, getDifference, getUnion, uniqueObjectArray } from '@renderer/utils'
|
||||
import { Button, Checkbox, Divider, Flex, Form, Input, InputNumber, message, Modal, Select, Switch } from 'antd'
|
||||
import { cloneDeep } from 'lodash'
|
||||
import { ChevronDown, ChevronUp } from 'lucide-react'
|
||||
import { FC, useState } from 'react'
|
||||
import { FC, useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
|
||||
@ -33,7 +34,9 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
|
||||
const [currencySymbol, setCurrencySymbol] = useState(model.pricing?.currencySymbol || '$')
|
||||
const [isCustomCurrency, setIsCustomCurrency] = useState(!symbols.includes(model.pricing?.currencySymbol || '$'))
|
||||
const [modelCapabilities, setModelCapabilities] = useState(model.capabilities || [])
|
||||
const originalModelCapabilities = cloneDeep(model.capabilities || [])
|
||||
const [supportedTextDelta, setSupportedTextDelta] = useState(model.supported_text_delta)
|
||||
const [hasUserModified, setHasUserModified] = useState(false)
|
||||
|
||||
const labelWidth = useDynamicLabelWidth([t('settings.models.add.endpoint_type.label')])
|
||||
|
||||
@ -69,6 +72,41 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
|
||||
{ label: t('models.price.custom'), value: 'custom' }
|
||||
]
|
||||
|
||||
const defaultTypes = [
|
||||
...(isVisionModel(model) ? ['vision'] : []),
|
||||
...(isReasoningModel(model) ? ['reasoning'] : []),
|
||||
...(isFunctionCallingModel(model) ? ['function_calling'] : []),
|
||||
...(isWebSearchModel(model) ? ['web_search'] : []),
|
||||
...(isEmbeddingModel(model) ? ['embedding'] : []),
|
||||
...(isRerankModel(model) ? ['rerank'] : [])
|
||||
]
|
||||
|
||||
const selectedTypes: string[] = getUnion(
|
||||
modelCapabilities?.filter((t) => t.isUserSelected).map((t) => t.type) || [],
|
||||
getDifference(defaultTypes, modelCapabilities?.filter((t) => t.isUserSelected === false).map((t) => t.type) || [])
|
||||
)
|
||||
|
||||
// 被rerank/embedding改变的类型
|
||||
const changedTypesRef = useRef<string[]>([])
|
||||
|
||||
useEffect(() => {
|
||||
if (showMoreSettings) {
|
||||
const newModelCapabilities = getUnion(
|
||||
selectedTypes.map((type) => {
|
||||
const existingCapability = modelCapabilities?.find((m) => m.type === type)
|
||||
return {
|
||||
type: type as ModelType,
|
||||
isUserSelected: existingCapability?.isUserSelected ?? undefined
|
||||
}
|
||||
}),
|
||||
modelCapabilities?.filter((t) => t.isUserSelected === false),
|
||||
(item) => item.type
|
||||
)
|
||||
setModelCapabilities(newModelCapabilities)
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [showMoreSettings])
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={t('models.edit')}
|
||||
@ -182,29 +220,10 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
|
||||
<Divider style={{ margin: '16px 0 16px 0' }} />
|
||||
<TypeTitle>{t('models.type.select')}:</TypeTitle>
|
||||
{(() => {
|
||||
const defaultTypes = [
|
||||
...(isVisionModel(model) ? ['vision'] : []),
|
||||
...(isReasoningModel(model) ? ['reasoning'] : []),
|
||||
...(isFunctionCallingModel(model) ? ['function_calling'] : []),
|
||||
...(isWebSearchModel(model) ? ['web_search'] : []),
|
||||
...(isEmbeddingModel(model) ? ['embedding'] : []),
|
||||
...(isRerankModel(model) ? ['rerank'] : [])
|
||||
]
|
||||
|
||||
// 合并现有选择和默认类型用于前端展示
|
||||
const selectedTypes = getUnion(
|
||||
modelCapabilities?.filter((t) => t.isUserSelected).map((t) => t.type) || [],
|
||||
getDifference(
|
||||
defaultTypes,
|
||||
modelCapabilities?.filter((t) => t.isUserSelected === false).map((t) => t.type) || []
|
||||
)
|
||||
)
|
||||
|
||||
const isDisabled = selectedTypes.includes('rerank') || selectedTypes.includes('embedding')
|
||||
|
||||
const isRerankDisabled = selectedTypes.includes('embedding')
|
||||
const isEmbeddingDisabled = selectedTypes.includes('rerank')
|
||||
|
||||
const showTypeConfirmModal = (newCapability: ModelCapability) => {
|
||||
const onUpdateType = selectedTypes?.find((t) => t === newCapability.type)
|
||||
window.modal.confirm({
|
||||
@ -216,30 +235,38 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
|
||||
cancelButtonProps: { type: 'primary' },
|
||||
onOk: () => {
|
||||
if (onUpdateType) {
|
||||
const updatedTypes = selectedTypes?.map((t) => {
|
||||
if (t === newCapability.type) {
|
||||
return { type: t, isUserSelected: true }
|
||||
const updatedModelCapabilities = modelCapabilities?.map((t) => {
|
||||
if (t.type === newCapability.type) {
|
||||
return { ...t, isUserSelected: true }
|
||||
}
|
||||
if (
|
||||
(onUpdateType !== t && onUpdateType === 'rerank') ||
|
||||
(onUpdateType === 'embedding' && onUpdateType !== t)
|
||||
((onUpdateType !== t.type && onUpdateType === 'rerank') ||
|
||||
(onUpdateType === 'embedding' && onUpdateType !== t.type)) &&
|
||||
t.isUserSelected !== false
|
||||
) {
|
||||
return { type: t, isUserSelected: false }
|
||||
changedTypesRef.current.push(t.type)
|
||||
return { ...t, isUserSelected: false }
|
||||
}
|
||||
return { type: t }
|
||||
return t
|
||||
})
|
||||
setModelCapabilities(updatedTypes as ModelCapability[])
|
||||
setModelCapabilities(uniqueObjectArray(updatedModelCapabilities as ModelCapability[]))
|
||||
} else {
|
||||
const updatedTypes = selectedTypes?.map((t) => {
|
||||
const updatedModelCapabilities = modelCapabilities?.map((t) => {
|
||||
if (
|
||||
(newCapability.type !== t && newCapability.type === 'rerank') ||
|
||||
(newCapability.type === 'embedding' && newCapability.type !== t)
|
||||
((newCapability.type !== t.type && newCapability.type === 'rerank') ||
|
||||
(newCapability.type === 'embedding' && newCapability.type !== t.type)) &&
|
||||
t.isUserSelected !== false
|
||||
) {
|
||||
return { type: t, isUserSelected: false }
|
||||
changedTypesRef.current.push(t.type)
|
||||
return { ...t, isUserSelected: false }
|
||||
}
|
||||
return { type: t }
|
||||
if (newCapability.type === t.type) {
|
||||
return { ...t, isUserSelected: true }
|
||||
}
|
||||
return t
|
||||
})
|
||||
setModelCapabilities([...(updatedTypes as ModelCapability[]), newCapability])
|
||||
updatedModelCapabilities.push(newCapability as any)
|
||||
setModelCapabilities(uniqueObjectArray(updatedModelCapabilities as ModelCapability[]))
|
||||
}
|
||||
},
|
||||
onCancel: () => {},
|
||||
@ -248,6 +275,7 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
|
||||
}
|
||||
|
||||
const handleTypeChange = (types: string[]) => {
|
||||
setHasUserModified(true) // 标记用户已进行修改
|
||||
const diff = types.length > selectedTypes.length
|
||||
if (diff) {
|
||||
const newCapability = getDifference(types, selectedTypes) // checkbox的特性,确保了newCapability只有一个元素
|
||||
@ -264,16 +292,19 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
|
||||
return { ...t, isUserSelected: false }
|
||||
}
|
||||
if (
|
||||
(onUpdateType !== t && onUpdateType.type === 'rerank') ||
|
||||
(onUpdateType.type === 'embedding' && onUpdateType !== t && t.isUserSelected === false)
|
||||
((onUpdateType !== t && onUpdateType.type === 'rerank') ||
|
||||
(onUpdateType.type === 'embedding' && onUpdateType !== t)) &&
|
||||
t.isUserSelected === false
|
||||
) {
|
||||
return { ...t, isUserSelected: true }
|
||||
if (changedTypesRef.current.includes(t.type)) {
|
||||
return { ...t, isUserSelected: true }
|
||||
}
|
||||
}
|
||||
return t
|
||||
})
|
||||
setModelCapabilities(updatedTypes || [])
|
||||
setModelCapabilities(uniqueObjectArray(updatedTypes as ModelCapability[]))
|
||||
} else {
|
||||
const updatedTypes = modelCapabilities?.map((t) => {
|
||||
const updatedModelCapabilities = modelCapabilities?.map((t) => {
|
||||
if (
|
||||
(disabledTypes[0] === 'rerank' && t.type !== 'rerank') ||
|
||||
(disabledTypes[0] === 'embedding' && t.type !== 'embedding' && t.isUserSelected === false)
|
||||
@ -282,16 +313,16 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
|
||||
}
|
||||
return t
|
||||
})
|
||||
setModelCapabilities([
|
||||
...(updatedTypes ?? []),
|
||||
{ type: disabledTypes[0] as ModelType, isUserSelected: false }
|
||||
])
|
||||
updatedModelCapabilities.push({ type: disabledTypes[0] as ModelType, isUserSelected: false })
|
||||
setModelCapabilities(uniqueObjectArray(updatedModelCapabilities as ModelCapability[]))
|
||||
}
|
||||
changedTypesRef.current.length = 0
|
||||
}
|
||||
}
|
||||
|
||||
const handleResetTypes = () => {
|
||||
setModelCapabilities([])
|
||||
setModelCapabilities(originalModelCapabilities)
|
||||
setHasUserModified(false) // 重置后清除修改标志
|
||||
}
|
||||
|
||||
return (
|
||||
@ -333,9 +364,11 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
|
||||
}
|
||||
]}
|
||||
/>
|
||||
<Button size="small" onClick={handleResetTypes}>
|
||||
{t('common.reset')}
|
||||
</Button>
|
||||
{hasUserModified && (
|
||||
<Button size="small" onClick={handleResetTypes}>
|
||||
{t('common.reset')}
|
||||
</Button>
|
||||
)}
|
||||
</Flex>
|
||||
</div>
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { Language, Model, ModelType, Provider } from '@renderer/types'
|
||||
import { ModalFuncProps } from 'antd'
|
||||
import { isEqual } from 'lodash'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
const logger = loggerService.withContext('Utils')
|
||||
@ -248,6 +249,10 @@ export function mapLanguageToQwenMTModel(language: Language): string {
|
||||
return language.value
|
||||
}
|
||||
|
||||
export function uniqueObjectArray<T>(array: T[]): T[] {
|
||||
return array.filter((obj, index, self) => index === self.findIndex((t) => isEqual(t, obj)))
|
||||
}
|
||||
|
||||
export * from './api'
|
||||
export * from './collection'
|
||||
export * from './file'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user