feat(mermaid): add error handling and fix tool for mermaid diagrams

- Add mermaid_fix tool to handle and fix mermaid diagram errors
- Implement pending state tracking for async operations
- Add error callback to MermaidPreview component
- Update i18n strings for mermaid error handling
- Extend cache schema with pending_map for task tracking
This commit is contained in:
icarus 2025-10-22 06:12:21 +08:00
parent 0fb5480b0a
commit 1d94d56e2a
9 changed files with 258 additions and 24 deletions

View File

@ -8,6 +8,8 @@ export type UseCacheSchema = {
// App state
'app.dist.update_state': CacheValueTypes.CacheAppUpdateState
'app.user.avatar': string
/** Used to indicate whether any asynchronous task with a given ID is currently in progress */
'app.pending_map': Record<string, boolean>
// Chat context
'chat.multi_select_mode': boolean
@ -53,6 +55,7 @@ export const DefaultUseCache: UseCacheSchema = {
available: false
},
'app.user.avatar': '',
'app.pending_map': {},
// Chat context
'chat.multi_select_mode': false,

View File

@ -1,6 +1,6 @@
import type { ActionToolSpec } from './types'
export const TOOL_SPECS: Record<string, ActionToolSpec> = {
export const TOOL_SPECS = {
// Core tools
copy: {
id: 'copy',
@ -72,5 +72,10 @@ export const TOOL_SPECS: Record<string, ActionToolSpec> = {
id: 'zoom-out',
type: 'quick',
order: 41
},
mermaid_fix: {
id: 'mermaid-fix',
type: 'core',
order: 42
}
}
} as const satisfies Record<string, ActionToolSpec>

View File

@ -26,6 +26,7 @@ import React, { memo, startTransition, useCallback, useEffect, useMemo, useRef,
import { useTranslation } from 'react-i18next'
import styled, { css } from 'styled-components'
import { useMermaidFixTool } from '../CodeToolbar/hooks/useMermaidFixTool'
import { SPECIAL_VIEW_COMPONENTS, SPECIAL_VIEWS } from './constants'
import StatusBar from './StatusBar'
import type { ViewMode } from './types'
@ -33,9 +34,12 @@ import type { ViewMode } from './types'
const logger = loggerService.withContext('CodeBlockView')
interface Props {
// FIXME: It's not runtime string!
children: string
language: string
onSave?: (newContent: string) => void
// Message Block ID
blockId: string
onSave: (newContent: string) => void
}
/**
@ -54,7 +58,7 @@ interface Props {
* - quick
* - core
*/
export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave }) => {
export const CodeBlockView: React.FC<Props> = memo(({ children: code, language, blockId, onSave }) => {
const { t } = useTranslation()
const [codeExecutionEnabled] = usePreference('chat.code.execution.enabled')
@ -113,6 +117,8 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
const specialViewRef = useRef<BasicPreviewHandles>(null)
const hasSpecialView = useMemo(() => SPECIAL_VIEWS.includes(language), [language])
const [error, setError] = useState<unknown>(null)
const isMermaid = language === 'mermaid'
const isInSpecialView = useMemo(() => {
return hasSpecialView && viewMode === 'special'
@ -146,16 +152,16 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
}, [])
const handleCopySource = useCallback(() => {
navigator.clipboard.writeText(children)
navigator.clipboard.writeText(code)
window.toast.success(t('code_block.copy.success'))
}, [children, t])
}, [code, t])
const handleDownloadSource = useCallback(() => {
let fileName = ''
// 尝试提取 HTML 标题
if (language === 'html') {
fileName = getFileNameFromHtmlTitle(extractHtmlTitle(children)) || ''
fileName = getFileNameFromHtmlTitle(extractHtmlTitle(code)) || ''
}
// 默认使用日期格式命名
@ -164,15 +170,15 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
}
const ext = getExtensionByLanguage(language)
window.api.file.save(`${fileName}${ext}`, children)
}, [children, language])
window.api.file.save(`${fileName}${ext}`, code)
}, [code, language])
const handleRunScript = useCallback(() => {
setIsRunning(true)
setExecutionResult(null)
pyodideService
.runScript(children, {}, codeExecutionTimeoutMinutes * 60000)
.runScript(code, {}, codeExecutionTimeoutMinutes * 60000)
.then((result) => {
setExecutionResult(result)
})
@ -185,7 +191,7 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
.finally(() => {
setIsRunning(false)
})
}, [children, codeExecutionTimeoutMinutes])
}, [code, codeExecutionTimeoutMinutes])
const showPreviewTools = useMemo(() => {
return viewMode !== 'source' && hasSpecialView
@ -257,6 +263,19 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
setTools
})
// Mermaid fix tool
useMermaidFixTool({
enabled: isMermaid && error !== undefined && error !== null,
context: {
blockId,
error,
content: code
},
setError,
onSave,
setTools
})
// 源代码视图组件
const sourceView = useMemo(
() =>
@ -266,7 +285,7 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
ref={sourceViewRef}
theme={activeCmTheme}
fontSize={fontSize - 1}
value={children}
value={code}
language={language}
onSave={onSave}
onHeightChange={handleHeightChange}
@ -278,7 +297,7 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
) : (
<CodeViewer
className="source-view"
value={children}
value={code}
language={language}
onHeightChange={handleHeightChange}
expanded={shouldExpand}
@ -288,7 +307,7 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
),
[
activeCmTheme,
children,
code,
codeEditor,
codeShowLineNumbers,
fontSize,
@ -307,11 +326,11 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
if (!SpecialView) return null
return (
<SpecialView ref={specialViewRef} enableToolbar={codeImageTools}>
{children}
<SpecialView ref={specialViewRef} enableToolbar={codeImageTools} onError={setError}>
{code}
</SpecialView>
)
}, [children, codeImageTools, language])
}, [code, codeImageTools, language])
const renderHeader = useMemo(() => {
const langTag = '<' + language.toUpperCase() + '>'

View File

@ -0,0 +1,178 @@
import { loggerService } from '@logger'
import type { ActionTool } from '@renderer/components/ActionTools'
import { TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools'
import { usePendingMap } from '@renderer/hooks/usePendingMap'
import { useQuickCompletion } from '@renderer/hooks/useQuickCompletion'
import { useSettings } from '@renderer/hooks/useSettings'
import type { Chunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import { getErrorMessage, parseJSON } from '@renderer/utils'
import { abortCompletion, readyToAbort } from '@renderer/utils/abortController'
import { WrenchIcon } from 'lucide-react'
import { useCallback, useEffect, useRef } from 'react'
import { useTranslation } from 'react-i18next'
import * as z from 'zod'
const logger = loggerService.withContext('useMermaidFixTool')
interface UseMermaidFixTool {
enabled?: boolean
context: {
/** Code block id */
blockId: string
/** Error */
error: unknown
/** Mermaid code */
content: string
}
onSave: (newContent: string) => void
setError: (error: unknown) => void
setTools: React.Dispatch<React.SetStateAction<ActionTool[]>>
}
const ResultSchema = z.union([
z.object({
fixed: z.literal(true),
result: z.string()
}),
z.object({
fixed: z.literal(false),
reason: z.string()
})
])
type Input = {
// mermaid 代码
mermaid: string
// 错误信息
error: string
// 用户语言代码, 如 zh-cn, en-us
lang: string
}
const SYSTEM_PROMPT = `
You are an AI assistant that fixes Mermaid code. The input is a JSON string with the following structure: {"mermaid": "the Mermaid code", "error": "the error message from rendering", "lang": "the user's language code"}.
Your task is to analyze the error and the Mermaid code. If the error is due to a mistake in the Mermaid code, fix it and output a JSON string with {"fixed": true, "result": "the fixed Mermaid code"}. If the error is not caused by the code (e.g., environment issues, unsupported features, or other non-code errors), output {"fixed": false, "reason": "a brief explanation in the language specified by the 'lang' field"}.
Your output must be a pure JSON string with no additional text, comments, or formatting.
Example input:
{
"mermaid": "graph TD\nA[Start] --> B{Error?}",
"error": "Syntax error: unexpected token",
"lang": "en-us"
}
Example outputs:
- If fixed: {"fixed": true, "result": "graph TD\nA[Start] --> B{Error?}\nB -->|Yes| C[End]"}
- If not fixed: {"fixed": false, "reason": "The error is due to an unsupported feature in the current environment."}
`
export const useMermaidFixTool = ({ enabled, context, onSave, setError, setTools }: UseMermaidFixTool) => {
const { t } = useTranslation()
const { registerTool, removeTool } = useToolManager(setTools)
const { language } = useSettings()
const completion = useQuickCompletion(SYSTEM_PROMPT)
const { error, content, blockId } = context
const abortKeyRef = useRef<string | null>(null)
const { setPending } = usePendingMap()
logger.debug('input', {
mermaid: content,
error: getErrorMessage(error),
lang: language
})
const prompt = JSON.stringify({
mermaid: content,
error: getErrorMessage(error),
lang: language
} satisfies Input)
const fixCode = useCallback(async () => {
setPending(blockId, true)
const abortKey = crypto.randomUUID()
abortKeyRef.current = abortKey
const signal = readyToAbort(abortKey)
let result = ''
const onChunk = (chunk: Chunk) => {
if (chunk.type === ChunkType.TEXT_DELTA) {
result = chunk.text
}
}
try {
await completion({
prompt,
onChunk,
params: {
options: {
signal
}
}
})
} catch (e) {
window.toast.error({ title: t('code_block.mermaid_fix.failed'), description: getErrorMessage(e) })
return
}
result = result.trim()
logger.debug('output', { result })
const parsedJson = parseJSON(result)
if (parsedJson === null) {
window.toast.error({
title: t('code_block.mermaid_fix.failed'),
description: t('code_block.mermaid_fix.invalid_result')
})
} else {
logger.debug('parseJSON success', { parsedJson })
const parsedResult = ResultSchema.safeParse(parsedJson)
logger.debug('validation', { parsedResult })
if (parsedResult.success) {
const validResult = parsedResult.data
if (validResult.fixed) {
onSave(validResult.result)
setError(undefined)
} else {
window.toast.warning({ title: t('code_block.mermaid_fix.failed'), description: validResult.reason })
}
} else {
window.toast.error({
title: t('code_block.mermaid_fix.failed'),
description: t('code_block.mermaid_fix.invalid_result')
})
}
}
setPending(blockId, false)
}, [setPending, blockId, completion, prompt, onSave, t])
// when unmounted
useEffect(() => {
return () => {
const abortKey = abortKeyRef.current
if (abortKey) {
abortCompletion(abortKey)
}
}
}, [])
useEffect(() => {
if (enabled) {
registerTool({
...TOOL_SPECS.mermaid_fix,
icon: <WrenchIcon size={'1rem'} className="tool-icon" />,
tooltip: t('code_block.mermaid_fix.label'),
visible: () => error !== undefined && error !== null,
onClick: fixCode
})
}
return () => removeTool(TOOL_SPECS.expand.id)
}, [enabled, error, fixCode, registerTool, removeTool, t])
}

View File

@ -16,8 +16,9 @@ import { renderSvgInShadowHost } from './utils'
const MermaidPreview = ({
children,
enableToolbar = false,
ref
}: BasicPreviewProps & { ref?: React.RefObject<BasicPreviewHandles | null> }) => {
ref,
onError
}: BasicPreviewProps & { ref?: React.RefObject<BasicPreviewHandles | null>; onError?: (error: unknown) => void }) => {
const { mermaid, isLoading: isLoadingMermaid, error: mermaidError, forceRenderKey } = useMermaid()
const diagramId = useRef<string>(`mermaid-${nanoid(6)}`).current
const [isVisible, setIsVisible] = useState(true)
@ -122,6 +123,12 @@ const MermaidPreview = ({
const isLoading = isLoadingMermaid || isRendering
const error = mermaidError || renderError
useEffect(() => {
if (error !== undefined && error !== null && onError !== undefined) {
onError(error)
}
}, [error, onError])
return (
<ImagePreviewLayout
loading={isLoading}

View File

@ -11,7 +11,8 @@ import { useDefaultModel } from './useAssistant'
*/
export type QuickCompletionParams = {
/**
* The prompt text to send to the model.
* The user message text (not the system prompt) to send to the model.
* The system prompt is set via the `systemPrompt` parameter passed to `useQuickCompletion`.
*/
prompt: string
/**

View File

@ -940,6 +940,11 @@
}
},
"expand": "Expand",
"mermaid_fix": {
"failed": "Failed to fix",
"invalid_result": "Model returned invalid data. Please try again or try changing the quick model.",
"label": "Fix mermaid error"
},
"more": "More",
"run": "Run",
"split": {

View File

@ -1,4 +1,6 @@
import { BlockingOverlay, cn, Spinner } from '@cherrystudio/ui'
import { CodeBlockView, HtmlArtifactsCard } from '@renderer/components/CodeBlockView'
import { usePendingMap } from '@renderer/hooks/usePendingMap'
import { useSettings } from '@renderer/hooks/useSettings'
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
import store from '@renderer/store'
@ -21,6 +23,8 @@ const CodeBlock: React.FC<Props> = ({ children, className, node, blockId }) => {
const isMultiline = children?.includes('\n')
const language = languageMatch?.[1] ?? (isMultiline ? 'text' : null)
const { codeFancyBlock } = useSettings()
const { isPending } = usePendingMap()
const isBlockPending = isPending(blockId)
// 代码块 id
const id = useMemo(() => getCodeBlockId(node?.position?.start), [node?.position?.start])
@ -52,14 +56,24 @@ const CodeBlock: React.FC<Props> = ({ children, className, node, blockId }) => {
}
return (
<CodeBlockView language={language} onSave={handleSave}>
{children}
</CodeBlockView>
<div className="relative">
<CodeBlockView language={language} onSave={handleSave} blockId={blockId}>
{children}
</CodeBlockView>
{isBlockPending && (
<BlockingOverlay isVisible={isBlockPending}>
<Spinner />
</BlockingOverlay>
)}
{/* <BlockingOverlay isVisible={true}>
<Spinner />
</BlockingOverlay> */}
</div>
)
}
return (
<code className={className} style={{ textWrap: 'wrap', fontSize: '95%', padding: '2px 4px' }}>
<code className={cn('relative', className)} style={{ textWrap: 'wrap', fontSize: '95%', padding: '2px 4px' }}>
{children}
</code>
)

View File

@ -1,5 +1,7 @@
import { loggerService } from '@logger'
// TODO: We may refactor it to a service with pendingMap
const logger = loggerService.withContext('AbortController')
export const abortMap = new Map<string, (() => void)[]>()