mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-24 10:33:05 +08:00
fix(ModelEdit): enhance model type management and introduce new selection logic (#8420)
* fix(ModelEdit): enhance model type management and introduce new selection logic - Added support for 'rerank' model type in the ModelEditContent component. - Refactored type selection logic to utilize new utility functions for finding differences and unions in model types. - Updated model type handling to include user selection status, improving user experience in type management. - Adjusted migration logic to initialize newType for existing models, ensuring backward compatibility. - Introduced isUserSelectedModelType utility to streamline model type checks across the application. * refactor(isFunctionCallingModel): simplify model type check logic - Replaced the inline check for 'function_calling' model type with a call to the new utility function isUserSelectedModelType, enhancing code clarity and maintainability. * feat(collection): add utility functions for array operations - Introduced `findIntersection`, `findDifference`, and `findUnion` functions to handle array operations with support for custom key selectors and comparison functions. - Removed previous implementations from `index.ts` to streamline utility exports. - Added comprehensive tests for new functions covering basic types and object types with various edge cases. * refactor(collection): rename utility functions for clarity - Renamed `findIntersection`, `findDifference`, and `findUnion` to `getIntersection`, `getDifference`, and `getUnion` respectively for improved clarity and consistency in naming. - Updated corresponding tests to reflect the new function names, ensuring all tests pass with the updated utility functions. * refactor(ModelEditContent): update model type management and improve selection logic - Replaced utility function calls to `findDifference` and `findUnion` with `getDifference` and `getUnion` for consistency. - Introduced temporary state management for model types to enhance user selection handling. - Added a reset functionality for model type selections, improving user experience. - Updated the rendering logic to conditionally disable certain model types based on user selections. * fix(ModelEditContent): enhance model type selection logic with conditional disabling - Introduced logic to conditionally disable 'rerank' and 'embedding' model types based on user selections. - Updated the state management for model types to ensure correct user selection handling. - Improved the confirmation modal to reflect the updated selection logic for better user experience. * fix(ModelEditContent): refine model type selection and update confirmation logic - Enhanced the logic for model type selection to ensure accurate user selections for 'rerank' and 'embedding'. - Updated the confirmation modal to reflect changes in selection handling, improving user experience. - Adjusted state management to correctly handle updates based on selected model types. * fix(models): update model support logic to include 'qwen3-235b-a22b-instruct' * refactor(models): rename 'newType' to 'capabilities' and update related logic in ModelEditContent and migration scripts
This commit is contained in:
parent
0453402242
commit
6c44f7fe24
@ -4,12 +4,13 @@ import {
|
||||
isEmbeddingModel,
|
||||
isFunctionCallingModel,
|
||||
isReasoningModel,
|
||||
isRerankModel,
|
||||
isVisionModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { useDynamicLabelWidth } from '@renderer/hooks/useDynamicLabelWidth'
|
||||
import { Model, ModelType, Provider } from '@renderer/types'
|
||||
import { getDefaultGroupName } from '@renderer/utils'
|
||||
import { Model, ModelCapability, ModelType, Provider } from '@renderer/types'
|
||||
import { getDefaultGroupName, getDifference, getUnion } from '@renderer/utils'
|
||||
import { Button, Checkbox, Divider, Flex, Form, Input, InputNumber, message, Modal, Select } from 'antd'
|
||||
import { ChevronDown, ChevronUp } from 'lucide-react'
|
||||
import { FC, useState } from 'react'
|
||||
@ -31,6 +32,7 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
|
||||
const [showMoreSettings, setShowMoreSettings] = useState(false)
|
||||
const [currencySymbol, setCurrencySymbol] = useState(model.pricing?.currencySymbol || '$')
|
||||
const [isCustomCurrency, setIsCustomCurrency] = useState(!symbols.includes(model.pricing?.currencySymbol || '$'))
|
||||
const [tempModelTypes, setTempModelTypes] = useState(model.capabilities || [])
|
||||
|
||||
const labelWidth = useDynamicLabelWidth([t('settings.models.add.endpoint_type')])
|
||||
|
||||
@ -42,6 +44,7 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
|
||||
name: values.name || model.name,
|
||||
group: values.group || model.group,
|
||||
endpoint_type: provider.id === 'new-api' ? values.endpointType : model.endpoint_type,
|
||||
newCapability: tempModelTypes,
|
||||
pricing: {
|
||||
input_per_million_tokens: Number(values.input_per_million_tokens) || 0,
|
||||
output_per_million_tokens: Number(values.output_per_million_tokens) || 0,
|
||||
@ -55,6 +58,7 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
|
||||
|
||||
const handleClose = () => {
|
||||
setShowMoreSettings(false)
|
||||
setTempModelTypes(model.capabilities || [])
|
||||
onClose()
|
||||
}
|
||||
|
||||
@ -179,16 +183,29 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
|
||||
{(() => {
|
||||
const defaultTypes = [
|
||||
...(isVisionModel(model) ? ['vision'] : []),
|
||||
...(isEmbeddingModel(model) ? ['embedding'] : []),
|
||||
...(isReasoningModel(model) ? ['reasoning'] : []),
|
||||
...(isFunctionCallingModel(model) ? ['function_calling'] : []),
|
||||
...(isWebSearchModel(model) ? ['web_search'] : [])
|
||||
] as ModelType[]
|
||||
...(isWebSearchModel(model) ? ['web_search'] : []),
|
||||
...(isEmbeddingModel(model) ? ['embedding'] : []),
|
||||
...(isRerankModel(model) ? ['rerank'] : [])
|
||||
]
|
||||
|
||||
// 合并现有选择和默认类型
|
||||
const selectedTypes = [...new Set([...(model.type || []), ...defaultTypes])]
|
||||
// 合并现有选择和默认类型用于前端展示
|
||||
const selectedTypes = getUnion(
|
||||
tempModelTypes?.filter((t) => t.isUserSelected).map((t) => t.type) || [],
|
||||
getDifference(
|
||||
defaultTypes,
|
||||
tempModelTypes?.filter((t) => t.isUserSelected === false).map((t) => t.type) || []
|
||||
)
|
||||
)
|
||||
|
||||
const showTypeConfirmModal = (type: string) => {
|
||||
const isDisabled = selectedTypes.includes('rerank') || selectedTypes.includes('embedding')
|
||||
|
||||
const isRerankDisabled = selectedTypes.includes('embedding')
|
||||
const isEmbeddingDisabled = selectedTypes.includes('rerank')
|
||||
|
||||
const showTypeConfirmModal = (newCapability: ModelCapability) => {
|
||||
const onUpdateType = selectedTypes?.find((t) => t === newCapability.type)
|
||||
window.modal.confirm({
|
||||
title: t('settings.moresetting.warn'),
|
||||
content: t('settings.moresetting.check.warn'),
|
||||
@ -196,54 +213,130 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
|
||||
cancelText: t('common.cancel'),
|
||||
okButtonProps: { danger: true },
|
||||
cancelButtonProps: { type: 'primary' },
|
||||
onOk: () => onUpdateModel({ ...model, type: [...selectedTypes, type] as ModelType[] }),
|
||||
onOk: () => {
|
||||
if (onUpdateType) {
|
||||
const updatedTypes = selectedTypes?.map((t) => {
|
||||
if (t === newCapability.type) {
|
||||
return { type: t, isUserSelected: true }
|
||||
}
|
||||
if (
|
||||
(onUpdateType !== t && onUpdateType === 'rerank') ||
|
||||
(onUpdateType === 'embedding' && onUpdateType !== t)
|
||||
) {
|
||||
return { type: t, isUserSelected: false }
|
||||
}
|
||||
return { type: t }
|
||||
})
|
||||
setTempModelTypes(updatedTypes as ModelCapability[])
|
||||
} else {
|
||||
const updatedTypes = selectedTypes?.map((t) => {
|
||||
if (
|
||||
(newCapability.type !== t && newCapability.type === 'rerank') ||
|
||||
(newCapability.type === 'embedding' && newCapability.type !== t)
|
||||
) {
|
||||
return { type: t, isUserSelected: false }
|
||||
}
|
||||
return { type: t }
|
||||
})
|
||||
setTempModelTypes([...(updatedTypes as ModelCapability[]), newCapability])
|
||||
}
|
||||
},
|
||||
onCancel: () => {},
|
||||
centered: true
|
||||
})
|
||||
}
|
||||
|
||||
const handleTypeChange = (types: string[]) => {
|
||||
const newType = types.find((type) => !selectedTypes.includes(type as ModelType))
|
||||
|
||||
if (newType) {
|
||||
showTypeConfirmModal(newType)
|
||||
const diff = types.length > selectedTypes.length
|
||||
if (diff) {
|
||||
const newCapability = getDifference(types, selectedTypes) // checkbox的特性,确保了newCapability只有一个元素
|
||||
showTypeConfirmModal({
|
||||
type: newCapability[0] as ModelType,
|
||||
isUserSelected: true
|
||||
})
|
||||
} else {
|
||||
onUpdateModel({ ...model, type: types as ModelType[] })
|
||||
const disabledTypes = getDifference(selectedTypes, types)
|
||||
const onUpdateType = tempModelTypes?.find((t) => t.type === disabledTypes[0])
|
||||
if (onUpdateType) {
|
||||
const updatedTypes = tempModelTypes?.map((t) => {
|
||||
if (t.type === disabledTypes[0]) {
|
||||
return { ...t, isUserSelected: false }
|
||||
}
|
||||
if (
|
||||
(onUpdateType !== t && onUpdateType.type === 'rerank') ||
|
||||
(onUpdateType.type === 'embedding' && onUpdateType !== t && t.isUserSelected === false)
|
||||
) {
|
||||
return { ...t, isUserSelected: true }
|
||||
}
|
||||
return t
|
||||
})
|
||||
setTempModelTypes(updatedTypes || [])
|
||||
} else {
|
||||
const updatedTypes = tempModelTypes?.map((t) => {
|
||||
if (
|
||||
(disabledTypes[0] === 'rerank' && t.type !== 'rerank') ||
|
||||
(disabledTypes[0] === 'embedding' && t.type !== 'embedding' && t.isUserSelected === false)
|
||||
) {
|
||||
return { ...t, isUserSelected: true }
|
||||
}
|
||||
return t
|
||||
})
|
||||
setTempModelTypes([
|
||||
...(updatedTypes ?? []),
|
||||
{ type: disabledTypes[0] as ModelType, isUserSelected: false }
|
||||
])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const handleResetTypes = () => {
|
||||
setTempModelTypes([])
|
||||
}
|
||||
|
||||
return (
|
||||
<Checkbox.Group
|
||||
value={selectedTypes}
|
||||
onChange={handleTypeChange}
|
||||
options={[
|
||||
{
|
||||
label: t('models.type.vision'),
|
||||
value: 'vision',
|
||||
disabled: isVisionModel(model) && !selectedTypes.includes('vision')
|
||||
},
|
||||
{
|
||||
label: t('models.type.websearch'),
|
||||
value: 'web_search',
|
||||
disabled: isWebSearchModel(model) && !selectedTypes.includes('web_search')
|
||||
},
|
||||
{
|
||||
label: t('models.type.embedding'),
|
||||
value: 'embedding',
|
||||
disabled: isEmbeddingModel(model) && !selectedTypes.includes('embedding')
|
||||
},
|
||||
{
|
||||
label: t('models.type.reasoning'),
|
||||
value: 'reasoning',
|
||||
disabled: isReasoningModel(model) && !selectedTypes.includes('reasoning')
|
||||
},
|
||||
{
|
||||
label: t('models.type.function_calling'),
|
||||
value: 'function_calling',
|
||||
disabled: isFunctionCallingModel(model) && !selectedTypes.includes('function_calling')
|
||||
}
|
||||
]}
|
||||
/>
|
||||
<div>
|
||||
<Flex justify="space-between" align="center" style={{ marginBottom: 8 }}>
|
||||
<Checkbox.Group
|
||||
value={selectedTypes}
|
||||
onChange={handleTypeChange}
|
||||
options={[
|
||||
{
|
||||
label: t('models.type.vision'),
|
||||
value: 'vision',
|
||||
disabled: isDisabled
|
||||
},
|
||||
{
|
||||
label: t('models.type.websearch'),
|
||||
value: 'web_search',
|
||||
disabled: isDisabled
|
||||
},
|
||||
{
|
||||
label: t('models.type.rerank'),
|
||||
value: 'rerank',
|
||||
disabled: isRerankDisabled
|
||||
},
|
||||
{
|
||||
label: t('models.type.embedding'),
|
||||
value: 'embedding',
|
||||
disabled: isEmbeddingDisabled
|
||||
},
|
||||
{
|
||||
label: t('models.type.reasoning'),
|
||||
value: 'reasoning',
|
||||
disabled: isDisabled
|
||||
},
|
||||
{
|
||||
label: t('models.type.function_calling'),
|
||||
value: 'function_calling',
|
||||
disabled: isDisabled
|
||||
}
|
||||
]}
|
||||
/>
|
||||
<Button size="small" onClick={handleResetTypes}>
|
||||
{t('common.reset')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</div>
|
||||
)
|
||||
})()}
|
||||
<TypeTitle>{t('models.price.price')}</TypeTitle>
|
||||
|
||||
@ -145,7 +145,7 @@ import YoudaoLogo from '@renderer/assets/images/providers/netease-youdao.svg'
|
||||
import NomicLogo from '@renderer/assets/images/providers/nomic.png'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { Model } from '@renderer/types'
|
||||
import { getLowerBaseModelName } from '@renderer/utils'
|
||||
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from './prompts'
|
||||
@ -265,16 +265,12 @@ export const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp(
|
||||
)
|
||||
|
||||
export function isFunctionCallingModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
if (!model || isEmbeddingModel(model) || isRerankModel(model)) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (model.type?.includes('function_calling')) {
|
||||
return true
|
||||
}
|
||||
|
||||
if (isEmbeddingModel(model)) {
|
||||
return false
|
||||
if (isUserSelectedModelType(model, 'function_calling') !== undefined) {
|
||||
return isUserSelectedModelType(model, 'function_calling')!
|
||||
}
|
||||
|
||||
if (model.provider === 'qiniu') {
|
||||
@ -2395,10 +2391,14 @@ export function isTextToImageModel(model: Model): boolean {
|
||||
}
|
||||
|
||||
export function isEmbeddingModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
if (!model || isRerankModel(model)) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (isUserSelectedModelType(model, 'embedding') !== undefined) {
|
||||
return isUserSelectedModelType(model, 'embedding')!
|
||||
}
|
||||
|
||||
if (['anthropic'].includes(model?.provider)) {
|
||||
return false
|
||||
}
|
||||
@ -2407,31 +2407,33 @@ export function isEmbeddingModel(model: Model): boolean {
|
||||
return EMBEDDING_REGEX.test(model.name)
|
||||
}
|
||||
|
||||
if (isRerankModel(model)) {
|
||||
return false
|
||||
}
|
||||
|
||||
return EMBEDDING_REGEX.test(model.id) || model.type?.includes('embedding') || false
|
||||
return EMBEDDING_REGEX.test(model.id) || false
|
||||
}
|
||||
|
||||
export function isRerankModel(model: Model): boolean {
|
||||
if (isUserSelectedModelType(model, 'rerank') !== undefined) {
|
||||
return isUserSelectedModelType(model, 'rerank')!
|
||||
}
|
||||
return model ? RERANKING_REGEX.test(model.id) || false : false
|
||||
}
|
||||
|
||||
export function isVisionModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
if (!model || isEmbeddingModel(model) || isRerankModel(model)) {
|
||||
return false
|
||||
}
|
||||
// 新添字段 copilot-vision-request 后可使用 vision
|
||||
// if (model.provider === 'copilot') {
|
||||
// return false
|
||||
// }
|
||||
|
||||
if (model.provider === 'doubao' || model.id.includes('doubao')) {
|
||||
return VISION_REGEX.test(model.name) || VISION_REGEX.test(model.id) || model.type?.includes('vision') || false
|
||||
if (isUserSelectedModelType(model, 'vision') !== undefined) {
|
||||
return isUserSelectedModelType(model, 'vision')!
|
||||
}
|
||||
|
||||
return VISION_REGEX.test(model.id) || model.type?.includes('vision') || false
|
||||
if (model.provider === 'doubao' || model.id.includes('doubao')) {
|
||||
return VISION_REGEX.test(model.name) || VISION_REGEX.test(model.id) || false
|
||||
}
|
||||
|
||||
return VISION_REGEX.test(model.id) || false
|
||||
}
|
||||
|
||||
export function isOpenAIReasoningModel(model: Model): boolean {
|
||||
@ -2599,7 +2601,7 @@ export function isSupportedThinkingTokenQwenModel(model?: Model): boolean {
|
||||
|
||||
const baseName = getLowerBaseModelName(model.id, '/')
|
||||
|
||||
if (baseName.includes('coder')) {
|
||||
if (baseName.includes('coder') || baseName.includes('qwen3-235b-a22b-instruct')) {
|
||||
return false
|
||||
}
|
||||
|
||||
@ -2660,19 +2662,18 @@ export const isHunyuanReasoningModel = (model?: Model): boolean => {
|
||||
}
|
||||
|
||||
export function isReasoningModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
|
||||
return false
|
||||
if (isUserSelectedModelType(model, 'reasoning') !== undefined) {
|
||||
return isUserSelectedModelType(model, 'reasoning')!
|
||||
}
|
||||
|
||||
if (model.provider === 'doubao' || model.id.includes('doubao')) {
|
||||
return (
|
||||
REASONING_REGEX.test(model.id) ||
|
||||
REASONING_REGEX.test(model.name) ||
|
||||
model.type?.includes('reasoning') ||
|
||||
isSupportedThinkingTokenDoubaoModel(model) ||
|
||||
false
|
||||
)
|
||||
@ -2692,7 +2693,7 @@ export function isReasoningModel(model?: Model): boolean {
|
||||
return true
|
||||
}
|
||||
|
||||
return REASONING_REGEX.test(model.id) || model.type?.includes('reasoning') || false
|
||||
return REASONING_REGEX.test(model.id) || false
|
||||
}
|
||||
|
||||
export function isSupportedModel(model: OpenAI.Models.Model): boolean {
|
||||
@ -2716,14 +2717,12 @@ export function isNotSupportTemperatureAndTopP(model: Model): boolean {
|
||||
}
|
||||
|
||||
export function isWebSearchModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
if (!model || isEmbeddingModel(model) || isRerankModel(model)) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (model.type) {
|
||||
if (model.type.includes('web_search')) {
|
||||
return true
|
||||
}
|
||||
if (isUserSelectedModelType(model, 'web_search') !== undefined) {
|
||||
return isUserSelectedModelType(model, 'web_search')!
|
||||
}
|
||||
|
||||
const provider = getProviderByModel(model)
|
||||
|
||||
@ -58,7 +58,7 @@ const persistedReducer = persistReducer(
|
||||
{
|
||||
key: 'cherry-studio',
|
||||
storage,
|
||||
version: 122,
|
||||
version: 123,
|
||||
blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs'],
|
||||
migrate
|
||||
},
|
||||
|
||||
@ -1844,6 +1844,25 @@ const migrateConfig = {
|
||||
logger.error('migrate 122 error', error as Error)
|
||||
return state
|
||||
}
|
||||
},
|
||||
'123': (state: RootState) => {
|
||||
try {
|
||||
state.llm.providers.forEach((provider) => {
|
||||
provider.models.forEach((model) => {
|
||||
if (model.type && Array.isArray(model.type)) {
|
||||
model.capabilities = model.type.map((t) => ({
|
||||
type: t,
|
||||
isUserSelected: true
|
||||
}))
|
||||
delete model.type
|
||||
}
|
||||
})
|
||||
})
|
||||
return state
|
||||
} catch (error) {
|
||||
logger.error('migrate 123 error', error as Error)
|
||||
return state
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -182,7 +182,7 @@ export type ProviderType =
|
||||
| 'vertexai'
|
||||
| 'mistral'
|
||||
|
||||
export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search'
|
||||
export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search' | 'rerank'
|
||||
|
||||
export type EndpointType = 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
|
||||
|
||||
@ -192,6 +192,15 @@ export type ModelPricing = {
|
||||
currencySymbol?: string
|
||||
}
|
||||
|
||||
export type ModelCapability = {
|
||||
type: ModelType
|
||||
/**
|
||||
* 是否为用户手动选择,如果为true,则表示用户手动选择了该类型,否则表示用户手动禁止了该模型;如果为undefined,则表示使用默认值
|
||||
* Is it manually selected by the user? If true, it means the user manually selected this type; otherwise, it means the user * manually disabled the model.
|
||||
*/
|
||||
isUserSelected?: boolean
|
||||
}
|
||||
|
||||
export type Model = {
|
||||
id: string
|
||||
provider: string
|
||||
@ -199,6 +208,10 @@ export type Model = {
|
||||
group: string
|
||||
owned_by?: string
|
||||
description?: string
|
||||
capabilities?: ModelCapability[]
|
||||
/**
|
||||
* @deprecated
|
||||
*/
|
||||
type?: ModelType[]
|
||||
pricing?: ModelPricing
|
||||
endpoint_type?: EndpointType
|
||||
|
||||
393
src/renderer/src/utils/__tests__/collection.test.ts
Normal file
393
src/renderer/src/utils/__tests__/collection.test.ts
Normal file
@ -0,0 +1,393 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { getDifference, getIntersection, getUnion } from '../collection'
|
||||
|
||||
describe('Collection Utils', () => {
|
||||
// ================== Basic Types Tests ==================
|
||||
|
||||
describe('getIntersection - Basic Types', () => {
|
||||
it('should get intersection of number arrays', () => {
|
||||
const arr1 = [1, 2, 3, 4]
|
||||
const arr2 = [3, 4, 5, 6]
|
||||
const result = getIntersection(arr1, arr2)
|
||||
expect(result).toEqual([3, 4])
|
||||
})
|
||||
|
||||
it('should get intersection of string arrays', () => {
|
||||
const arr1 = ['a', 'b', 'c']
|
||||
const arr2 = ['b', 'c', 'd']
|
||||
const result = getIntersection(arr1, arr2)
|
||||
expect(result).toEqual(['b', 'c'])
|
||||
})
|
||||
|
||||
it('should return empty array when no intersection', () => {
|
||||
const arr1 = [1, 2, 3]
|
||||
const arr2 = [4, 5, 6]
|
||||
const result = getIntersection(arr1, arr2)
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('should return empty array when one array is empty', () => {
|
||||
const arr1 = [1, 2, 3]
|
||||
const arr2: number[] = []
|
||||
const result = getIntersection(arr1, arr2)
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('should return empty array when both arrays are empty', () => {
|
||||
const arr1: number[] = []
|
||||
const arr2: number[] = []
|
||||
const result = getIntersection(arr1, arr2)
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getDifference - Basic Types', () => {
|
||||
it('should get difference of number arrays', () => {
|
||||
const arr1 = [1, 2, 3, 4]
|
||||
const arr2 = [3, 4, 5, 6]
|
||||
const result = getDifference(arr1, arr2)
|
||||
expect(result).toEqual([1, 2])
|
||||
})
|
||||
|
||||
it('should get difference of string arrays', () => {
|
||||
const arr1 = ['a', 'b', 'c']
|
||||
const arr2 = ['b', 'c', 'd']
|
||||
const result = getDifference(arr1, arr2)
|
||||
expect(result).toEqual(['a'])
|
||||
})
|
||||
|
||||
it('should return empty array when no difference', () => {
|
||||
const arr1 = [1, 2, 3]
|
||||
const arr2 = [1, 2, 3, 4, 5]
|
||||
const result = getDifference(arr1, arr2)
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('should return first array when second array is empty', () => {
|
||||
const arr1 = [1, 2, 3]
|
||||
const arr2: number[] = []
|
||||
const result = getDifference(arr1, arr2)
|
||||
expect(result).toEqual([1, 2, 3])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getUnion - Basic Types', () => {
|
||||
it('should merge number arrays correctly', () => {
|
||||
const arr1 = [1, 2, 3]
|
||||
const arr2 = [3, 4, 5]
|
||||
const result = getUnion(arr1, arr2)
|
||||
expect(result).toEqual([1, 2, 3, 4, 5])
|
||||
})
|
||||
|
||||
it('should merge string arrays correctly', () => {
|
||||
const arr1 = ['a', 'b']
|
||||
const arr2 = ['b', 'c']
|
||||
const result = getUnion(arr1, arr2)
|
||||
expect(result).toEqual(['a', 'b', 'c'])
|
||||
})
|
||||
|
||||
it('should merge arrays with no duplicates', () => {
|
||||
const arr1 = [1, 2]
|
||||
const arr2 = [3, 4]
|
||||
const result = getUnion(arr1, arr2)
|
||||
expect(result).toEqual([1, 2, 3, 4])
|
||||
})
|
||||
|
||||
it('should return other array when one is empty', () => {
|
||||
const arr1 = [1, 2, 3]
|
||||
const arr2: number[] = []
|
||||
const result = getUnion(arr1, arr2)
|
||||
expect(result).toEqual([1, 2, 3])
|
||||
})
|
||||
})
|
||||
|
||||
// ================== Object Types Tests - Key Selector ==================
|
||||
|
||||
interface User {
|
||||
id: number
|
||||
name: string
|
||||
age: number
|
||||
}
|
||||
|
||||
const users1: User[] = [
|
||||
{ id: 1, name: 'Alice', age: 25 },
|
||||
{ id: 2, name: 'Bob', age: 30 },
|
||||
{ id: 3, name: 'Charlie', age: 35 }
|
||||
]
|
||||
|
||||
const users2: User[] = [
|
||||
{ id: 2, name: 'Bob', age: 30 },
|
||||
{ id: 3, name: 'Charlie', age: 36 },
|
||||
{ id: 4, name: 'David', age: 28 }
|
||||
]
|
||||
|
||||
describe('getIntersection - Object Types (Key Selector)', () => {
|
||||
it('should get user intersection by id', () => {
|
||||
const result = getIntersection(users1, users2, (user) => user.id)
|
||||
expect(result).toEqual([
|
||||
{ id: 2, name: 'Bob', age: 30 },
|
||||
{ id: 3, name: 'Charlie', age: 35 }
|
||||
])
|
||||
})
|
||||
|
||||
it('should get user intersection by name', () => {
|
||||
const result = getIntersection(users1, users2, (user) => user.name)
|
||||
expect(result).toEqual([
|
||||
{ id: 2, name: 'Bob', age: 30 },
|
||||
{ id: 3, name: 'Charlie', age: 35 }
|
||||
])
|
||||
})
|
||||
|
||||
it('should return empty array when no intersection', () => {
|
||||
const users3: User[] = [{ id: 5, name: 'Eve', age: 40 }]
|
||||
const result = getIntersection(users1, users3, (user) => user.id)
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getDifference - Object Types (Key Selector)', () => {
|
||||
it('should get user difference by id', () => {
|
||||
const result = getDifference(users1, users2, (user) => user.id)
|
||||
expect(result).toEqual([{ id: 1, name: 'Alice', age: 25 }])
|
||||
})
|
||||
|
||||
it('should get user difference by name', () => {
|
||||
const result = getDifference(users1, users2, (user) => user.name)
|
||||
expect(result).toEqual([{ id: 1, name: 'Alice', age: 25 }])
|
||||
})
|
||||
|
||||
it('should return correct difference', () => {
|
||||
const result = getDifference(users2, users1, (user) => user.id)
|
||||
expect(result).toEqual([{ id: 4, name: 'David', age: 28 }])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getUnion - Object Types (Key Selector)', () => {
|
||||
it('should merge user arrays by id', () => {
|
||||
const result = getUnion(users1, users2, (user) => user.id)
|
||||
expect(result).toEqual([
|
||||
{ id: 1, name: 'Alice', age: 25 },
|
||||
{ id: 2, name: 'Bob', age: 30 },
|
||||
{ id: 3, name: 'Charlie', age: 35 },
|
||||
{ id: 4, name: 'David', age: 28 }
|
||||
])
|
||||
})
|
||||
|
||||
it('should preserve first array element version', () => {
|
||||
const result = getUnion(users1, users2, (user) => user.id)
|
||||
const charlie = result.find((u) => u.id === 3)
|
||||
expect(charlie?.age).toBe(35)
|
||||
})
|
||||
})
|
||||
|
||||
// ================== Object Types Tests - Comparator Function ==================
|
||||
|
||||
describe('getIntersection - Object Types (Comparator)', () => {
|
||||
it('should use custom comparator correctly', () => {
|
||||
const result = getIntersection(users1, users2, (a, b) => a.id === b.id && a.name === b.name)
|
||||
expect(result).toEqual([
|
||||
{ id: 2, name: 'Bob', age: 30 },
|
||||
{ id: 3, name: 'Charlie', age: 35 }
|
||||
])
|
||||
})
|
||||
|
||||
it('should get users with similar age', () => {
|
||||
const youngUsers: User[] = [
|
||||
{ id: 5, name: 'Eve', age: 26 },
|
||||
{ id: 6, name: 'Frank', age: 32 }
|
||||
]
|
||||
|
||||
const result = getIntersection(users1, youngUsers, (a, b) => Math.abs(a.age - b.age) <= 5)
|
||||
|
||||
expect(result).toEqual([
|
||||
{ id: 1, name: 'Alice', age: 25 },
|
||||
{ id: 2, name: 'Bob', age: 30 },
|
||||
{ id: 3, name: 'Charlie', age: 35 }
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getDifference - Object Types (Comparator)', () => {
|
||||
it('should use custom comparator for difference', () => {
|
||||
const result = getDifference(users1, users2, (a, b) => a.id === b.id && a.name === b.name)
|
||||
expect(result).toEqual([{ id: 1, name: 'Alice', age: 25 }])
|
||||
})
|
||||
|
||||
it('should consider all properties in comparison', () => {
|
||||
const result = getDifference(users1, users2, (a, b) => a.id === b.id && a.name === b.name && a.age === b.age)
|
||||
expect(result).toEqual([
|
||||
{ id: 1, name: 'Alice', age: 25 },
|
||||
{ id: 3, name: 'Charlie', age: 35 }
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getUnion - Object Types (Comparator)', () => {
|
||||
it('should merge arrays using custom comparator', () => {
|
||||
const result = getUnion(users1, users2, (a, b) => a.id === b.id)
|
||||
expect(result).toEqual([
|
||||
{ id: 1, name: 'Alice', age: 25 },
|
||||
{ id: 2, name: 'Bob', age: 30 },
|
||||
{ id: 3, name: 'Charlie', age: 35 },
|
||||
{ id: 4, name: 'David', age: 28 }
|
||||
])
|
||||
})
|
||||
|
||||
it('should handle complex comparison logic', () => {
|
||||
const products1 = [
|
||||
{ id: 1, category: 'electronics', price: 100 },
|
||||
{ id: 2, category: 'books', price: 20 }
|
||||
]
|
||||
|
||||
const products2 = [
|
||||
{ id: 3, category: 'electronics', price: 150 },
|
||||
{ id: 4, category: 'clothing', price: 50 }
|
||||
]
|
||||
|
||||
const result = getUnion(products1, products2, (a, b) => a.category === b.category)
|
||||
|
||||
expect(result).toEqual([
|
||||
{ id: 1, category: 'electronics', price: 100 },
|
||||
{ id: 2, category: 'books', price: 20 },
|
||||
{ id: 4, category: 'clothing', price: 50 }
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
// ================== Edge Cases ==================
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle identical arrays', () => {
|
||||
const arr = [1, 2, 3]
|
||||
|
||||
expect(getIntersection(arr, arr)).toEqual([1, 2, 3])
|
||||
expect(getDifference(arr, arr)).toEqual([])
|
||||
expect(getUnion(arr, arr)).toEqual([1, 2, 3])
|
||||
})
|
||||
|
||||
it('should handle arrays with duplicates', () => {
|
||||
const arr1 = [1, 1, 2, 2, 3]
|
||||
const arr2 = [2, 2, 3, 3, 4]
|
||||
|
||||
expect(getIntersection(arr1, arr2)).toEqual([2, 2, 3])
|
||||
expect(getDifference(arr1, arr2)).toEqual([1, 1])
|
||||
expect(getUnion(arr1, arr2)).toEqual([1, 2, 3, 4])
|
||||
})
|
||||
|
||||
it('should handle object array duplicates with key selector', () => {
|
||||
const arr1 = [
|
||||
{ id: 1, name: 'A' },
|
||||
{ id: 1, name: 'A' },
|
||||
{ id: 2, name: 'B' }
|
||||
]
|
||||
const arr2 = [
|
||||
{ id: 2, name: 'B' },
|
||||
{ id: 3, name: 'C' }
|
||||
]
|
||||
|
||||
const intersection = getIntersection(arr1, arr2, (item) => item.id)
|
||||
expect(intersection).toEqual([{ id: 2, name: 'B' }])
|
||||
|
||||
const union = getUnion(arr1, arr2, (item) => item.id)
|
||||
expect(union).toEqual([
|
||||
{ id: 1, name: 'A' },
|
||||
{ id: 2, name: 'B' },
|
||||
{ id: 3, name: 'C' }
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
// ================== Type Safety Tests ==================
|
||||
|
||||
describe('Type Safety', () => {
|
||||
it('should correctly infer return types', () => {
|
||||
const numbers = [1, 2, 3]
|
||||
const strings = ['a', 'b', 'c']
|
||||
|
||||
const numberResult = getIntersection(numbers, numbers)
|
||||
const stringResult = getIntersection(strings, strings)
|
||||
|
||||
expect(typeof numberResult[0]).toBe('number')
|
||||
expect(typeof stringResult[0]).toBe('string')
|
||||
})
|
||||
|
||||
it('should support complex object types', () => {
|
||||
interface ComplexObject {
|
||||
nested: {
|
||||
value: number
|
||||
}
|
||||
array: string[]
|
||||
}
|
||||
|
||||
const complex1: ComplexObject[] = [
|
||||
{ nested: { value: 1 }, array: ['a'] },
|
||||
{ nested: { value: 2 }, array: ['b'] }
|
||||
]
|
||||
|
||||
const complex2: ComplexObject[] = [
|
||||
{ nested: { value: 2 }, array: ['b'] },
|
||||
{ nested: { value: 3 }, array: ['c'] }
|
||||
]
|
||||
|
||||
const result = getIntersection(complex1, complex2, (obj) => obj.nested.value)
|
||||
expect(result).toEqual([{ nested: { value: 2 }, array: ['b'] }])
|
||||
})
|
||||
|
||||
it('should demonstrate why objects need comparators', () => {
|
||||
const obj1 = [{ id: 1, name: 'Alice' }]
|
||||
const obj2 = [{ id: 1, name: 'Alice' }]
|
||||
|
||||
// Bypass TypeScript type checking with 'any' to show runtime behavior
|
||||
const anyObj1 = obj1 as any
|
||||
const anyObj2 = obj2 as any
|
||||
|
||||
// Without comparator, objects are compared by reference, not content
|
||||
const result = getIntersection(anyObj1, anyObj2)
|
||||
expect(result).toEqual([])
|
||||
|
||||
// With proper key selector, it works correctly
|
||||
const correctResult = getIntersection(obj1, obj2, (item) => item.id)
|
||||
expect(correctResult).toEqual([{ id: 1, name: 'Alice' }])
|
||||
})
|
||||
|
||||
it('should enforce type constraints at compile time', () => {
|
||||
const obj1 = [{ id: 1, name: 'Alice' }]
|
||||
const obj2 = [{ id: 1, name: 'Alice' }]
|
||||
|
||||
// The following would cause TypeScript compilation errors:
|
||||
//
|
||||
// ❌ Error: Type '{ id: number; name: string; }' does not satisfy the constraint 'string | number | boolean | null | undefined'
|
||||
// getIntersection(obj1, obj2)
|
||||
//
|
||||
// ❌ Error: Expected 3 arguments, but got 2. Object types require a comparator.
|
||||
// getDifference(obj1, obj2)
|
||||
//
|
||||
// ❌ Error: Expected 3 arguments, but got 2. Object types require a comparator.
|
||||
// getUnion(obj1, obj2)
|
||||
|
||||
// ✅ Correct usage with key selector
|
||||
const intersection = getIntersection(obj1, obj2, (item) => item.id)
|
||||
const difference = getDifference(obj1, obj2, (item) => item.id)
|
||||
const union = getUnion(obj1, obj2, (item) => item.id)
|
||||
|
||||
expect(intersection).toEqual([{ id: 1, name: 'Alice' }])
|
||||
expect(difference).toEqual([])
|
||||
expect(union).toEqual([{ id: 1, name: 'Alice' }])
|
||||
})
|
||||
|
||||
it('should work correctly with primitive types without comparator', () => {
|
||||
const nums1 = [1, 2, 3]
|
||||
const nums2 = [2, 3, 4]
|
||||
|
||||
const intersection = getIntersection(nums1, nums2)
|
||||
expect(intersection).toEqual([2, 3])
|
||||
|
||||
const difference = getDifference(nums1, nums2)
|
||||
expect(difference).toEqual([1])
|
||||
|
||||
const union = getUnion(nums1, nums2)
|
||||
expect(union).toEqual([1, 2, 3, 4])
|
||||
})
|
||||
})
|
||||
})
|
||||
89
src/renderer/src/utils/collection.ts
Normal file
89
src/renderer/src/utils/collection.ts
Normal file
@ -0,0 +1,89 @@
|
||||
// Type-safe collection operations with strict compile-time enforcement
|
||||
|
||||
type PrimitiveType = string | number | boolean | null | undefined
|
||||
|
||||
// getIntersection - with strict overloads
|
||||
export function getIntersection<T extends PrimitiveType>(arr1: T[], arr2: T[]): T[]
|
||||
export function getIntersection<T extends object, K>(arr1: T[], arr2: T[], keySelector: (item: T) => K): T[]
|
||||
export function getIntersection<T extends object>(arr1: T[], arr2: T[], compareFn: (a: T, b: T) => boolean): T[]
|
||||
export function getIntersection<T>(
|
||||
arr1: T[],
|
||||
arr2: T[],
|
||||
comparator?: ((item: T) => any) | ((a: T, b: T) => boolean)
|
||||
): T[] {
|
||||
if (!comparator) {
|
||||
const set2 = new Set(arr2)
|
||||
return arr1.filter((element) => set2.has(element))
|
||||
}
|
||||
|
||||
if (comparator.length === 1) {
|
||||
const keySelector = comparator as (item: T) => any
|
||||
const keySet = new Set(arr2.map(keySelector))
|
||||
return arr1.filter((item) => keySet.has(keySelector(item)))
|
||||
} else {
|
||||
const compareFn = comparator as (a: T, b: T) => boolean
|
||||
return arr1.filter((item1) => arr2.some((item2) => compareFn(item1, item2)))
|
||||
}
|
||||
}
|
||||
|
||||
// getDifference - with strict overloads
|
||||
export function getDifference<T extends PrimitiveType>(arr1: T[], arr2: T[]): T[]
|
||||
export function getDifference<T extends object, K>(arr1: T[], arr2: T[], keySelector: (item: T) => K): T[]
|
||||
export function getDifference<T extends object>(arr1: T[], arr2: T[], compareFn: (a: T, b: T) => boolean): T[]
|
||||
export function getDifference<T>(
|
||||
arr1: T[],
|
||||
arr2: T[],
|
||||
comparator?: ((item: T) => any) | ((a: T, b: T) => boolean)
|
||||
): T[] {
|
||||
if (!comparator) {
|
||||
const set2 = new Set(arr2)
|
||||
return arr1.filter((element) => !set2.has(element))
|
||||
}
|
||||
|
||||
if (comparator.length === 1) {
|
||||
const keySelector = comparator as (item: T) => any
|
||||
const keySet = new Set(arr2.map(keySelector))
|
||||
return arr1.filter((item) => !keySet.has(keySelector(item)))
|
||||
} else {
|
||||
const compareFn = comparator as (a: T, b: T) => boolean
|
||||
return arr1.filter((item1) => !arr2.some((item2) => compareFn(item1, item2)))
|
||||
}
|
||||
}
|
||||
|
||||
// getUnion - with strict overloads
|
||||
export function getUnion<T extends PrimitiveType>(arr1: T[], arr2: T[]): T[]
|
||||
export function getUnion<T extends object, K>(arr1: T[], arr2: T[], keySelector: (item: T) => K): T[]
|
||||
export function getUnion<T extends object>(arr1: T[], arr2: T[], compareFn: (a: T, b: T) => boolean): T[]
|
||||
export function getUnion<T>(arr1: T[], arr2: T[], comparator?: ((item: T) => any) | ((a: T, b: T) => boolean)): T[] {
|
||||
if (!comparator) {
|
||||
return Array.from(new Set([...arr1, ...arr2]))
|
||||
}
|
||||
|
||||
if (comparator.length === 1) {
|
||||
const keySelector = comparator as (item: T) => any
|
||||
const seen = new Set<any>()
|
||||
const result: T[] = []
|
||||
|
||||
for (const item of [...arr1, ...arr2]) {
|
||||
const key = keySelector(item)
|
||||
if (!seen.has(key)) {
|
||||
seen.add(key)
|
||||
result.push(item)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
} else {
|
||||
const compareFn = comparator as (a: T, b: T) => boolean
|
||||
const result = [...arr1]
|
||||
|
||||
for (const item2 of arr2) {
|
||||
const exists = result.some((item1) => compareFn(item1, item2))
|
||||
if (!exists) {
|
||||
result.push(item2)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
import { Model, ModelType, Provider } from '@renderer/types'
|
||||
import { ModalFuncProps } from 'antd'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
@ -227,7 +227,19 @@ export function isOpenAIProvider(provider: Provider): boolean {
|
||||
return !['anthropic', 'gemini', 'vertexai'].includes(provider.type)
|
||||
}
|
||||
|
||||
/**
|
||||
* 判断模型是否为用户手动选择
|
||||
* @param {Model} model 模型对象
|
||||
* @param {ModelType} type 模型类型
|
||||
* @returns {boolean} 是否为用户手动选择
|
||||
*/
|
||||
export function isUserSelectedModelType(model: Model, type: ModelType): boolean | undefined {
|
||||
const t = model.capabilities?.find((t) => t.type === type)
|
||||
return t ? t.isUserSelected : undefined
|
||||
}
|
||||
|
||||
export * from './api'
|
||||
export * from './collection'
|
||||
export * from './file'
|
||||
export * from './image'
|
||||
export * from './json'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user