diff --git a/packages/shared/data/cache/cacheSchemas.ts b/packages/shared/data/cache/cacheSchemas.ts index 9be41126c0..51b83b513f 100644 --- a/packages/shared/data/cache/cacheSchemas.ts +++ b/packages/shared/data/cache/cacheSchemas.ts @@ -8,6 +8,8 @@ export type UseCacheSchema = { // App state 'app.dist.update_state': CacheValueTypes.CacheAppUpdateState 'app.user.avatar': string + /** Used to indicate whether any asynchronous task with a given ID is currently in progress */ + 'app.pending_map': Record // Chat context 'chat.multi_select_mode': boolean @@ -53,6 +55,7 @@ export const DefaultUseCache: UseCacheSchema = { available: false }, 'app.user.avatar': '', + 'app.pending_map': {}, // Chat context 'chat.multi_select_mode': false, diff --git a/src/renderer/src/components/ActionTools/constants.ts b/src/renderer/src/components/ActionTools/constants.ts index bade7d123a..30b3197d6a 100644 --- a/src/renderer/src/components/ActionTools/constants.ts +++ b/src/renderer/src/components/ActionTools/constants.ts @@ -1,6 +1,6 @@ import type { ActionToolSpec } from './types' -export const TOOL_SPECS: Record = { +export const TOOL_SPECS = { // Core tools copy: { id: 'copy', @@ -72,5 +72,10 @@ export const TOOL_SPECS: Record = { id: 'zoom-out', type: 'quick', order: 41 + }, + mermaid_fix: { + id: 'mermaid-fix', + type: 'core', + order: 42 } -} +} as const satisfies Record diff --git a/src/renderer/src/components/CodeBlockView/view.tsx b/src/renderer/src/components/CodeBlockView/view.tsx index cc978b3f8c..c703578382 100644 --- a/src/renderer/src/components/CodeBlockView/view.tsx +++ b/src/renderer/src/components/CodeBlockView/view.tsx @@ -26,6 +26,7 @@ import React, { memo, startTransition, useCallback, useEffect, useMemo, useRef, import { useTranslation } from 'react-i18next' import styled, { css } from 'styled-components' +import { useMermaidFixTool } from '../CodeToolbar/hooks/useMermaidFixTool' import { SPECIAL_VIEW_COMPONENTS, SPECIAL_VIEWS } from './constants' import StatusBar from './StatusBar' import type { ViewMode } from './types' @@ -33,9 +34,12 @@ import type { ViewMode } from './types' const logger = loggerService.withContext('CodeBlockView') interface Props { + // FIXME: It's not runtime string! children: string language: string - onSave?: (newContent: string) => void + // Message Block ID + blockId: string + onSave: (newContent: string) => void } /** @@ -54,7 +58,7 @@ interface Props { * - quick 工具 * - core 工具 */ -export const CodeBlockView: React.FC = memo(({ children, language, onSave }) => { +export const CodeBlockView: React.FC = memo(({ children: code, language, blockId, onSave }) => { const { t } = useTranslation() const [codeExecutionEnabled] = usePreference('chat.code.execution.enabled') @@ -113,6 +117,8 @@ export const CodeBlockView: React.FC = memo(({ children, language, onSave const specialViewRef = useRef(null) const hasSpecialView = useMemo(() => SPECIAL_VIEWS.includes(language), [language]) + const [error, setError] = useState(null) + const isMermaid = language === 'mermaid' const isInSpecialView = useMemo(() => { return hasSpecialView && viewMode === 'special' @@ -146,16 +152,16 @@ export const CodeBlockView: React.FC = memo(({ children, language, onSave }, []) const handleCopySource = useCallback(() => { - navigator.clipboard.writeText(children) + navigator.clipboard.writeText(code) window.toast.success(t('code_block.copy.success')) - }, [children, t]) + }, [code, t]) const handleDownloadSource = useCallback(() => { let fileName = '' // 尝试提取 HTML 标题 if (language === 'html') { - fileName = getFileNameFromHtmlTitle(extractHtmlTitle(children)) || '' + fileName = getFileNameFromHtmlTitle(extractHtmlTitle(code)) || '' } // 默认使用日期格式命名 @@ -164,15 +170,15 @@ export const CodeBlockView: React.FC = memo(({ children, language, onSave } const ext = getExtensionByLanguage(language) - window.api.file.save(`${fileName}${ext}`, children) - }, [children, language]) + window.api.file.save(`${fileName}${ext}`, code) + }, [code, language]) const handleRunScript = useCallback(() => { setIsRunning(true) setExecutionResult(null) pyodideService - .runScript(children, {}, codeExecutionTimeoutMinutes * 60000) + .runScript(code, {}, codeExecutionTimeoutMinutes * 60000) .then((result) => { setExecutionResult(result) }) @@ -185,7 +191,7 @@ export const CodeBlockView: React.FC = memo(({ children, language, onSave .finally(() => { setIsRunning(false) }) - }, [children, codeExecutionTimeoutMinutes]) + }, [code, codeExecutionTimeoutMinutes]) const showPreviewTools = useMemo(() => { return viewMode !== 'source' && hasSpecialView @@ -257,6 +263,19 @@ export const CodeBlockView: React.FC = memo(({ children, language, onSave setTools }) + // Mermaid fix tool + useMermaidFixTool({ + enabled: isMermaid && error !== undefined && error !== null, + context: { + blockId, + error, + content: code + }, + setError, + onSave, + setTools + }) + // 源代码视图组件 const sourceView = useMemo( () => @@ -266,7 +285,7 @@ export const CodeBlockView: React.FC = memo(({ children, language, onSave ref={sourceViewRef} theme={activeCmTheme} fontSize={fontSize - 1} - value={children} + value={code} language={language} onSave={onSave} onHeightChange={handleHeightChange} @@ -278,7 +297,7 @@ export const CodeBlockView: React.FC = memo(({ children, language, onSave ) : ( = memo(({ children, language, onSave ), [ activeCmTheme, - children, + code, codeEditor, codeShowLineNumbers, fontSize, @@ -307,11 +326,11 @@ export const CodeBlockView: React.FC = memo(({ children, language, onSave if (!SpecialView) return null return ( - - {children} + + {code} ) - }, [children, codeImageTools, language]) + }, [code, codeImageTools, language]) const renderHeader = useMemo(() => { const langTag = '<' + language.toUpperCase() + '>' diff --git a/src/renderer/src/components/CodeToolbar/hooks/useMermaidFixTool.tsx b/src/renderer/src/components/CodeToolbar/hooks/useMermaidFixTool.tsx new file mode 100644 index 0000000000..5df00c86ce --- /dev/null +++ b/src/renderer/src/components/CodeToolbar/hooks/useMermaidFixTool.tsx @@ -0,0 +1,178 @@ +import { loggerService } from '@logger' +import type { ActionTool } from '@renderer/components/ActionTools' +import { TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools' +import { usePendingMap } from '@renderer/hooks/usePendingMap' +import { useQuickCompletion } from '@renderer/hooks/useQuickCompletion' +import { useSettings } from '@renderer/hooks/useSettings' +import type { Chunk } from '@renderer/types/chunk' +import { ChunkType } from '@renderer/types/chunk' +import { getErrorMessage, parseJSON } from '@renderer/utils' +import { abortCompletion, readyToAbort } from '@renderer/utils/abortController' +import { WrenchIcon } from 'lucide-react' +import { useCallback, useEffect, useRef } from 'react' +import { useTranslation } from 'react-i18next' +import * as z from 'zod' + +const logger = loggerService.withContext('useMermaidFixTool') + +interface UseMermaidFixTool { + enabled?: boolean + context: { + /** Code block id */ + blockId: string + /** Error */ + error: unknown + /** Mermaid code */ + content: string + } + onSave: (newContent: string) => void + setError: (error: unknown) => void + setTools: React.Dispatch> +} + +const ResultSchema = z.union([ + z.object({ + fixed: z.literal(true), + result: z.string() + }), + z.object({ + fixed: z.literal(false), + reason: z.string() + }) +]) + +type Input = { + // mermaid 代码 + mermaid: string + // 错误信息 + error: string + // 用户语言代码, 如 zh-cn, en-us + lang: string +} + +const SYSTEM_PROMPT = ` +You are an AI assistant that fixes Mermaid code. The input is a JSON string with the following structure: {"mermaid": "the Mermaid code", "error": "the error message from rendering", "lang": "the user's language code"}. + +Your task is to analyze the error and the Mermaid code. If the error is due to a mistake in the Mermaid code, fix it and output a JSON string with {"fixed": true, "result": "the fixed Mermaid code"}. If the error is not caused by the code (e.g., environment issues, unsupported features, or other non-code errors), output {"fixed": false, "reason": "a brief explanation in the language specified by the 'lang' field"}. + +Your output must be a pure JSON string with no additional text, comments, or formatting. + +Example input: +{ + "mermaid": "graph TD\nA[Start] --> B{Error?}", + "error": "Syntax error: unexpected token", + "lang": "en-us" +} + +Example outputs: +- If fixed: {"fixed": true, "result": "graph TD\nA[Start] --> B{Error?}\nB -->|Yes| C[End]"} +- If not fixed: {"fixed": false, "reason": "The error is due to an unsupported feature in the current environment."} + +` + +export const useMermaidFixTool = ({ enabled, context, onSave, setError, setTools }: UseMermaidFixTool) => { + const { t } = useTranslation() + const { registerTool, removeTool } = useToolManager(setTools) + const { language } = useSettings() + const completion = useQuickCompletion(SYSTEM_PROMPT) + + const { error, content, blockId } = context + const abortKeyRef = useRef(null) + + const { setPending } = usePendingMap() + logger.debug('input', { + mermaid: content, + error: getErrorMessage(error), + lang: language + }) + const prompt = JSON.stringify({ + mermaid: content, + error: getErrorMessage(error), + lang: language + } satisfies Input) + + const fixCode = useCallback(async () => { + setPending(blockId, true) + const abortKey = crypto.randomUUID() + abortKeyRef.current = abortKey + const signal = readyToAbort(abortKey) + let result = '' + + const onChunk = (chunk: Chunk) => { + if (chunk.type === ChunkType.TEXT_DELTA) { + result = chunk.text + } + } + + try { + await completion({ + prompt, + onChunk, + params: { + options: { + signal + } + } + }) + } catch (e) { + window.toast.error({ title: t('code_block.mermaid_fix.failed'), description: getErrorMessage(e) }) + return + } + + result = result.trim() + logger.debug('output', { result }) + + const parsedJson = parseJSON(result) + if (parsedJson === null) { + window.toast.error({ + title: t('code_block.mermaid_fix.failed'), + description: t('code_block.mermaid_fix.invalid_result') + }) + } else { + logger.debug('parseJSON success', { parsedJson }) + const parsedResult = ResultSchema.safeParse(parsedJson) + logger.debug('validation', { parsedResult }) + + if (parsedResult.success) { + const validResult = parsedResult.data + if (validResult.fixed) { + onSave(validResult.result) + setError(undefined) + } else { + window.toast.warning({ title: t('code_block.mermaid_fix.failed'), description: validResult.reason }) + } + } else { + window.toast.error({ + title: t('code_block.mermaid_fix.failed'), + description: t('code_block.mermaid_fix.invalid_result') + }) + } + } + + setPending(blockId, false) + }, [setPending, blockId, completion, prompt, onSave, t]) + + // when unmounted + useEffect(() => { + return () => { + const abortKey = abortKeyRef.current + if (abortKey) { + abortCompletion(abortKey) + } + } + }, []) + + useEffect(() => { + if (enabled) { + registerTool({ + ...TOOL_SPECS.mermaid_fix, + icon: , + tooltip: t('code_block.mermaid_fix.label'), + visible: () => error !== undefined && error !== null, + onClick: fixCode + }) + } + + return () => removeTool(TOOL_SPECS.expand.id) + }, [enabled, error, fixCode, registerTool, removeTool, t]) +} diff --git a/src/renderer/src/components/Preview/MermaidPreview.tsx b/src/renderer/src/components/Preview/MermaidPreview.tsx index fcb09af64a..2b7a777139 100644 --- a/src/renderer/src/components/Preview/MermaidPreview.tsx +++ b/src/renderer/src/components/Preview/MermaidPreview.tsx @@ -16,8 +16,9 @@ import { renderSvgInShadowHost } from './utils' const MermaidPreview = ({ children, enableToolbar = false, - ref -}: BasicPreviewProps & { ref?: React.RefObject }) => { + ref, + onError +}: BasicPreviewProps & { ref?: React.RefObject; onError?: (error: unknown) => void }) => { const { mermaid, isLoading: isLoadingMermaid, error: mermaidError, forceRenderKey } = useMermaid() const diagramId = useRef(`mermaid-${nanoid(6)}`).current const [isVisible, setIsVisible] = useState(true) @@ -122,6 +123,12 @@ const MermaidPreview = ({ const isLoading = isLoadingMermaid || isRendering const error = mermaidError || renderError + useEffect(() => { + if (error !== undefined && error !== null && onError !== undefined) { + onError(error) + } + }, [error, onError]) + return ( = ({ children, className, node, blockId }) => { const isMultiline = children?.includes('\n') const language = languageMatch?.[1] ?? (isMultiline ? 'text' : null) const { codeFancyBlock } = useSettings() + const { isPending } = usePendingMap() + const isBlockPending = isPending(blockId) // 代码块 id const id = useMemo(() => getCodeBlockId(node?.position?.start), [node?.position?.start]) @@ -52,14 +56,24 @@ const CodeBlock: React.FC = ({ children, className, node, blockId }) => { } return ( - - {children} - +
+ + {children} + + {isBlockPending && ( + + + + )} + {/* + + */} +
) } return ( - + {children} ) diff --git a/src/renderer/src/utils/abortController.ts b/src/renderer/src/utils/abortController.ts index 11bd7a791d..6c853be53b 100644 --- a/src/renderer/src/utils/abortController.ts +++ b/src/renderer/src/utils/abortController.ts @@ -1,5 +1,7 @@ import { loggerService } from '@logger' +// TODO: We may refactor it to a service with pendingMap + const logger = loggerService.withContext('AbortController') export const abortMap = new Map void)[]>()