diff --git a/src/renderer/src/components/ModelList/ModelEditContent.tsx b/src/renderer/src/components/ModelList/ModelEditContent.tsx index fb539d94bc..640c44fae6 100644 --- a/src/renderer/src/components/ModelList/ModelEditContent.tsx +++ b/src/renderer/src/components/ModelList/ModelEditContent.tsx @@ -4,12 +4,13 @@ import { isEmbeddingModel, isFunctionCallingModel, isReasoningModel, + isRerankModel, isVisionModel, isWebSearchModel } from '@renderer/config/models' import { useDynamicLabelWidth } from '@renderer/hooks/useDynamicLabelWidth' -import { Model, ModelType, Provider } from '@renderer/types' -import { getDefaultGroupName } from '@renderer/utils' +import { Model, ModelCapability, ModelType, Provider } from '@renderer/types' +import { getDefaultGroupName, getDifference, getUnion } from '@renderer/utils' import { Button, Checkbox, Divider, Flex, Form, Input, InputNumber, message, Modal, Select } from 'antd' import { ChevronDown, ChevronUp } from 'lucide-react' import { FC, useState } from 'react' @@ -31,6 +32,7 @@ const ModelEditContent: FC = ({ provider, model, onUpdate const [showMoreSettings, setShowMoreSettings] = useState(false) const [currencySymbol, setCurrencySymbol] = useState(model.pricing?.currencySymbol || '$') const [isCustomCurrency, setIsCustomCurrency] = useState(!symbols.includes(model.pricing?.currencySymbol || '$')) + const [tempModelTypes, setTempModelTypes] = useState(model.capabilities || []) const labelWidth = useDynamicLabelWidth([t('settings.models.add.endpoint_type')]) @@ -42,6 +44,7 @@ const ModelEditContent: FC = ({ provider, model, onUpdate name: values.name || model.name, group: values.group || model.group, endpoint_type: provider.id === 'new-api' ? values.endpointType : model.endpoint_type, + newCapability: tempModelTypes, pricing: { input_per_million_tokens: Number(values.input_per_million_tokens) || 0, output_per_million_tokens: Number(values.output_per_million_tokens) || 0, @@ -55,6 +58,7 @@ const ModelEditContent: FC = ({ provider, model, onUpdate const handleClose = () => { setShowMoreSettings(false) + setTempModelTypes(model.capabilities || []) onClose() } @@ -179,16 +183,29 @@ const ModelEditContent: FC = ({ provider, model, onUpdate {(() => { const defaultTypes = [ ...(isVisionModel(model) ? ['vision'] : []), - ...(isEmbeddingModel(model) ? ['embedding'] : []), ...(isReasoningModel(model) ? ['reasoning'] : []), ...(isFunctionCallingModel(model) ? ['function_calling'] : []), - ...(isWebSearchModel(model) ? ['web_search'] : []) - ] as ModelType[] + ...(isWebSearchModel(model) ? ['web_search'] : []), + ...(isEmbeddingModel(model) ? ['embedding'] : []), + ...(isRerankModel(model) ? ['rerank'] : []) + ] - // 合并现有选择和默认类型 - const selectedTypes = [...new Set([...(model.type || []), ...defaultTypes])] + // 合并现有选择和默认类型用于前端展示 + const selectedTypes = getUnion( + tempModelTypes?.filter((t) => t.isUserSelected).map((t) => t.type) || [], + getDifference( + defaultTypes, + tempModelTypes?.filter((t) => t.isUserSelected === false).map((t) => t.type) || [] + ) + ) - const showTypeConfirmModal = (type: string) => { + const isDisabled = selectedTypes.includes('rerank') || selectedTypes.includes('embedding') + + const isRerankDisabled = selectedTypes.includes('embedding') + const isEmbeddingDisabled = selectedTypes.includes('rerank') + + const showTypeConfirmModal = (newCapability: ModelCapability) => { + const onUpdateType = selectedTypes?.find((t) => t === newCapability.type) window.modal.confirm({ title: t('settings.moresetting.warn'), content: t('settings.moresetting.check.warn'), @@ -196,54 +213,130 @@ const ModelEditContent: FC = ({ provider, model, onUpdate cancelText: t('common.cancel'), okButtonProps: { danger: true }, cancelButtonProps: { type: 'primary' }, - onOk: () => onUpdateModel({ ...model, type: [...selectedTypes, type] as ModelType[] }), + onOk: () => { + if (onUpdateType) { + const updatedTypes = selectedTypes?.map((t) => { + if (t === newCapability.type) { + return { type: t, isUserSelected: true } + } + if ( + (onUpdateType !== t && onUpdateType === 'rerank') || + (onUpdateType === 'embedding' && onUpdateType !== t) + ) { + return { type: t, isUserSelected: false } + } + return { type: t } + }) + setTempModelTypes(updatedTypes as ModelCapability[]) + } else { + const updatedTypes = selectedTypes?.map((t) => { + if ( + (newCapability.type !== t && newCapability.type === 'rerank') || + (newCapability.type === 'embedding' && newCapability.type !== t) + ) { + return { type: t, isUserSelected: false } + } + return { type: t } + }) + setTempModelTypes([...(updatedTypes as ModelCapability[]), newCapability]) + } + }, onCancel: () => {}, centered: true }) } const handleTypeChange = (types: string[]) => { - const newType = types.find((type) => !selectedTypes.includes(type as ModelType)) - - if (newType) { - showTypeConfirmModal(newType) + const diff = types.length > selectedTypes.length + if (diff) { + const newCapability = getDifference(types, selectedTypes) // checkbox的特性,确保了newCapability只有一个元素 + showTypeConfirmModal({ + type: newCapability[0] as ModelType, + isUserSelected: true + }) } else { - onUpdateModel({ ...model, type: types as ModelType[] }) + const disabledTypes = getDifference(selectedTypes, types) + const onUpdateType = tempModelTypes?.find((t) => t.type === disabledTypes[0]) + if (onUpdateType) { + const updatedTypes = tempModelTypes?.map((t) => { + if (t.type === disabledTypes[0]) { + return { ...t, isUserSelected: false } + } + if ( + (onUpdateType !== t && onUpdateType.type === 'rerank') || + (onUpdateType.type === 'embedding' && onUpdateType !== t && t.isUserSelected === false) + ) { + return { ...t, isUserSelected: true } + } + return t + }) + setTempModelTypes(updatedTypes || []) + } else { + const updatedTypes = tempModelTypes?.map((t) => { + if ( + (disabledTypes[0] === 'rerank' && t.type !== 'rerank') || + (disabledTypes[0] === 'embedding' && t.type !== 'embedding' && t.isUserSelected === false) + ) { + return { ...t, isUserSelected: true } + } + return t + }) + setTempModelTypes([ + ...(updatedTypes ?? []), + { type: disabledTypes[0] as ModelType, isUserSelected: false } + ]) + } } } + const handleResetTypes = () => { + setTempModelTypes([]) + } + return ( - +
+ + + + +
) })()} {t('models.price.price')} diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index ef3ab76a04..bce85fff4f 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -145,7 +145,7 @@ import YoudaoLogo from '@renderer/assets/images/providers/netease-youdao.svg' import NomicLogo from '@renderer/assets/images/providers/nomic.png' import { getProviderByModel } from '@renderer/services/AssistantService' import { Model } from '@renderer/types' -import { getLowerBaseModelName } from '@renderer/utils' +import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils' import OpenAI from 'openai' import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from './prompts' @@ -265,16 +265,12 @@ export const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp( ) export function isFunctionCallingModel(model?: Model): boolean { - if (!model) { + if (!model || isEmbeddingModel(model) || isRerankModel(model)) { return false } - if (model.type?.includes('function_calling')) { - return true - } - - if (isEmbeddingModel(model)) { - return false + if (isUserSelectedModelType(model, 'function_calling') !== undefined) { + return isUserSelectedModelType(model, 'function_calling')! } if (model.provider === 'qiniu') { @@ -2395,10 +2391,14 @@ export function isTextToImageModel(model: Model): boolean { } export function isEmbeddingModel(model: Model): boolean { - if (!model) { + if (!model || isRerankModel(model)) { return false } + if (isUserSelectedModelType(model, 'embedding') !== undefined) { + return isUserSelectedModelType(model, 'embedding')! + } + if (['anthropic'].includes(model?.provider)) { return false } @@ -2407,31 +2407,33 @@ export function isEmbeddingModel(model: Model): boolean { return EMBEDDING_REGEX.test(model.name) } - if (isRerankModel(model)) { - return false - } - - return EMBEDDING_REGEX.test(model.id) || model.type?.includes('embedding') || false + return EMBEDDING_REGEX.test(model.id) || false } export function isRerankModel(model: Model): boolean { + if (isUserSelectedModelType(model, 'rerank') !== undefined) { + return isUserSelectedModelType(model, 'rerank')! + } return model ? RERANKING_REGEX.test(model.id) || false : false } export function isVisionModel(model: Model): boolean { - if (!model) { + if (!model || isEmbeddingModel(model) || isRerankModel(model)) { return false } // 新添字段 copilot-vision-request 后可使用 vision // if (model.provider === 'copilot') { // return false // } - - if (model.provider === 'doubao' || model.id.includes('doubao')) { - return VISION_REGEX.test(model.name) || VISION_REGEX.test(model.id) || model.type?.includes('vision') || false + if (isUserSelectedModelType(model, 'vision') !== undefined) { + return isUserSelectedModelType(model, 'vision')! } - return VISION_REGEX.test(model.id) || model.type?.includes('vision') || false + if (model.provider === 'doubao' || model.id.includes('doubao')) { + return VISION_REGEX.test(model.name) || VISION_REGEX.test(model.id) || false + } + + return VISION_REGEX.test(model.id) || false } export function isOpenAIReasoningModel(model: Model): boolean { @@ -2599,7 +2601,7 @@ export function isSupportedThinkingTokenQwenModel(model?: Model): boolean { const baseName = getLowerBaseModelName(model.id, '/') - if (baseName.includes('coder')) { + if (baseName.includes('coder') || baseName.includes('qwen3-235b-a22b-instruct')) { return false } @@ -2660,19 +2662,18 @@ export const isHunyuanReasoningModel = (model?: Model): boolean => { } export function isReasoningModel(model?: Model): boolean { - if (!model) { + if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) { return false } - if (isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) { - return false + if (isUserSelectedModelType(model, 'reasoning') !== undefined) { + return isUserSelectedModelType(model, 'reasoning')! } if (model.provider === 'doubao' || model.id.includes('doubao')) { return ( REASONING_REGEX.test(model.id) || REASONING_REGEX.test(model.name) || - model.type?.includes('reasoning') || isSupportedThinkingTokenDoubaoModel(model) || false ) @@ -2692,7 +2693,7 @@ export function isReasoningModel(model?: Model): boolean { return true } - return REASONING_REGEX.test(model.id) || model.type?.includes('reasoning') || false + return REASONING_REGEX.test(model.id) || false } export function isSupportedModel(model: OpenAI.Models.Model): boolean { @@ -2716,14 +2717,12 @@ export function isNotSupportTemperatureAndTopP(model: Model): boolean { } export function isWebSearchModel(model: Model): boolean { - if (!model) { + if (!model || isEmbeddingModel(model) || isRerankModel(model)) { return false } - if (model.type) { - if (model.type.includes('web_search')) { - return true - } + if (isUserSelectedModelType(model, 'web_search') !== undefined) { + return isUserSelectedModelType(model, 'web_search')! } const provider = getProviderByModel(model) diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index 24c07b71ab..9288d87fff 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -58,7 +58,7 @@ const persistedReducer = persistReducer( { key: 'cherry-studio', storage, - version: 122, + version: 123, blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs'], migrate }, diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 51b52c4d57..e835c25631 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -1844,6 +1844,25 @@ const migrateConfig = { logger.error('migrate 122 error', error as Error) return state } + }, + '123': (state: RootState) => { + try { + state.llm.providers.forEach((provider) => { + provider.models.forEach((model) => { + if (model.type && Array.isArray(model.type)) { + model.capabilities = model.type.map((t) => ({ + type: t, + isUserSelected: true + })) + delete model.type + } + }) + }) + return state + } catch (error) { + logger.error('migrate 123 error', error as Error) + return state + } } } diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index ad87aca488..163da0eb63 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -182,7 +182,7 @@ export type ProviderType = | 'vertexai' | 'mistral' -export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search' +export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search' | 'rerank' export type EndpointType = 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank' @@ -192,6 +192,15 @@ export type ModelPricing = { currencySymbol?: string } +export type ModelCapability = { + type: ModelType + /** + * 是否为用户手动选择,如果为true,则表示用户手动选择了该类型,否则表示用户手动禁止了该模型;如果为undefined,则表示使用默认值 + * Is it manually selected by the user? If true, it means the user manually selected this type; otherwise, it means the user * manually disabled the model. + */ + isUserSelected?: boolean +} + export type Model = { id: string provider: string @@ -199,6 +208,10 @@ export type Model = { group: string owned_by?: string description?: string + capabilities?: ModelCapability[] + /** + * @deprecated + */ type?: ModelType[] pricing?: ModelPricing endpoint_type?: EndpointType diff --git a/src/renderer/src/utils/__tests__/collection.test.ts b/src/renderer/src/utils/__tests__/collection.test.ts new file mode 100644 index 0000000000..8fe6c04272 --- /dev/null +++ b/src/renderer/src/utils/__tests__/collection.test.ts @@ -0,0 +1,393 @@ +import { describe, expect, it } from 'vitest' + +import { getDifference, getIntersection, getUnion } from '../collection' + +describe('Collection Utils', () => { + // ================== Basic Types Tests ================== + + describe('getIntersection - Basic Types', () => { + it('should get intersection of number arrays', () => { + const arr1 = [1, 2, 3, 4] + const arr2 = [3, 4, 5, 6] + const result = getIntersection(arr1, arr2) + expect(result).toEqual([3, 4]) + }) + + it('should get intersection of string arrays', () => { + const arr1 = ['a', 'b', 'c'] + const arr2 = ['b', 'c', 'd'] + const result = getIntersection(arr1, arr2) + expect(result).toEqual(['b', 'c']) + }) + + it('should return empty array when no intersection', () => { + const arr1 = [1, 2, 3] + const arr2 = [4, 5, 6] + const result = getIntersection(arr1, arr2) + expect(result).toEqual([]) + }) + + it('should return empty array when one array is empty', () => { + const arr1 = [1, 2, 3] + const arr2: number[] = [] + const result = getIntersection(arr1, arr2) + expect(result).toEqual([]) + }) + + it('should return empty array when both arrays are empty', () => { + const arr1: number[] = [] + const arr2: number[] = [] + const result = getIntersection(arr1, arr2) + expect(result).toEqual([]) + }) + }) + + describe('getDifference - Basic Types', () => { + it('should get difference of number arrays', () => { + const arr1 = [1, 2, 3, 4] + const arr2 = [3, 4, 5, 6] + const result = getDifference(arr1, arr2) + expect(result).toEqual([1, 2]) + }) + + it('should get difference of string arrays', () => { + const arr1 = ['a', 'b', 'c'] + const arr2 = ['b', 'c', 'd'] + const result = getDifference(arr1, arr2) + expect(result).toEqual(['a']) + }) + + it('should return empty array when no difference', () => { + const arr1 = [1, 2, 3] + const arr2 = [1, 2, 3, 4, 5] + const result = getDifference(arr1, arr2) + expect(result).toEqual([]) + }) + + it('should return first array when second array is empty', () => { + const arr1 = [1, 2, 3] + const arr2: number[] = [] + const result = getDifference(arr1, arr2) + expect(result).toEqual([1, 2, 3]) + }) + }) + + describe('getUnion - Basic Types', () => { + it('should merge number arrays correctly', () => { + const arr1 = [1, 2, 3] + const arr2 = [3, 4, 5] + const result = getUnion(arr1, arr2) + expect(result).toEqual([1, 2, 3, 4, 5]) + }) + + it('should merge string arrays correctly', () => { + const arr1 = ['a', 'b'] + const arr2 = ['b', 'c'] + const result = getUnion(arr1, arr2) + expect(result).toEqual(['a', 'b', 'c']) + }) + + it('should merge arrays with no duplicates', () => { + const arr1 = [1, 2] + const arr2 = [3, 4] + const result = getUnion(arr1, arr2) + expect(result).toEqual([1, 2, 3, 4]) + }) + + it('should return other array when one is empty', () => { + const arr1 = [1, 2, 3] + const arr2: number[] = [] + const result = getUnion(arr1, arr2) + expect(result).toEqual([1, 2, 3]) + }) + }) + + // ================== Object Types Tests - Key Selector ================== + + interface User { + id: number + name: string + age: number + } + + const users1: User[] = [ + { id: 1, name: 'Alice', age: 25 }, + { id: 2, name: 'Bob', age: 30 }, + { id: 3, name: 'Charlie', age: 35 } + ] + + const users2: User[] = [ + { id: 2, name: 'Bob', age: 30 }, + { id: 3, name: 'Charlie', age: 36 }, + { id: 4, name: 'David', age: 28 } + ] + + describe('getIntersection - Object Types (Key Selector)', () => { + it('should get user intersection by id', () => { + const result = getIntersection(users1, users2, (user) => user.id) + expect(result).toEqual([ + { id: 2, name: 'Bob', age: 30 }, + { id: 3, name: 'Charlie', age: 35 } + ]) + }) + + it('should get user intersection by name', () => { + const result = getIntersection(users1, users2, (user) => user.name) + expect(result).toEqual([ + { id: 2, name: 'Bob', age: 30 }, + { id: 3, name: 'Charlie', age: 35 } + ]) + }) + + it('should return empty array when no intersection', () => { + const users3: User[] = [{ id: 5, name: 'Eve', age: 40 }] + const result = getIntersection(users1, users3, (user) => user.id) + expect(result).toEqual([]) + }) + }) + + describe('getDifference - Object Types (Key Selector)', () => { + it('should get user difference by id', () => { + const result = getDifference(users1, users2, (user) => user.id) + expect(result).toEqual([{ id: 1, name: 'Alice', age: 25 }]) + }) + + it('should get user difference by name', () => { + const result = getDifference(users1, users2, (user) => user.name) + expect(result).toEqual([{ id: 1, name: 'Alice', age: 25 }]) + }) + + it('should return correct difference', () => { + const result = getDifference(users2, users1, (user) => user.id) + expect(result).toEqual([{ id: 4, name: 'David', age: 28 }]) + }) + }) + + describe('getUnion - Object Types (Key Selector)', () => { + it('should merge user arrays by id', () => { + const result = getUnion(users1, users2, (user) => user.id) + expect(result).toEqual([ + { id: 1, name: 'Alice', age: 25 }, + { id: 2, name: 'Bob', age: 30 }, + { id: 3, name: 'Charlie', age: 35 }, + { id: 4, name: 'David', age: 28 } + ]) + }) + + it('should preserve first array element version', () => { + const result = getUnion(users1, users2, (user) => user.id) + const charlie = result.find((u) => u.id === 3) + expect(charlie?.age).toBe(35) + }) + }) + + // ================== Object Types Tests - Comparator Function ================== + + describe('getIntersection - Object Types (Comparator)', () => { + it('should use custom comparator correctly', () => { + const result = getIntersection(users1, users2, (a, b) => a.id === b.id && a.name === b.name) + expect(result).toEqual([ + { id: 2, name: 'Bob', age: 30 }, + { id: 3, name: 'Charlie', age: 35 } + ]) + }) + + it('should get users with similar age', () => { + const youngUsers: User[] = [ + { id: 5, name: 'Eve', age: 26 }, + { id: 6, name: 'Frank', age: 32 } + ] + + const result = getIntersection(users1, youngUsers, (a, b) => Math.abs(a.age - b.age) <= 5) + + expect(result).toEqual([ + { id: 1, name: 'Alice', age: 25 }, + { id: 2, name: 'Bob', age: 30 }, + { id: 3, name: 'Charlie', age: 35 } + ]) + }) + }) + + describe('getDifference - Object Types (Comparator)', () => { + it('should use custom comparator for difference', () => { + const result = getDifference(users1, users2, (a, b) => a.id === b.id && a.name === b.name) + expect(result).toEqual([{ id: 1, name: 'Alice', age: 25 }]) + }) + + it('should consider all properties in comparison', () => { + const result = getDifference(users1, users2, (a, b) => a.id === b.id && a.name === b.name && a.age === b.age) + expect(result).toEqual([ + { id: 1, name: 'Alice', age: 25 }, + { id: 3, name: 'Charlie', age: 35 } + ]) + }) + }) + + describe('getUnion - Object Types (Comparator)', () => { + it('should merge arrays using custom comparator', () => { + const result = getUnion(users1, users2, (a, b) => a.id === b.id) + expect(result).toEqual([ + { id: 1, name: 'Alice', age: 25 }, + { id: 2, name: 'Bob', age: 30 }, + { id: 3, name: 'Charlie', age: 35 }, + { id: 4, name: 'David', age: 28 } + ]) + }) + + it('should handle complex comparison logic', () => { + const products1 = [ + { id: 1, category: 'electronics', price: 100 }, + { id: 2, category: 'books', price: 20 } + ] + + const products2 = [ + { id: 3, category: 'electronics', price: 150 }, + { id: 4, category: 'clothing', price: 50 } + ] + + const result = getUnion(products1, products2, (a, b) => a.category === b.category) + + expect(result).toEqual([ + { id: 1, category: 'electronics', price: 100 }, + { id: 2, category: 'books', price: 20 }, + { id: 4, category: 'clothing', price: 50 } + ]) + }) + }) + + // ================== Edge Cases ================== + + describe('Edge Cases', () => { + it('should handle identical arrays', () => { + const arr = [1, 2, 3] + + expect(getIntersection(arr, arr)).toEqual([1, 2, 3]) + expect(getDifference(arr, arr)).toEqual([]) + expect(getUnion(arr, arr)).toEqual([1, 2, 3]) + }) + + it('should handle arrays with duplicates', () => { + const arr1 = [1, 1, 2, 2, 3] + const arr2 = [2, 2, 3, 3, 4] + + expect(getIntersection(arr1, arr2)).toEqual([2, 2, 3]) + expect(getDifference(arr1, arr2)).toEqual([1, 1]) + expect(getUnion(arr1, arr2)).toEqual([1, 2, 3, 4]) + }) + + it('should handle object array duplicates with key selector', () => { + const arr1 = [ + { id: 1, name: 'A' }, + { id: 1, name: 'A' }, + { id: 2, name: 'B' } + ] + const arr2 = [ + { id: 2, name: 'B' }, + { id: 3, name: 'C' } + ] + + const intersection = getIntersection(arr1, arr2, (item) => item.id) + expect(intersection).toEqual([{ id: 2, name: 'B' }]) + + const union = getUnion(arr1, arr2, (item) => item.id) + expect(union).toEqual([ + { id: 1, name: 'A' }, + { id: 2, name: 'B' }, + { id: 3, name: 'C' } + ]) + }) + }) + + // ================== Type Safety Tests ================== + + describe('Type Safety', () => { + it('should correctly infer return types', () => { + const numbers = [1, 2, 3] + const strings = ['a', 'b', 'c'] + + const numberResult = getIntersection(numbers, numbers) + const stringResult = getIntersection(strings, strings) + + expect(typeof numberResult[0]).toBe('number') + expect(typeof stringResult[0]).toBe('string') + }) + + it('should support complex object types', () => { + interface ComplexObject { + nested: { + value: number + } + array: string[] + } + + const complex1: ComplexObject[] = [ + { nested: { value: 1 }, array: ['a'] }, + { nested: { value: 2 }, array: ['b'] } + ] + + const complex2: ComplexObject[] = [ + { nested: { value: 2 }, array: ['b'] }, + { nested: { value: 3 }, array: ['c'] } + ] + + const result = getIntersection(complex1, complex2, (obj) => obj.nested.value) + expect(result).toEqual([{ nested: { value: 2 }, array: ['b'] }]) + }) + + it('should demonstrate why objects need comparators', () => { + const obj1 = [{ id: 1, name: 'Alice' }] + const obj2 = [{ id: 1, name: 'Alice' }] + + // Bypass TypeScript type checking with 'any' to show runtime behavior + const anyObj1 = obj1 as any + const anyObj2 = obj2 as any + + // Without comparator, objects are compared by reference, not content + const result = getIntersection(anyObj1, anyObj2) + expect(result).toEqual([]) + + // With proper key selector, it works correctly + const correctResult = getIntersection(obj1, obj2, (item) => item.id) + expect(correctResult).toEqual([{ id: 1, name: 'Alice' }]) + }) + + it('should enforce type constraints at compile time', () => { + const obj1 = [{ id: 1, name: 'Alice' }] + const obj2 = [{ id: 1, name: 'Alice' }] + + // The following would cause TypeScript compilation errors: + // + // ❌ Error: Type '{ id: number; name: string; }' does not satisfy the constraint 'string | number | boolean | null | undefined' + // getIntersection(obj1, obj2) + // + // ❌ Error: Expected 3 arguments, but got 2. Object types require a comparator. + // getDifference(obj1, obj2) + // + // ❌ Error: Expected 3 arguments, but got 2. Object types require a comparator. + // getUnion(obj1, obj2) + + // ✅ Correct usage with key selector + const intersection = getIntersection(obj1, obj2, (item) => item.id) + const difference = getDifference(obj1, obj2, (item) => item.id) + const union = getUnion(obj1, obj2, (item) => item.id) + + expect(intersection).toEqual([{ id: 1, name: 'Alice' }]) + expect(difference).toEqual([]) + expect(union).toEqual([{ id: 1, name: 'Alice' }]) + }) + + it('should work correctly with primitive types without comparator', () => { + const nums1 = [1, 2, 3] + const nums2 = [2, 3, 4] + + const intersection = getIntersection(nums1, nums2) + expect(intersection).toEqual([2, 3]) + + const difference = getDifference(nums1, nums2) + expect(difference).toEqual([1]) + + const union = getUnion(nums1, nums2) + expect(union).toEqual([1, 2, 3, 4]) + }) + }) +}) diff --git a/src/renderer/src/utils/collection.ts b/src/renderer/src/utils/collection.ts new file mode 100644 index 0000000000..7fd033624e --- /dev/null +++ b/src/renderer/src/utils/collection.ts @@ -0,0 +1,89 @@ +// Type-safe collection operations with strict compile-time enforcement + +type PrimitiveType = string | number | boolean | null | undefined + +// getIntersection - with strict overloads +export function getIntersection(arr1: T[], arr2: T[]): T[] +export function getIntersection(arr1: T[], arr2: T[], keySelector: (item: T) => K): T[] +export function getIntersection(arr1: T[], arr2: T[], compareFn: (a: T, b: T) => boolean): T[] +export function getIntersection( + arr1: T[], + arr2: T[], + comparator?: ((item: T) => any) | ((a: T, b: T) => boolean) +): T[] { + if (!comparator) { + const set2 = new Set(arr2) + return arr1.filter((element) => set2.has(element)) + } + + if (comparator.length === 1) { + const keySelector = comparator as (item: T) => any + const keySet = new Set(arr2.map(keySelector)) + return arr1.filter((item) => keySet.has(keySelector(item))) + } else { + const compareFn = comparator as (a: T, b: T) => boolean + return arr1.filter((item1) => arr2.some((item2) => compareFn(item1, item2))) + } +} + +// getDifference - with strict overloads +export function getDifference(arr1: T[], arr2: T[]): T[] +export function getDifference(arr1: T[], arr2: T[], keySelector: (item: T) => K): T[] +export function getDifference(arr1: T[], arr2: T[], compareFn: (a: T, b: T) => boolean): T[] +export function getDifference( + arr1: T[], + arr2: T[], + comparator?: ((item: T) => any) | ((a: T, b: T) => boolean) +): T[] { + if (!comparator) { + const set2 = new Set(arr2) + return arr1.filter((element) => !set2.has(element)) + } + + if (comparator.length === 1) { + const keySelector = comparator as (item: T) => any + const keySet = new Set(arr2.map(keySelector)) + return arr1.filter((item) => !keySet.has(keySelector(item))) + } else { + const compareFn = comparator as (a: T, b: T) => boolean + return arr1.filter((item1) => !arr2.some((item2) => compareFn(item1, item2))) + } +} + +// getUnion - with strict overloads +export function getUnion(arr1: T[], arr2: T[]): T[] +export function getUnion(arr1: T[], arr2: T[], keySelector: (item: T) => K): T[] +export function getUnion(arr1: T[], arr2: T[], compareFn: (a: T, b: T) => boolean): T[] +export function getUnion(arr1: T[], arr2: T[], comparator?: ((item: T) => any) | ((a: T, b: T) => boolean)): T[] { + if (!comparator) { + return Array.from(new Set([...arr1, ...arr2])) + } + + if (comparator.length === 1) { + const keySelector = comparator as (item: T) => any + const seen = new Set() + const result: T[] = [] + + for (const item of [...arr1, ...arr2]) { + const key = keySelector(item) + if (!seen.has(key)) { + seen.add(key) + result.push(item) + } + } + + return result + } else { + const compareFn = comparator as (a: T, b: T) => boolean + const result = [...arr1] + + for (const item2 of arr2) { + const exists = result.some((item1) => compareFn(item1, item2)) + if (!exists) { + result.push(item2) + } + } + + return result + } +} diff --git a/src/renderer/src/utils/index.ts b/src/renderer/src/utils/index.ts index dca7dfd148..9be1edb7c6 100644 --- a/src/renderer/src/utils/index.ts +++ b/src/renderer/src/utils/index.ts @@ -1,5 +1,5 @@ import { loggerService } from '@logger' -import { Model, Provider } from '@renderer/types' +import { Model, ModelType, Provider } from '@renderer/types' import { ModalFuncProps } from 'antd' import { v4 as uuidv4 } from 'uuid' @@ -227,7 +227,19 @@ export function isOpenAIProvider(provider: Provider): boolean { return !['anthropic', 'gemini', 'vertexai'].includes(provider.type) } +/** + * 判断模型是否为用户手动选择 + * @param {Model} model 模型对象 + * @param {ModelType} type 模型类型 + * @returns {boolean} 是否为用户手动选择 + */ +export function isUserSelectedModelType(model: Model, type: ModelType): boolean | undefined { + const t = model.capabilities?.find((t) => t.type === type) + return t ? t.isUserSelected : undefined +} + export * from './api' +export * from './collection' export * from './file' export * from './image' export * from './json'