mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 12:51:26 +08:00
feat(mermaid): add error handling and fix tool for mermaid diagrams
- Add mermaid_fix tool to handle and fix mermaid diagram errors - Implement pending state tracking for async operations - Add error callback to MermaidPreview component - Update i18n strings for mermaid error handling - Extend cache schema with pending_map for task tracking
This commit is contained in:
parent
0fb5480b0a
commit
1d94d56e2a
3
packages/shared/data/cache/cacheSchemas.ts
vendored
3
packages/shared/data/cache/cacheSchemas.ts
vendored
@ -8,6 +8,8 @@ export type UseCacheSchema = {
|
||||
// App state
|
||||
'app.dist.update_state': CacheValueTypes.CacheAppUpdateState
|
||||
'app.user.avatar': string
|
||||
/** Used to indicate whether any asynchronous task with a given ID is currently in progress */
|
||||
'app.pending_map': Record<string, boolean>
|
||||
|
||||
// Chat context
|
||||
'chat.multi_select_mode': boolean
|
||||
@ -53,6 +55,7 @@ export const DefaultUseCache: UseCacheSchema = {
|
||||
available: false
|
||||
},
|
||||
'app.user.avatar': '',
|
||||
'app.pending_map': {},
|
||||
|
||||
// Chat context
|
||||
'chat.multi_select_mode': false,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import type { ActionToolSpec } from './types'
|
||||
|
||||
export const TOOL_SPECS: Record<string, ActionToolSpec> = {
|
||||
export const TOOL_SPECS = {
|
||||
// Core tools
|
||||
copy: {
|
||||
id: 'copy',
|
||||
@ -72,5 +72,10 @@ export const TOOL_SPECS: Record<string, ActionToolSpec> = {
|
||||
id: 'zoom-out',
|
||||
type: 'quick',
|
||||
order: 41
|
||||
},
|
||||
mermaid_fix: {
|
||||
id: 'mermaid-fix',
|
||||
type: 'core',
|
||||
order: 42
|
||||
}
|
||||
}
|
||||
} as const satisfies Record<string, ActionToolSpec>
|
||||
|
||||
@ -26,6 +26,7 @@ import React, { memo, startTransition, useCallback, useEffect, useMemo, useRef,
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled, { css } from 'styled-components'
|
||||
|
||||
import { useMermaidFixTool } from '../CodeToolbar/hooks/useMermaidFixTool'
|
||||
import { SPECIAL_VIEW_COMPONENTS, SPECIAL_VIEWS } from './constants'
|
||||
import StatusBar from './StatusBar'
|
||||
import type { ViewMode } from './types'
|
||||
@ -33,9 +34,12 @@ import type { ViewMode } from './types'
|
||||
const logger = loggerService.withContext('CodeBlockView')
|
||||
|
||||
interface Props {
|
||||
// FIXME: It's not runtime string!
|
||||
children: string
|
||||
language: string
|
||||
onSave?: (newContent: string) => void
|
||||
// Message Block ID
|
||||
blockId: string
|
||||
onSave: (newContent: string) => void
|
||||
}
|
||||
|
||||
/**
|
||||
@ -54,7 +58,7 @@ interface Props {
|
||||
* - quick 工具
|
||||
* - core 工具
|
||||
*/
|
||||
export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave }) => {
|
||||
export const CodeBlockView: React.FC<Props> = memo(({ children: code, language, blockId, onSave }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const [codeExecutionEnabled] = usePreference('chat.code.execution.enabled')
|
||||
@ -113,6 +117,8 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
|
||||
const specialViewRef = useRef<BasicPreviewHandles>(null)
|
||||
|
||||
const hasSpecialView = useMemo(() => SPECIAL_VIEWS.includes(language), [language])
|
||||
const [error, setError] = useState<unknown>(null)
|
||||
const isMermaid = language === 'mermaid'
|
||||
|
||||
const isInSpecialView = useMemo(() => {
|
||||
return hasSpecialView && viewMode === 'special'
|
||||
@ -146,16 +152,16 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
|
||||
}, [])
|
||||
|
||||
const handleCopySource = useCallback(() => {
|
||||
navigator.clipboard.writeText(children)
|
||||
navigator.clipboard.writeText(code)
|
||||
window.toast.success(t('code_block.copy.success'))
|
||||
}, [children, t])
|
||||
}, [code, t])
|
||||
|
||||
const handleDownloadSource = useCallback(() => {
|
||||
let fileName = ''
|
||||
|
||||
// 尝试提取 HTML 标题
|
||||
if (language === 'html') {
|
||||
fileName = getFileNameFromHtmlTitle(extractHtmlTitle(children)) || ''
|
||||
fileName = getFileNameFromHtmlTitle(extractHtmlTitle(code)) || ''
|
||||
}
|
||||
|
||||
// 默认使用日期格式命名
|
||||
@ -164,15 +170,15 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
|
||||
}
|
||||
|
||||
const ext = getExtensionByLanguage(language)
|
||||
window.api.file.save(`${fileName}${ext}`, children)
|
||||
}, [children, language])
|
||||
window.api.file.save(`${fileName}${ext}`, code)
|
||||
}, [code, language])
|
||||
|
||||
const handleRunScript = useCallback(() => {
|
||||
setIsRunning(true)
|
||||
setExecutionResult(null)
|
||||
|
||||
pyodideService
|
||||
.runScript(children, {}, codeExecutionTimeoutMinutes * 60000)
|
||||
.runScript(code, {}, codeExecutionTimeoutMinutes * 60000)
|
||||
.then((result) => {
|
||||
setExecutionResult(result)
|
||||
})
|
||||
@ -185,7 +191,7 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
|
||||
.finally(() => {
|
||||
setIsRunning(false)
|
||||
})
|
||||
}, [children, codeExecutionTimeoutMinutes])
|
||||
}, [code, codeExecutionTimeoutMinutes])
|
||||
|
||||
const showPreviewTools = useMemo(() => {
|
||||
return viewMode !== 'source' && hasSpecialView
|
||||
@ -257,6 +263,19 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
|
||||
setTools
|
||||
})
|
||||
|
||||
// Mermaid fix tool
|
||||
useMermaidFixTool({
|
||||
enabled: isMermaid && error !== undefined && error !== null,
|
||||
context: {
|
||||
blockId,
|
||||
error,
|
||||
content: code
|
||||
},
|
||||
setError,
|
||||
onSave,
|
||||
setTools
|
||||
})
|
||||
|
||||
// 源代码视图组件
|
||||
const sourceView = useMemo(
|
||||
() =>
|
||||
@ -266,7 +285,7 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
|
||||
ref={sourceViewRef}
|
||||
theme={activeCmTheme}
|
||||
fontSize={fontSize - 1}
|
||||
value={children}
|
||||
value={code}
|
||||
language={language}
|
||||
onSave={onSave}
|
||||
onHeightChange={handleHeightChange}
|
||||
@ -278,7 +297,7 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
|
||||
) : (
|
||||
<CodeViewer
|
||||
className="source-view"
|
||||
value={children}
|
||||
value={code}
|
||||
language={language}
|
||||
onHeightChange={handleHeightChange}
|
||||
expanded={shouldExpand}
|
||||
@ -288,7 +307,7 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
|
||||
),
|
||||
[
|
||||
activeCmTheme,
|
||||
children,
|
||||
code,
|
||||
codeEditor,
|
||||
codeShowLineNumbers,
|
||||
fontSize,
|
||||
@ -307,11 +326,11 @@ export const CodeBlockView: React.FC<Props> = memo(({ children, language, onSave
|
||||
if (!SpecialView) return null
|
||||
|
||||
return (
|
||||
<SpecialView ref={specialViewRef} enableToolbar={codeImageTools}>
|
||||
{children}
|
||||
<SpecialView ref={specialViewRef} enableToolbar={codeImageTools} onError={setError}>
|
||||
{code}
|
||||
</SpecialView>
|
||||
)
|
||||
}, [children, codeImageTools, language])
|
||||
}, [code, codeImageTools, language])
|
||||
|
||||
const renderHeader = useMemo(() => {
|
||||
const langTag = '<' + language.toUpperCase() + '>'
|
||||
|
||||
@ -0,0 +1,178 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { ActionTool } from '@renderer/components/ActionTools'
|
||||
import { TOOL_SPECS, useToolManager } from '@renderer/components/ActionTools'
|
||||
import { usePendingMap } from '@renderer/hooks/usePendingMap'
|
||||
import { useQuickCompletion } from '@renderer/hooks/useQuickCompletion'
|
||||
import { useSettings } from '@renderer/hooks/useSettings'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { getErrorMessage, parseJSON } from '@renderer/utils'
|
||||
import { abortCompletion, readyToAbort } from '@renderer/utils/abortController'
|
||||
import { WrenchIcon } from 'lucide-react'
|
||||
import { useCallback, useEffect, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import * as z from 'zod'
|
||||
|
||||
const logger = loggerService.withContext('useMermaidFixTool')
|
||||
|
||||
interface UseMermaidFixTool {
|
||||
enabled?: boolean
|
||||
context: {
|
||||
/** Code block id */
|
||||
blockId: string
|
||||
/** Error */
|
||||
error: unknown
|
||||
/** Mermaid code */
|
||||
content: string
|
||||
}
|
||||
onSave: (newContent: string) => void
|
||||
setError: (error: unknown) => void
|
||||
setTools: React.Dispatch<React.SetStateAction<ActionTool[]>>
|
||||
}
|
||||
|
||||
const ResultSchema = z.union([
|
||||
z.object({
|
||||
fixed: z.literal(true),
|
||||
result: z.string()
|
||||
}),
|
||||
z.object({
|
||||
fixed: z.literal(false),
|
||||
reason: z.string()
|
||||
})
|
||||
])
|
||||
|
||||
type Input = {
|
||||
// mermaid 代码
|
||||
mermaid: string
|
||||
// 错误信息
|
||||
error: string
|
||||
// 用户语言代码, 如 zh-cn, en-us
|
||||
lang: string
|
||||
}
|
||||
|
||||
const SYSTEM_PROMPT = `
|
||||
You are an AI assistant that fixes Mermaid code. The input is a JSON string with the following structure: {"mermaid": "the Mermaid code", "error": "the error message from rendering", "lang": "the user's language code"}.
|
||||
|
||||
Your task is to analyze the error and the Mermaid code. If the error is due to a mistake in the Mermaid code, fix it and output a JSON string with {"fixed": true, "result": "the fixed Mermaid code"}. If the error is not caused by the code (e.g., environment issues, unsupported features, or other non-code errors), output {"fixed": false, "reason": "a brief explanation in the language specified by the 'lang' field"}.
|
||||
|
||||
Your output must be a pure JSON string with no additional text, comments, or formatting.
|
||||
|
||||
Example input:
|
||||
{
|
||||
"mermaid": "graph TD\nA[Start] --> B{Error?}",
|
||||
"error": "Syntax error: unexpected token",
|
||||
"lang": "en-us"
|
||||
}
|
||||
|
||||
Example outputs:
|
||||
- If fixed: {"fixed": true, "result": "graph TD\nA[Start] --> B{Error?}\nB -->|Yes| C[End]"}
|
||||
- If not fixed: {"fixed": false, "reason": "The error is due to an unsupported feature in the current environment."}
|
||||
|
||||
`
|
||||
|
||||
export const useMermaidFixTool = ({ enabled, context, onSave, setError, setTools }: UseMermaidFixTool) => {
|
||||
const { t } = useTranslation()
|
||||
const { registerTool, removeTool } = useToolManager(setTools)
|
||||
const { language } = useSettings()
|
||||
const completion = useQuickCompletion(SYSTEM_PROMPT)
|
||||
|
||||
const { error, content, blockId } = context
|
||||
const abortKeyRef = useRef<string | null>(null)
|
||||
|
||||
const { setPending } = usePendingMap()
|
||||
logger.debug('input', {
|
||||
mermaid: content,
|
||||
error: getErrorMessage(error),
|
||||
lang: language
|
||||
})
|
||||
const prompt = JSON.stringify({
|
||||
mermaid: content,
|
||||
error: getErrorMessage(error),
|
||||
lang: language
|
||||
} satisfies Input)
|
||||
|
||||
const fixCode = useCallback(async () => {
|
||||
setPending(blockId, true)
|
||||
const abortKey = crypto.randomUUID()
|
||||
abortKeyRef.current = abortKey
|
||||
const signal = readyToAbort(abortKey)
|
||||
let result = ''
|
||||
|
||||
const onChunk = (chunk: Chunk) => {
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
result = chunk.text
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
await completion({
|
||||
prompt,
|
||||
onChunk,
|
||||
params: {
|
||||
options: {
|
||||
signal
|
||||
}
|
||||
}
|
||||
})
|
||||
} catch (e) {
|
||||
window.toast.error({ title: t('code_block.mermaid_fix.failed'), description: getErrorMessage(e) })
|
||||
return
|
||||
}
|
||||
|
||||
result = result.trim()
|
||||
logger.debug('output', { result })
|
||||
|
||||
const parsedJson = parseJSON(result)
|
||||
if (parsedJson === null) {
|
||||
window.toast.error({
|
||||
title: t('code_block.mermaid_fix.failed'),
|
||||
description: t('code_block.mermaid_fix.invalid_result')
|
||||
})
|
||||
} else {
|
||||
logger.debug('parseJSON success', { parsedJson })
|
||||
const parsedResult = ResultSchema.safeParse(parsedJson)
|
||||
logger.debug('validation', { parsedResult })
|
||||
|
||||
if (parsedResult.success) {
|
||||
const validResult = parsedResult.data
|
||||
if (validResult.fixed) {
|
||||
onSave(validResult.result)
|
||||
setError(undefined)
|
||||
} else {
|
||||
window.toast.warning({ title: t('code_block.mermaid_fix.failed'), description: validResult.reason })
|
||||
}
|
||||
} else {
|
||||
window.toast.error({
|
||||
title: t('code_block.mermaid_fix.failed'),
|
||||
description: t('code_block.mermaid_fix.invalid_result')
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
setPending(blockId, false)
|
||||
}, [setPending, blockId, completion, prompt, onSave, t])
|
||||
|
||||
// when unmounted
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
const abortKey = abortKeyRef.current
|
||||
if (abortKey) {
|
||||
abortCompletion(abortKey)
|
||||
}
|
||||
}
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
if (enabled) {
|
||||
registerTool({
|
||||
...TOOL_SPECS.mermaid_fix,
|
||||
icon: <WrenchIcon size={'1rem'} className="tool-icon" />,
|
||||
tooltip: t('code_block.mermaid_fix.label'),
|
||||
visible: () => error !== undefined && error !== null,
|
||||
onClick: fixCode
|
||||
})
|
||||
}
|
||||
|
||||
return () => removeTool(TOOL_SPECS.expand.id)
|
||||
}, [enabled, error, fixCode, registerTool, removeTool, t])
|
||||
}
|
||||
@ -16,8 +16,9 @@ import { renderSvgInShadowHost } from './utils'
|
||||
const MermaidPreview = ({
|
||||
children,
|
||||
enableToolbar = false,
|
||||
ref
|
||||
}: BasicPreviewProps & { ref?: React.RefObject<BasicPreviewHandles | null> }) => {
|
||||
ref,
|
||||
onError
|
||||
}: BasicPreviewProps & { ref?: React.RefObject<BasicPreviewHandles | null>; onError?: (error: unknown) => void }) => {
|
||||
const { mermaid, isLoading: isLoadingMermaid, error: mermaidError, forceRenderKey } = useMermaid()
|
||||
const diagramId = useRef<string>(`mermaid-${nanoid(6)}`).current
|
||||
const [isVisible, setIsVisible] = useState(true)
|
||||
@ -122,6 +123,12 @@ const MermaidPreview = ({
|
||||
const isLoading = isLoadingMermaid || isRendering
|
||||
const error = mermaidError || renderError
|
||||
|
||||
useEffect(() => {
|
||||
if (error !== undefined && error !== null && onError !== undefined) {
|
||||
onError(error)
|
||||
}
|
||||
}, [error, onError])
|
||||
|
||||
return (
|
||||
<ImagePreviewLayout
|
||||
loading={isLoading}
|
||||
|
||||
@ -11,7 +11,8 @@ import { useDefaultModel } from './useAssistant'
|
||||
*/
|
||||
export type QuickCompletionParams = {
|
||||
/**
|
||||
* The prompt text to send to the model.
|
||||
* The user message text (not the system prompt) to send to the model.
|
||||
* The system prompt is set via the `systemPrompt` parameter passed to `useQuickCompletion`.
|
||||
*/
|
||||
prompt: string
|
||||
/**
|
||||
|
||||
@ -940,6 +940,11 @@
|
||||
}
|
||||
},
|
||||
"expand": "Expand",
|
||||
"mermaid_fix": {
|
||||
"failed": "Failed to fix",
|
||||
"invalid_result": "Model returned invalid data. Please try again or try changing the quick model.",
|
||||
"label": "Fix mermaid error"
|
||||
},
|
||||
"more": "More",
|
||||
"run": "Run",
|
||||
"split": {
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import { BlockingOverlay, cn, Spinner } from '@cherrystudio/ui'
|
||||
import { CodeBlockView, HtmlArtifactsCard } from '@renderer/components/CodeBlockView'
|
||||
import { usePendingMap } from '@renderer/hooks/usePendingMap'
|
||||
import { useSettings } from '@renderer/hooks/useSettings'
|
||||
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
||||
import store from '@renderer/store'
|
||||
@ -21,6 +23,8 @@ const CodeBlock: React.FC<Props> = ({ children, className, node, blockId }) => {
|
||||
const isMultiline = children?.includes('\n')
|
||||
const language = languageMatch?.[1] ?? (isMultiline ? 'text' : null)
|
||||
const { codeFancyBlock } = useSettings()
|
||||
const { isPending } = usePendingMap()
|
||||
const isBlockPending = isPending(blockId)
|
||||
|
||||
// 代码块 id
|
||||
const id = useMemo(() => getCodeBlockId(node?.position?.start), [node?.position?.start])
|
||||
@ -52,14 +56,24 @@ const CodeBlock: React.FC<Props> = ({ children, className, node, blockId }) => {
|
||||
}
|
||||
|
||||
return (
|
||||
<CodeBlockView language={language} onSave={handleSave}>
|
||||
{children}
|
||||
</CodeBlockView>
|
||||
<div className="relative">
|
||||
<CodeBlockView language={language} onSave={handleSave} blockId={blockId}>
|
||||
{children}
|
||||
</CodeBlockView>
|
||||
{isBlockPending && (
|
||||
<BlockingOverlay isVisible={isBlockPending}>
|
||||
<Spinner />
|
||||
</BlockingOverlay>
|
||||
)}
|
||||
{/* <BlockingOverlay isVisible={true}>
|
||||
<Spinner />
|
||||
</BlockingOverlay> */}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<code className={className} style={{ textWrap: 'wrap', fontSize: '95%', padding: '2px 4px' }}>
|
||||
<code className={cn('relative', className)} style={{ textWrap: 'wrap', fontSize: '95%', padding: '2px 4px' }}>
|
||||
{children}
|
||||
</code>
|
||||
)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import { loggerService } from '@logger'
|
||||
|
||||
// TODO: We may refactor it to a service with pendingMap
|
||||
|
||||
const logger = loggerService.withContext('AbortController')
|
||||
|
||||
export const abortMap = new Map<string, (() => void)[]>()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user