fix(ModelEdit): enhance model type management and introduce new selection logic (#8420)

* fix(ModelEdit): enhance model type management and introduce new selection logic

- Added support for 'rerank' model type in the ModelEditContent component.
- Refactored type selection logic to utilize new utility functions for finding differences and unions in model types.
- Updated model type handling to include user selection status, improving user experience in type management.
- Adjusted migration logic to initialize newType for existing models, ensuring backward compatibility.
- Introduced isUserSelectedModelType utility to streamline model type checks across the application.

* refactor(isFunctionCallingModel): simplify model type check logic

- Replaced the inline check for 'function_calling' model type with a call to the new utility function isUserSelectedModelType, enhancing code clarity and maintainability.

* feat(collection): add utility functions for array operations

- Introduced `findIntersection`, `findDifference`, and `findUnion` functions to handle array operations with support for custom key selectors and comparison functions.
- Removed previous implementations from `index.ts` to streamline utility exports.
- Added comprehensive tests for new functions covering basic types and object types with various edge cases.

* refactor(collection): rename utility functions for clarity

- Renamed `findIntersection`, `findDifference`, and `findUnion` to `getIntersection`, `getDifference`, and `getUnion` respectively for improved clarity and consistency in naming.
- Updated corresponding tests to reflect the new function names, ensuring all tests pass with the updated utility functions.

* refactor(ModelEditContent): update model type management and improve selection logic

- Replaced utility function calls to `findDifference` and `findUnion` with `getDifference` and `getUnion` for consistency.
- Introduced temporary state management for model types to enhance user selection handling.
- Added a reset functionality for model type selections, improving user experience.
- Updated the rendering logic to conditionally disable certain model types based on user selections.

* fix(ModelEditContent): enhance model type selection logic with conditional disabling

- Introduced logic to conditionally disable 'rerank' and 'embedding' model types based on user selections.
- Updated the state management for model types to ensure correct user selection handling.
- Improved the confirmation modal to reflect the updated selection logic for better user experience.

* fix(ModelEditContent): refine model type selection and update confirmation logic

- Enhanced the logic for model type selection to ensure accurate user selections for 'rerank' and 'embedding'.
- Updated the confirmation modal to reflect changes in selection handling, improving user experience.
- Adjusted state management to correctly handle updates based on selected model types.

* fix(models): update model support logic to include 'qwen3-235b-a22b-instruct'

* refactor(models): rename 'newType' to 'capabilities' and update related logic in ModelEditContent and migration scripts
This commit is contained in:
SuYao 2025-07-24 17:17:26 +08:00 committed by GitHub
parent 0453402242
commit 6c44f7fe24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 696 additions and 78 deletions

View File

@ -4,12 +4,13 @@ import {
isEmbeddingModel,
isFunctionCallingModel,
isReasoningModel,
isRerankModel,
isVisionModel,
isWebSearchModel
} from '@renderer/config/models'
import { useDynamicLabelWidth } from '@renderer/hooks/useDynamicLabelWidth'
import { Model, ModelType, Provider } from '@renderer/types'
import { getDefaultGroupName } from '@renderer/utils'
import { Model, ModelCapability, ModelType, Provider } from '@renderer/types'
import { getDefaultGroupName, getDifference, getUnion } from '@renderer/utils'
import { Button, Checkbox, Divider, Flex, Form, Input, InputNumber, message, Modal, Select } from 'antd'
import { ChevronDown, ChevronUp } from 'lucide-react'
import { FC, useState } from 'react'
@ -31,6 +32,7 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
const [showMoreSettings, setShowMoreSettings] = useState(false)
const [currencySymbol, setCurrencySymbol] = useState(model.pricing?.currencySymbol || '$')
const [isCustomCurrency, setIsCustomCurrency] = useState(!symbols.includes(model.pricing?.currencySymbol || '$'))
const [tempModelTypes, setTempModelTypes] = useState(model.capabilities || [])
const labelWidth = useDynamicLabelWidth([t('settings.models.add.endpoint_type')])
@ -42,6 +44,7 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
name: values.name || model.name,
group: values.group || model.group,
endpoint_type: provider.id === 'new-api' ? values.endpointType : model.endpoint_type,
newCapability: tempModelTypes,
pricing: {
input_per_million_tokens: Number(values.input_per_million_tokens) || 0,
output_per_million_tokens: Number(values.output_per_million_tokens) || 0,
@ -55,6 +58,7 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
const handleClose = () => {
setShowMoreSettings(false)
setTempModelTypes(model.capabilities || [])
onClose()
}
@ -179,16 +183,29 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
{(() => {
const defaultTypes = [
...(isVisionModel(model) ? ['vision'] : []),
...(isEmbeddingModel(model) ? ['embedding'] : []),
...(isReasoningModel(model) ? ['reasoning'] : []),
...(isFunctionCallingModel(model) ? ['function_calling'] : []),
...(isWebSearchModel(model) ? ['web_search'] : [])
] as ModelType[]
...(isWebSearchModel(model) ? ['web_search'] : []),
...(isEmbeddingModel(model) ? ['embedding'] : []),
...(isRerankModel(model) ? ['rerank'] : [])
]
// 合并现有选择和默认类型
const selectedTypes = [...new Set([...(model.type || []), ...defaultTypes])]
// 合并现有选择和默认类型用于前端展示
const selectedTypes = getUnion(
tempModelTypes?.filter((t) => t.isUserSelected).map((t) => t.type) || [],
getDifference(
defaultTypes,
tempModelTypes?.filter((t) => t.isUserSelected === false).map((t) => t.type) || []
)
)
const showTypeConfirmModal = (type: string) => {
const isDisabled = selectedTypes.includes('rerank') || selectedTypes.includes('embedding')
const isRerankDisabled = selectedTypes.includes('embedding')
const isEmbeddingDisabled = selectedTypes.includes('rerank')
const showTypeConfirmModal = (newCapability: ModelCapability) => {
const onUpdateType = selectedTypes?.find((t) => t === newCapability.type)
window.modal.confirm({
title: t('settings.moresetting.warn'),
content: t('settings.moresetting.check.warn'),
@ -196,54 +213,130 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ provider, model, onUpdate
cancelText: t('common.cancel'),
okButtonProps: { danger: true },
cancelButtonProps: { type: 'primary' },
onOk: () => onUpdateModel({ ...model, type: [...selectedTypes, type] as ModelType[] }),
onOk: () => {
if (onUpdateType) {
const updatedTypes = selectedTypes?.map((t) => {
if (t === newCapability.type) {
return { type: t, isUserSelected: true }
}
if (
(onUpdateType !== t && onUpdateType === 'rerank') ||
(onUpdateType === 'embedding' && onUpdateType !== t)
) {
return { type: t, isUserSelected: false }
}
return { type: t }
})
setTempModelTypes(updatedTypes as ModelCapability[])
} else {
const updatedTypes = selectedTypes?.map((t) => {
if (
(newCapability.type !== t && newCapability.type === 'rerank') ||
(newCapability.type === 'embedding' && newCapability.type !== t)
) {
return { type: t, isUserSelected: false }
}
return { type: t }
})
setTempModelTypes([...(updatedTypes as ModelCapability[]), newCapability])
}
},
onCancel: () => {},
centered: true
})
}
const handleTypeChange = (types: string[]) => {
const newType = types.find((type) => !selectedTypes.includes(type as ModelType))
if (newType) {
showTypeConfirmModal(newType)
const diff = types.length > selectedTypes.length
if (diff) {
const newCapability = getDifference(types, selectedTypes) // checkbox的特性确保了newCapability只有一个元素
showTypeConfirmModal({
type: newCapability[0] as ModelType,
isUserSelected: true
})
} else {
onUpdateModel({ ...model, type: types as ModelType[] })
const disabledTypes = getDifference(selectedTypes, types)
const onUpdateType = tempModelTypes?.find((t) => t.type === disabledTypes[0])
if (onUpdateType) {
const updatedTypes = tempModelTypes?.map((t) => {
if (t.type === disabledTypes[0]) {
return { ...t, isUserSelected: false }
}
if (
(onUpdateType !== t && onUpdateType.type === 'rerank') ||
(onUpdateType.type === 'embedding' && onUpdateType !== t && t.isUserSelected === false)
) {
return { ...t, isUserSelected: true }
}
return t
})
setTempModelTypes(updatedTypes || [])
} else {
const updatedTypes = tempModelTypes?.map((t) => {
if (
(disabledTypes[0] === 'rerank' && t.type !== 'rerank') ||
(disabledTypes[0] === 'embedding' && t.type !== 'embedding' && t.isUserSelected === false)
) {
return { ...t, isUserSelected: true }
}
return t
})
setTempModelTypes([
...(updatedTypes ?? []),
{ type: disabledTypes[0] as ModelType, isUserSelected: false }
])
}
}
}
const handleResetTypes = () => {
setTempModelTypes([])
}
return (
<Checkbox.Group
value={selectedTypes}
onChange={handleTypeChange}
options={[
{
label: t('models.type.vision'),
value: 'vision',
disabled: isVisionModel(model) && !selectedTypes.includes('vision')
},
{
label: t('models.type.websearch'),
value: 'web_search',
disabled: isWebSearchModel(model) && !selectedTypes.includes('web_search')
},
{
label: t('models.type.embedding'),
value: 'embedding',
disabled: isEmbeddingModel(model) && !selectedTypes.includes('embedding')
},
{
label: t('models.type.reasoning'),
value: 'reasoning',
disabled: isReasoningModel(model) && !selectedTypes.includes('reasoning')
},
{
label: t('models.type.function_calling'),
value: 'function_calling',
disabled: isFunctionCallingModel(model) && !selectedTypes.includes('function_calling')
}
]}
/>
<div>
<Flex justify="space-between" align="center" style={{ marginBottom: 8 }}>
<Checkbox.Group
value={selectedTypes}
onChange={handleTypeChange}
options={[
{
label: t('models.type.vision'),
value: 'vision',
disabled: isDisabled
},
{
label: t('models.type.websearch'),
value: 'web_search',
disabled: isDisabled
},
{
label: t('models.type.rerank'),
value: 'rerank',
disabled: isRerankDisabled
},
{
label: t('models.type.embedding'),
value: 'embedding',
disabled: isEmbeddingDisabled
},
{
label: t('models.type.reasoning'),
value: 'reasoning',
disabled: isDisabled
},
{
label: t('models.type.function_calling'),
value: 'function_calling',
disabled: isDisabled
}
]}
/>
<Button size="small" onClick={handleResetTypes}>
{t('common.reset')}
</Button>
</Flex>
</div>
)
})()}
<TypeTitle>{t('models.price.price')}</TypeTitle>

View File

@ -145,7 +145,7 @@ import YoudaoLogo from '@renderer/assets/images/providers/netease-youdao.svg'
import NomicLogo from '@renderer/assets/images/providers/nomic.png'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { Model } from '@renderer/types'
import { getLowerBaseModelName } from '@renderer/utils'
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import OpenAI from 'openai'
import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from './prompts'
@ -265,16 +265,12 @@ export const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp(
)
export function isFunctionCallingModel(model?: Model): boolean {
if (!model) {
if (!model || isEmbeddingModel(model) || isRerankModel(model)) {
return false
}
if (model.type?.includes('function_calling')) {
return true
}
if (isEmbeddingModel(model)) {
return false
if (isUserSelectedModelType(model, 'function_calling') !== undefined) {
return isUserSelectedModelType(model, 'function_calling')!
}
if (model.provider === 'qiniu') {
@ -2395,10 +2391,14 @@ export function isTextToImageModel(model: Model): boolean {
}
export function isEmbeddingModel(model: Model): boolean {
if (!model) {
if (!model || isRerankModel(model)) {
return false
}
if (isUserSelectedModelType(model, 'embedding') !== undefined) {
return isUserSelectedModelType(model, 'embedding')!
}
if (['anthropic'].includes(model?.provider)) {
return false
}
@ -2407,31 +2407,33 @@ export function isEmbeddingModel(model: Model): boolean {
return EMBEDDING_REGEX.test(model.name)
}
if (isRerankModel(model)) {
return false
}
return EMBEDDING_REGEX.test(model.id) || model.type?.includes('embedding') || false
return EMBEDDING_REGEX.test(model.id) || false
}
export function isRerankModel(model: Model): boolean {
if (isUserSelectedModelType(model, 'rerank') !== undefined) {
return isUserSelectedModelType(model, 'rerank')!
}
return model ? RERANKING_REGEX.test(model.id) || false : false
}
export function isVisionModel(model: Model): boolean {
if (!model) {
if (!model || isEmbeddingModel(model) || isRerankModel(model)) {
return false
}
// 新添字段 copilot-vision-request 后可使用 vision
// if (model.provider === 'copilot') {
// return false
// }
if (model.provider === 'doubao' || model.id.includes('doubao')) {
return VISION_REGEX.test(model.name) || VISION_REGEX.test(model.id) || model.type?.includes('vision') || false
if (isUserSelectedModelType(model, 'vision') !== undefined) {
return isUserSelectedModelType(model, 'vision')!
}
return VISION_REGEX.test(model.id) || model.type?.includes('vision') || false
if (model.provider === 'doubao' || model.id.includes('doubao')) {
return VISION_REGEX.test(model.name) || VISION_REGEX.test(model.id) || false
}
return VISION_REGEX.test(model.id) || false
}
export function isOpenAIReasoningModel(model: Model): boolean {
@ -2599,7 +2601,7 @@ export function isSupportedThinkingTokenQwenModel(model?: Model): boolean {
const baseName = getLowerBaseModelName(model.id, '/')
if (baseName.includes('coder')) {
if (baseName.includes('coder') || baseName.includes('qwen3-235b-a22b-instruct')) {
return false
}
@ -2660,19 +2662,18 @@ export const isHunyuanReasoningModel = (model?: Model): boolean => {
}
export function isReasoningModel(model?: Model): boolean {
if (!model) {
if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
return false
}
if (isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
return false
if (isUserSelectedModelType(model, 'reasoning') !== undefined) {
return isUserSelectedModelType(model, 'reasoning')!
}
if (model.provider === 'doubao' || model.id.includes('doubao')) {
return (
REASONING_REGEX.test(model.id) ||
REASONING_REGEX.test(model.name) ||
model.type?.includes('reasoning') ||
isSupportedThinkingTokenDoubaoModel(model) ||
false
)
@ -2692,7 +2693,7 @@ export function isReasoningModel(model?: Model): boolean {
return true
}
return REASONING_REGEX.test(model.id) || model.type?.includes('reasoning') || false
return REASONING_REGEX.test(model.id) || false
}
export function isSupportedModel(model: OpenAI.Models.Model): boolean {
@ -2716,14 +2717,12 @@ export function isNotSupportTemperatureAndTopP(model: Model): boolean {
}
export function isWebSearchModel(model: Model): boolean {
if (!model) {
if (!model || isEmbeddingModel(model) || isRerankModel(model)) {
return false
}
if (model.type) {
if (model.type.includes('web_search')) {
return true
}
if (isUserSelectedModelType(model, 'web_search') !== undefined) {
return isUserSelectedModelType(model, 'web_search')!
}
const provider = getProviderByModel(model)

View File

@ -58,7 +58,7 @@ const persistedReducer = persistReducer(
{
key: 'cherry-studio',
storage,
version: 122,
version: 123,
blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs'],
migrate
},

View File

@ -1844,6 +1844,25 @@ const migrateConfig = {
logger.error('migrate 122 error', error as Error)
return state
}
},
'123': (state: RootState) => {
try {
state.llm.providers.forEach((provider) => {
provider.models.forEach((model) => {
if (model.type && Array.isArray(model.type)) {
model.capabilities = model.type.map((t) => ({
type: t,
isUserSelected: true
}))
delete model.type
}
})
})
return state
} catch (error) {
logger.error('migrate 123 error', error as Error)
return state
}
}
}

View File

@ -182,7 +182,7 @@ export type ProviderType =
| 'vertexai'
| 'mistral'
export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search'
export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search' | 'rerank'
export type EndpointType = 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
@ -192,6 +192,15 @@ export type ModelPricing = {
currencySymbol?: string
}
export type ModelCapability = {
type: ModelType
/**
* trueundefined使
* Is it manually selected by the user? If true, it means the user manually selected this type; otherwise, it means the user * manually disabled the model.
*/
isUserSelected?: boolean
}
export type Model = {
id: string
provider: string
@ -199,6 +208,10 @@ export type Model = {
group: string
owned_by?: string
description?: string
capabilities?: ModelCapability[]
/**
* @deprecated
*/
type?: ModelType[]
pricing?: ModelPricing
endpoint_type?: EndpointType

View File

@ -0,0 +1,393 @@
import { describe, expect, it } from 'vitest'
import { getDifference, getIntersection, getUnion } from '../collection'
describe('Collection Utils', () => {
// ================== Basic Types Tests ==================
describe('getIntersection - Basic Types', () => {
it('should get intersection of number arrays', () => {
const arr1 = [1, 2, 3, 4]
const arr2 = [3, 4, 5, 6]
const result = getIntersection(arr1, arr2)
expect(result).toEqual([3, 4])
})
it('should get intersection of string arrays', () => {
const arr1 = ['a', 'b', 'c']
const arr2 = ['b', 'c', 'd']
const result = getIntersection(arr1, arr2)
expect(result).toEqual(['b', 'c'])
})
it('should return empty array when no intersection', () => {
const arr1 = [1, 2, 3]
const arr2 = [4, 5, 6]
const result = getIntersection(arr1, arr2)
expect(result).toEqual([])
})
it('should return empty array when one array is empty', () => {
const arr1 = [1, 2, 3]
const arr2: number[] = []
const result = getIntersection(arr1, arr2)
expect(result).toEqual([])
})
it('should return empty array when both arrays are empty', () => {
const arr1: number[] = []
const arr2: number[] = []
const result = getIntersection(arr1, arr2)
expect(result).toEqual([])
})
})
describe('getDifference - Basic Types', () => {
it('should get difference of number arrays', () => {
const arr1 = [1, 2, 3, 4]
const arr2 = [3, 4, 5, 6]
const result = getDifference(arr1, arr2)
expect(result).toEqual([1, 2])
})
it('should get difference of string arrays', () => {
const arr1 = ['a', 'b', 'c']
const arr2 = ['b', 'c', 'd']
const result = getDifference(arr1, arr2)
expect(result).toEqual(['a'])
})
it('should return empty array when no difference', () => {
const arr1 = [1, 2, 3]
const arr2 = [1, 2, 3, 4, 5]
const result = getDifference(arr1, arr2)
expect(result).toEqual([])
})
it('should return first array when second array is empty', () => {
const arr1 = [1, 2, 3]
const arr2: number[] = []
const result = getDifference(arr1, arr2)
expect(result).toEqual([1, 2, 3])
})
})
describe('getUnion - Basic Types', () => {
it('should merge number arrays correctly', () => {
const arr1 = [1, 2, 3]
const arr2 = [3, 4, 5]
const result = getUnion(arr1, arr2)
expect(result).toEqual([1, 2, 3, 4, 5])
})
it('should merge string arrays correctly', () => {
const arr1 = ['a', 'b']
const arr2 = ['b', 'c']
const result = getUnion(arr1, arr2)
expect(result).toEqual(['a', 'b', 'c'])
})
it('should merge arrays with no duplicates', () => {
const arr1 = [1, 2]
const arr2 = [3, 4]
const result = getUnion(arr1, arr2)
expect(result).toEqual([1, 2, 3, 4])
})
it('should return other array when one is empty', () => {
const arr1 = [1, 2, 3]
const arr2: number[] = []
const result = getUnion(arr1, arr2)
expect(result).toEqual([1, 2, 3])
})
})
// ================== Object Types Tests - Key Selector ==================
interface User {
id: number
name: string
age: number
}
const users1: User[] = [
{ id: 1, name: 'Alice', age: 25 },
{ id: 2, name: 'Bob', age: 30 },
{ id: 3, name: 'Charlie', age: 35 }
]
const users2: User[] = [
{ id: 2, name: 'Bob', age: 30 },
{ id: 3, name: 'Charlie', age: 36 },
{ id: 4, name: 'David', age: 28 }
]
describe('getIntersection - Object Types (Key Selector)', () => {
it('should get user intersection by id', () => {
const result = getIntersection(users1, users2, (user) => user.id)
expect(result).toEqual([
{ id: 2, name: 'Bob', age: 30 },
{ id: 3, name: 'Charlie', age: 35 }
])
})
it('should get user intersection by name', () => {
const result = getIntersection(users1, users2, (user) => user.name)
expect(result).toEqual([
{ id: 2, name: 'Bob', age: 30 },
{ id: 3, name: 'Charlie', age: 35 }
])
})
it('should return empty array when no intersection', () => {
const users3: User[] = [{ id: 5, name: 'Eve', age: 40 }]
const result = getIntersection(users1, users3, (user) => user.id)
expect(result).toEqual([])
})
})
describe('getDifference - Object Types (Key Selector)', () => {
it('should get user difference by id', () => {
const result = getDifference(users1, users2, (user) => user.id)
expect(result).toEqual([{ id: 1, name: 'Alice', age: 25 }])
})
it('should get user difference by name', () => {
const result = getDifference(users1, users2, (user) => user.name)
expect(result).toEqual([{ id: 1, name: 'Alice', age: 25 }])
})
it('should return correct difference', () => {
const result = getDifference(users2, users1, (user) => user.id)
expect(result).toEqual([{ id: 4, name: 'David', age: 28 }])
})
})
describe('getUnion - Object Types (Key Selector)', () => {
it('should merge user arrays by id', () => {
const result = getUnion(users1, users2, (user) => user.id)
expect(result).toEqual([
{ id: 1, name: 'Alice', age: 25 },
{ id: 2, name: 'Bob', age: 30 },
{ id: 3, name: 'Charlie', age: 35 },
{ id: 4, name: 'David', age: 28 }
])
})
it('should preserve first array element version', () => {
const result = getUnion(users1, users2, (user) => user.id)
const charlie = result.find((u) => u.id === 3)
expect(charlie?.age).toBe(35)
})
})
// ================== Object Types Tests - Comparator Function ==================
describe('getIntersection - Object Types (Comparator)', () => {
it('should use custom comparator correctly', () => {
const result = getIntersection(users1, users2, (a, b) => a.id === b.id && a.name === b.name)
expect(result).toEqual([
{ id: 2, name: 'Bob', age: 30 },
{ id: 3, name: 'Charlie', age: 35 }
])
})
it('should get users with similar age', () => {
const youngUsers: User[] = [
{ id: 5, name: 'Eve', age: 26 },
{ id: 6, name: 'Frank', age: 32 }
]
const result = getIntersection(users1, youngUsers, (a, b) => Math.abs(a.age - b.age) <= 5)
expect(result).toEqual([
{ id: 1, name: 'Alice', age: 25 },
{ id: 2, name: 'Bob', age: 30 },
{ id: 3, name: 'Charlie', age: 35 }
])
})
})
describe('getDifference - Object Types (Comparator)', () => {
it('should use custom comparator for difference', () => {
const result = getDifference(users1, users2, (a, b) => a.id === b.id && a.name === b.name)
expect(result).toEqual([{ id: 1, name: 'Alice', age: 25 }])
})
it('should consider all properties in comparison', () => {
const result = getDifference(users1, users2, (a, b) => a.id === b.id && a.name === b.name && a.age === b.age)
expect(result).toEqual([
{ id: 1, name: 'Alice', age: 25 },
{ id: 3, name: 'Charlie', age: 35 }
])
})
})
describe('getUnion - Object Types (Comparator)', () => {
it('should merge arrays using custom comparator', () => {
const result = getUnion(users1, users2, (a, b) => a.id === b.id)
expect(result).toEqual([
{ id: 1, name: 'Alice', age: 25 },
{ id: 2, name: 'Bob', age: 30 },
{ id: 3, name: 'Charlie', age: 35 },
{ id: 4, name: 'David', age: 28 }
])
})
it('should handle complex comparison logic', () => {
const products1 = [
{ id: 1, category: 'electronics', price: 100 },
{ id: 2, category: 'books', price: 20 }
]
const products2 = [
{ id: 3, category: 'electronics', price: 150 },
{ id: 4, category: 'clothing', price: 50 }
]
const result = getUnion(products1, products2, (a, b) => a.category === b.category)
expect(result).toEqual([
{ id: 1, category: 'electronics', price: 100 },
{ id: 2, category: 'books', price: 20 },
{ id: 4, category: 'clothing', price: 50 }
])
})
})
// ================== Edge Cases ==================
describe('Edge Cases', () => {
it('should handle identical arrays', () => {
const arr = [1, 2, 3]
expect(getIntersection(arr, arr)).toEqual([1, 2, 3])
expect(getDifference(arr, arr)).toEqual([])
expect(getUnion(arr, arr)).toEqual([1, 2, 3])
})
it('should handle arrays with duplicates', () => {
const arr1 = [1, 1, 2, 2, 3]
const arr2 = [2, 2, 3, 3, 4]
expect(getIntersection(arr1, arr2)).toEqual([2, 2, 3])
expect(getDifference(arr1, arr2)).toEqual([1, 1])
expect(getUnion(arr1, arr2)).toEqual([1, 2, 3, 4])
})
it('should handle object array duplicates with key selector', () => {
const arr1 = [
{ id: 1, name: 'A' },
{ id: 1, name: 'A' },
{ id: 2, name: 'B' }
]
const arr2 = [
{ id: 2, name: 'B' },
{ id: 3, name: 'C' }
]
const intersection = getIntersection(arr1, arr2, (item) => item.id)
expect(intersection).toEqual([{ id: 2, name: 'B' }])
const union = getUnion(arr1, arr2, (item) => item.id)
expect(union).toEqual([
{ id: 1, name: 'A' },
{ id: 2, name: 'B' },
{ id: 3, name: 'C' }
])
})
})
// ================== Type Safety Tests ==================
describe('Type Safety', () => {
it('should correctly infer return types', () => {
const numbers = [1, 2, 3]
const strings = ['a', 'b', 'c']
const numberResult = getIntersection(numbers, numbers)
const stringResult = getIntersection(strings, strings)
expect(typeof numberResult[0]).toBe('number')
expect(typeof stringResult[0]).toBe('string')
})
it('should support complex object types', () => {
interface ComplexObject {
nested: {
value: number
}
array: string[]
}
const complex1: ComplexObject[] = [
{ nested: { value: 1 }, array: ['a'] },
{ nested: { value: 2 }, array: ['b'] }
]
const complex2: ComplexObject[] = [
{ nested: { value: 2 }, array: ['b'] },
{ nested: { value: 3 }, array: ['c'] }
]
const result = getIntersection(complex1, complex2, (obj) => obj.nested.value)
expect(result).toEqual([{ nested: { value: 2 }, array: ['b'] }])
})
it('should demonstrate why objects need comparators', () => {
const obj1 = [{ id: 1, name: 'Alice' }]
const obj2 = [{ id: 1, name: 'Alice' }]
// Bypass TypeScript type checking with 'any' to show runtime behavior
const anyObj1 = obj1 as any
const anyObj2 = obj2 as any
// Without comparator, objects are compared by reference, not content
const result = getIntersection(anyObj1, anyObj2)
expect(result).toEqual([])
// With proper key selector, it works correctly
const correctResult = getIntersection(obj1, obj2, (item) => item.id)
expect(correctResult).toEqual([{ id: 1, name: 'Alice' }])
})
it('should enforce type constraints at compile time', () => {
const obj1 = [{ id: 1, name: 'Alice' }]
const obj2 = [{ id: 1, name: 'Alice' }]
// The following would cause TypeScript compilation errors:
//
// ❌ Error: Type '{ id: number; name: string; }' does not satisfy the constraint 'string | number | boolean | null | undefined'
// getIntersection(obj1, obj2)
//
// ❌ Error: Expected 3 arguments, but got 2. Object types require a comparator.
// getDifference(obj1, obj2)
//
// ❌ Error: Expected 3 arguments, but got 2. Object types require a comparator.
// getUnion(obj1, obj2)
// ✅ Correct usage with key selector
const intersection = getIntersection(obj1, obj2, (item) => item.id)
const difference = getDifference(obj1, obj2, (item) => item.id)
const union = getUnion(obj1, obj2, (item) => item.id)
expect(intersection).toEqual([{ id: 1, name: 'Alice' }])
expect(difference).toEqual([])
expect(union).toEqual([{ id: 1, name: 'Alice' }])
})
it('should work correctly with primitive types without comparator', () => {
const nums1 = [1, 2, 3]
const nums2 = [2, 3, 4]
const intersection = getIntersection(nums1, nums2)
expect(intersection).toEqual([2, 3])
const difference = getDifference(nums1, nums2)
expect(difference).toEqual([1])
const union = getUnion(nums1, nums2)
expect(union).toEqual([1, 2, 3, 4])
})
})
})

View File

@ -0,0 +1,89 @@
// Type-safe collection operations with strict compile-time enforcement
type PrimitiveType = string | number | boolean | null | undefined
// getIntersection - with strict overloads
export function getIntersection<T extends PrimitiveType>(arr1: T[], arr2: T[]): T[]
export function getIntersection<T extends object, K>(arr1: T[], arr2: T[], keySelector: (item: T) => K): T[]
export function getIntersection<T extends object>(arr1: T[], arr2: T[], compareFn: (a: T, b: T) => boolean): T[]
export function getIntersection<T>(
arr1: T[],
arr2: T[],
comparator?: ((item: T) => any) | ((a: T, b: T) => boolean)
): T[] {
if (!comparator) {
const set2 = new Set(arr2)
return arr1.filter((element) => set2.has(element))
}
if (comparator.length === 1) {
const keySelector = comparator as (item: T) => any
const keySet = new Set(arr2.map(keySelector))
return arr1.filter((item) => keySet.has(keySelector(item)))
} else {
const compareFn = comparator as (a: T, b: T) => boolean
return arr1.filter((item1) => arr2.some((item2) => compareFn(item1, item2)))
}
}
// getDifference - with strict overloads
export function getDifference<T extends PrimitiveType>(arr1: T[], arr2: T[]): T[]
export function getDifference<T extends object, K>(arr1: T[], arr2: T[], keySelector: (item: T) => K): T[]
export function getDifference<T extends object>(arr1: T[], arr2: T[], compareFn: (a: T, b: T) => boolean): T[]
export function getDifference<T>(
arr1: T[],
arr2: T[],
comparator?: ((item: T) => any) | ((a: T, b: T) => boolean)
): T[] {
if (!comparator) {
const set2 = new Set(arr2)
return arr1.filter((element) => !set2.has(element))
}
if (comparator.length === 1) {
const keySelector = comparator as (item: T) => any
const keySet = new Set(arr2.map(keySelector))
return arr1.filter((item) => !keySet.has(keySelector(item)))
} else {
const compareFn = comparator as (a: T, b: T) => boolean
return arr1.filter((item1) => !arr2.some((item2) => compareFn(item1, item2)))
}
}
// getUnion - with strict overloads
export function getUnion<T extends PrimitiveType>(arr1: T[], arr2: T[]): T[]
export function getUnion<T extends object, K>(arr1: T[], arr2: T[], keySelector: (item: T) => K): T[]
export function getUnion<T extends object>(arr1: T[], arr2: T[], compareFn: (a: T, b: T) => boolean): T[]
export function getUnion<T>(arr1: T[], arr2: T[], comparator?: ((item: T) => any) | ((a: T, b: T) => boolean)): T[] {
if (!comparator) {
return Array.from(new Set([...arr1, ...arr2]))
}
if (comparator.length === 1) {
const keySelector = comparator as (item: T) => any
const seen = new Set<any>()
const result: T[] = []
for (const item of [...arr1, ...arr2]) {
const key = keySelector(item)
if (!seen.has(key)) {
seen.add(key)
result.push(item)
}
}
return result
} else {
const compareFn = comparator as (a: T, b: T) => boolean
const result = [...arr1]
for (const item2 of arr2) {
const exists = result.some((item1) => compareFn(item1, item2))
if (!exists) {
result.push(item2)
}
}
return result
}
}

View File

@ -1,5 +1,5 @@
import { loggerService } from '@logger'
import { Model, Provider } from '@renderer/types'
import { Model, ModelType, Provider } from '@renderer/types'
import { ModalFuncProps } from 'antd'
import { v4 as uuidv4 } from 'uuid'
@ -227,7 +227,19 @@ export function isOpenAIProvider(provider: Provider): boolean {
return !['anthropic', 'gemini', 'vertexai'].includes(provider.type)
}
/**
*
* @param {Model} model
* @param {ModelType} type
* @returns {boolean}
*/
export function isUserSelectedModelType(model: Model, type: ModelType): boolean | undefined {
const t = model.capabilities?.find((t) => t.type === type)
return t ? t.isUserSelected : undefined
}
export * from './api'
export * from './collection'
export * from './file'
export * from './image'
export * from './json'