diff --git a/electron.vite.config.ts b/electron.vite.config.ts index 9b22ffc33b..fa4b234a54 100644 --- a/electron.vite.config.ts +++ b/electron.vite.config.ts @@ -68,12 +68,16 @@ export default defineConfig({ } }, optimizeDeps: { - exclude: ['pyodide'] + exclude: ['pyodide'], + esbuildOptions: { + target: 'esnext' // for dev + } }, worker: { format: 'es' }, build: { + target: 'esnext', // for build rollupOptions: { input: { index: resolve(__dirname, 'src/renderer/index.html'), diff --git a/package.json b/package.json index 3ee50ee69a..72e1f64856 100644 --- a/package.json +++ b/package.json @@ -168,6 +168,7 @@ "husky": "^9.1.7", "i18next": "^23.11.5", "jest-styled-components": "^7.2.0", + "linguist-languages": "^8.0.0", "lint-staged": "^15.5.0", "lodash": "^4.17.21", "lru-cache": "^11.1.0", diff --git a/src/renderer/src/components/CodeBlockView/index.tsx b/src/renderer/src/components/CodeBlockView/index.tsx index 944e6f66e3..811b8665cc 100644 --- a/src/renderer/src/components/CodeBlockView/index.tsx +++ b/src/renderer/src/components/CodeBlockView/index.tsx @@ -4,7 +4,7 @@ import { CodeTool, CodeToolbar, TOOL_SPECS, useCodeTool } from '@renderer/compon import { useSettings } from '@renderer/hooks/useSettings' import { pyodideService } from '@renderer/services/PyodideService' import { extractTitle } from '@renderer/utils/formats' -import { isValidPlantUML } from '@renderer/utils/markdown' +import { getExtensionByLanguage, isValidPlantUML } from '@renderer/utils/markdown' import dayjs from 'dayjs' import { CirclePlay, CodeXml, Copy, Download, Eye, Square, SquarePen, SquareSplitHorizontal } from 'lucide-react' import React, { memo, useCallback, useEffect, useMemo, useState } from 'react' @@ -67,23 +67,21 @@ const CodeBlockView: React.FC = ({ children, language, onSave }) => { window.message.success({ content: t('code_block.copy.success'), key: 'copy-code' }) }, [children, t]) - const handleDownloadSource = useCallback(() => { + const handleDownloadSource = useCallback(async () => { let fileName = '' - // 尝试提取标题 + // 尝试提取 HTML 标题 if (language === 'html' && children.includes('')) { - const title = extractTitle(children) - if (title) { - fileName = `${title}.html` - } + fileName = extractTitle(children) || '' } // 默认使用日期格式命名 if (!fileName) { - fileName = `${dayjs().format('YYYYMMDDHHmm')}.${language}` + fileName = `${dayjs().format('YYYYMMDDHHmm')}` } - window.api.file.save(fileName, children) + const ext = await getExtensionByLanguage(language) + window.api.file.save(`${fileName}${ext}`, children) }, [children, language]) const handleRunScript = useCallback(() => { diff --git a/src/renderer/src/utils/__tests__/markdown.test.ts b/src/renderer/src/utils/__tests__/markdown.test.ts index cbde058d24..7b82816f17 100644 --- a/src/renderer/src/utils/__tests__/markdown.test.ts +++ b/src/renderer/src/utils/__tests__/markdown.test.ts @@ -7,6 +7,7 @@ import { convertMathFormula, findCitationInChildren, getCodeBlockId, + getExtensionByLanguage, markdownToPlainText, removeTrailingDoubleSpaces, updateCodeBlock @@ -143,6 +144,67 @@ describe('markdown', () => { }) }) + describe('getExtensionByLanguage', () => { + // 批量测试语言名称到扩展名的映射 + const testLanguageExtensions = async (testCases: Record) => { + for (const [language, expectedExtension] of Object.entries(testCases)) { + const result = await getExtensionByLanguage(language) + expect(result).toBe(expectedExtension) + } + } + + it('should return extension for exact language name match', async () => { + await testLanguageExtensions({ + '4D': '.4dm', + 'C#': '.cs', + JavaScript: '.js', + TypeScript: '.ts', + 'Objective-C++': '.mm', + Python: '.py', + SVG: '.svg', + 'Visual Basic .NET': '.vb' + }) + }) + + it('should return extension for case-insensitive language name match', async () => { + await testLanguageExtensions({ + '4d': '.4dm', + 'c#': '.cs', + javascript: '.js', + typescript: '.ts', + 'objective-c++': '.mm', + python: '.py', + svg: '.svg', + 'visual basic .net': '.vb' + }) + }) + + it('should return extension for language aliases', async () => { + await testLanguageExtensions({ + js: '.js', + node: '.js', + 'obj-c++': '.mm', + 'objc++': '.mm', + 'objectivec++': '.mm', + py: '.py', + 'visual basic': '.vb' + }) + }) + + it('should return fallback extension for unknown languages', async () => { + await testLanguageExtensions({ + 'unknown-language': '.unknown-language', + custom: '.custom' + }) + }) + + it('should handle empty string input', async () => { + await testLanguageExtensions({ + '': '.' + }) + }) + }) + describe('getCodeBlockId', () => { it('should generate ID from position information', () => { // 从位置信息生成ID diff --git a/src/renderer/src/utils/markdown.ts b/src/renderer/src/utils/markdown.ts index e4881ed062..3c409d39b4 100644 --- a/src/renderer/src/utils/markdown.ts +++ b/src/renderer/src/utils/markdown.ts @@ -54,6 +54,60 @@ export function removeTrailingDoubleSpaces(markdown: string): string { return markdown.replace(/ {2}$/gm, '') } +const predefinedExtensionMap: Record = { + html: '.html', + javascript: '.js', + typescript: '.ts', + python: '.py', + json: '.json', + markdown: '.md', + text: '.txt' +} + +/** + * 根据语言名称获取文件扩展名 + * - 先精确匹配,再忽略大小写,最后匹配别名 + * - 返回第一个扩展名 + * @param language 语言名称 + * @returns 文件扩展名 + */ +export async function getExtensionByLanguage(language: string): Promise { + const lowerLanguage = language.toLowerCase() + + // 常用的扩展名 + const predefined = predefinedExtensionMap[lowerLanguage] + if (predefined) { + return predefined + } + + const languages = await import('linguist-languages') + + // 精确匹配语言名称 + const directMatch = languages[language as keyof typeof languages] as any + if (directMatch?.extensions?.[0]) { + return directMatch.extensions[0] + } + + // 大小写不敏感的语言名称匹配 + for (const [langName, data] of Object.entries(languages)) { + const languageData = data as any + if (langName.toLowerCase() === lowerLanguage && languageData.extensions?.[0]) { + return languageData.extensions[0] + } + } + + // 通过别名匹配 + for (const [, data] of Object.entries(languages)) { + const languageData = data as any + if (languageData.aliases?.includes(lowerLanguage)) { + return languageData.extensions?.[0] || `.${language}` + } + } + + // 回退到语言名称 + return `.${language}` +} + /** * 根据代码块节点的起始位置生成 ID * @param start 代码块节点的起始位置 diff --git a/yarn.lock b/yarn.lock index a0daa3b3ce..688f7f8f5d 100644 --- a/yarn.lock +++ b/yarn.lock @@ -5688,6 +5688,7 @@ __metadata: i18next: "npm:^23.11.5" jest-styled-components: "npm:^7.2.0" jsdom: "npm:26.1.0" + linguist-languages: "npm:^8.0.0" lint-staged: "npm:^15.5.0" lodash: "npm:^4.17.21" lru-cache: "npm:^11.1.0" @@ -11874,6 +11875,13 @@ __metadata: languageName: node linkType: hard +"linguist-languages@npm:^8.0.0": + version: 8.0.0 + resolution: "linguist-languages@npm:8.0.0" + checksum: 10c0/eaae46254247b9aa5b287ac98e062e7fe859314328ce305e34e152bc7bb172d69633999320cb47dc2a710388179712a76bb1ddd6e39e249af2684a4f0a66256c + languageName: node + linkType: hard + "linkify-it@npm:^5.0.0": version: 5.0.0 resolution: "linkify-it@npm:5.0.0"