diff --git a/src/renderer/src/components/ModelList/ModelEditContent.tsx b/src/renderer/src/components/ModelList/ModelEditContent.tsx index b1085908c7..b6195afb40 100644 --- a/src/renderer/src/components/ModelList/ModelEditContent.tsx +++ b/src/renderer/src/components/ModelList/ModelEditContent.tsx @@ -10,10 +10,11 @@ import { } from '@renderer/config/models' import { useDynamicLabelWidth } from '@renderer/hooks/useDynamicLabelWidth' import { Model, ModelCapability, ModelType, Provider } from '@renderer/types' -import { getDefaultGroupName, getDifference, getUnion } from '@renderer/utils' +import { getDefaultGroupName, getDifference, getUnion, uniqueObjectArray } from '@renderer/utils' import { Button, Checkbox, Divider, Flex, Form, Input, InputNumber, message, Modal, Select, Switch } from 'antd' +import { cloneDeep } from 'lodash' import { ChevronDown, ChevronUp } from 'lucide-react' -import { FC, useState } from 'react' +import { FC, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import styled from 'styled-components' @@ -33,7 +34,9 @@ const ModelEditContent: FC = ({ provider, model, onUpdate const [currencySymbol, setCurrencySymbol] = useState(model.pricing?.currencySymbol || '$') const [isCustomCurrency, setIsCustomCurrency] = useState(!symbols.includes(model.pricing?.currencySymbol || '$')) const [modelCapabilities, setModelCapabilities] = useState(model.capabilities || []) + const originalModelCapabilities = cloneDeep(model.capabilities || []) const [supportedTextDelta, setSupportedTextDelta] = useState(model.supported_text_delta) + const [hasUserModified, setHasUserModified] = useState(false) const labelWidth = useDynamicLabelWidth([t('settings.models.add.endpoint_type.label')]) @@ -69,6 +72,41 @@ const ModelEditContent: FC = ({ provider, model, onUpdate { label: t('models.price.custom'), value: 'custom' } ] + const defaultTypes = [ + ...(isVisionModel(model) ? ['vision'] : []), + ...(isReasoningModel(model) ? ['reasoning'] : []), + ...(isFunctionCallingModel(model) ? ['function_calling'] : []), + ...(isWebSearchModel(model) ? ['web_search'] : []), + ...(isEmbeddingModel(model) ? ['embedding'] : []), + ...(isRerankModel(model) ? ['rerank'] : []) + ] + + const selectedTypes: string[] = getUnion( + modelCapabilities?.filter((t) => t.isUserSelected).map((t) => t.type) || [], + getDifference(defaultTypes, modelCapabilities?.filter((t) => t.isUserSelected === false).map((t) => t.type) || []) + ) + + // 被rerank/embedding改变的类型 + const changedTypesRef = useRef([]) + + useEffect(() => { + if (showMoreSettings) { + const newModelCapabilities = getUnion( + selectedTypes.map((type) => { + const existingCapability = modelCapabilities?.find((m) => m.type === type) + return { + type: type as ModelType, + isUserSelected: existingCapability?.isUserSelected ?? undefined + } + }), + modelCapabilities?.filter((t) => t.isUserSelected === false), + (item) => item.type + ) + setModelCapabilities(newModelCapabilities) + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [showMoreSettings]) + return ( = ({ provider, model, onUpdate {t('models.type.select')}: {(() => { - const defaultTypes = [ - ...(isVisionModel(model) ? ['vision'] : []), - ...(isReasoningModel(model) ? ['reasoning'] : []), - ...(isFunctionCallingModel(model) ? ['function_calling'] : []), - ...(isWebSearchModel(model) ? ['web_search'] : []), - ...(isEmbeddingModel(model) ? ['embedding'] : []), - ...(isRerankModel(model) ? ['rerank'] : []) - ] - - // 合并现有选择和默认类型用于前端展示 - const selectedTypes = getUnion( - modelCapabilities?.filter((t) => t.isUserSelected).map((t) => t.type) || [], - getDifference( - defaultTypes, - modelCapabilities?.filter((t) => t.isUserSelected === false).map((t) => t.type) || [] - ) - ) - const isDisabled = selectedTypes.includes('rerank') || selectedTypes.includes('embedding') const isRerankDisabled = selectedTypes.includes('embedding') const isEmbeddingDisabled = selectedTypes.includes('rerank') - const showTypeConfirmModal = (newCapability: ModelCapability) => { const onUpdateType = selectedTypes?.find((t) => t === newCapability.type) window.modal.confirm({ @@ -216,30 +235,38 @@ const ModelEditContent: FC = ({ provider, model, onUpdate cancelButtonProps: { type: 'primary' }, onOk: () => { if (onUpdateType) { - const updatedTypes = selectedTypes?.map((t) => { - if (t === newCapability.type) { - return { type: t, isUserSelected: true } + const updatedModelCapabilities = modelCapabilities?.map((t) => { + if (t.type === newCapability.type) { + return { ...t, isUserSelected: true } } if ( - (onUpdateType !== t && onUpdateType === 'rerank') || - (onUpdateType === 'embedding' && onUpdateType !== t) + ((onUpdateType !== t.type && onUpdateType === 'rerank') || + (onUpdateType === 'embedding' && onUpdateType !== t.type)) && + t.isUserSelected !== false ) { - return { type: t, isUserSelected: false } + changedTypesRef.current.push(t.type) + return { ...t, isUserSelected: false } } - return { type: t } + return t }) - setModelCapabilities(updatedTypes as ModelCapability[]) + setModelCapabilities(uniqueObjectArray(updatedModelCapabilities as ModelCapability[])) } else { - const updatedTypes = selectedTypes?.map((t) => { + const updatedModelCapabilities = modelCapabilities?.map((t) => { if ( - (newCapability.type !== t && newCapability.type === 'rerank') || - (newCapability.type === 'embedding' && newCapability.type !== t) + ((newCapability.type !== t.type && newCapability.type === 'rerank') || + (newCapability.type === 'embedding' && newCapability.type !== t.type)) && + t.isUserSelected !== false ) { - return { type: t, isUserSelected: false } + changedTypesRef.current.push(t.type) + return { ...t, isUserSelected: false } } - return { type: t } + if (newCapability.type === t.type) { + return { ...t, isUserSelected: true } + } + return t }) - setModelCapabilities([...(updatedTypes as ModelCapability[]), newCapability]) + updatedModelCapabilities.push(newCapability as any) + setModelCapabilities(uniqueObjectArray(updatedModelCapabilities as ModelCapability[])) } }, onCancel: () => {}, @@ -248,6 +275,7 @@ const ModelEditContent: FC = ({ provider, model, onUpdate } const handleTypeChange = (types: string[]) => { + setHasUserModified(true) // 标记用户已进行修改 const diff = types.length > selectedTypes.length if (diff) { const newCapability = getDifference(types, selectedTypes) // checkbox的特性,确保了newCapability只有一个元素 @@ -264,16 +292,19 @@ const ModelEditContent: FC = ({ provider, model, onUpdate return { ...t, isUserSelected: false } } if ( - (onUpdateType !== t && onUpdateType.type === 'rerank') || - (onUpdateType.type === 'embedding' && onUpdateType !== t && t.isUserSelected === false) + ((onUpdateType !== t && onUpdateType.type === 'rerank') || + (onUpdateType.type === 'embedding' && onUpdateType !== t)) && + t.isUserSelected === false ) { - return { ...t, isUserSelected: true } + if (changedTypesRef.current.includes(t.type)) { + return { ...t, isUserSelected: true } + } } return t }) - setModelCapabilities(updatedTypes || []) + setModelCapabilities(uniqueObjectArray(updatedTypes as ModelCapability[])) } else { - const updatedTypes = modelCapabilities?.map((t) => { + const updatedModelCapabilities = modelCapabilities?.map((t) => { if ( (disabledTypes[0] === 'rerank' && t.type !== 'rerank') || (disabledTypes[0] === 'embedding' && t.type !== 'embedding' && t.isUserSelected === false) @@ -282,16 +313,16 @@ const ModelEditContent: FC = ({ provider, model, onUpdate } return t }) - setModelCapabilities([ - ...(updatedTypes ?? []), - { type: disabledTypes[0] as ModelType, isUserSelected: false } - ]) + updatedModelCapabilities.push({ type: disabledTypes[0] as ModelType, isUserSelected: false }) + setModelCapabilities(uniqueObjectArray(updatedModelCapabilities as ModelCapability[])) } + changedTypesRef.current.length = 0 } } const handleResetTypes = () => { - setModelCapabilities([]) + setModelCapabilities(originalModelCapabilities) + setHasUserModified(false) // 重置后清除修改标志 } return ( @@ -333,9 +364,11 @@ const ModelEditContent: FC = ({ provider, model, onUpdate } ]} /> - + {hasUserModified && ( + + )} ) diff --git a/src/renderer/src/utils/index.ts b/src/renderer/src/utils/index.ts index 7d2f0bfb0e..85275ad15d 100644 --- a/src/renderer/src/utils/index.ts +++ b/src/renderer/src/utils/index.ts @@ -1,6 +1,7 @@ import { loggerService } from '@logger' import { Language, Model, ModelType, Provider } from '@renderer/types' import { ModalFuncProps } from 'antd' +import { isEqual } from 'lodash' import { v4 as uuidv4 } from 'uuid' const logger = loggerService.withContext('Utils') @@ -248,6 +249,10 @@ export function mapLanguageToQwenMTModel(language: Language): string { return language.value } +export function uniqueObjectArray(array: T[]): T[] { + return array.filter((obj, index, self) => index === self.findIndex((t) => isEqual(t, obj))) +} + export * from './api' export * from './collection' export * from './file'