mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-06 21:35:52 +08:00
refactor(CodeBlock): closed fence detection for html (#9424)
* refactor(CodeBlock): closed fence detection for html * refactor: improve type, fix test * doc: add comments
This commit is contained in:
parent
ae203b5c7c
commit
c2aff60127
@ -3,7 +3,7 @@ import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
|||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
import { messageBlocksSelectors } from '@renderer/store/messageBlock'
|
import { messageBlocksSelectors } from '@renderer/store/messageBlock'
|
||||||
import { MessageBlockStatus } from '@renderer/types/newMessage'
|
import { MessageBlockStatus } from '@renderer/types/newMessage'
|
||||||
import { getCodeBlockId } from '@renderer/utils/markdown'
|
import { getCodeBlockId, isOpenFenceBlock } from '@renderer/utils/markdown'
|
||||||
import type { Node } from 'mdast'
|
import type { Node } from 'mdast'
|
||||||
import React, { memo, useCallback, useMemo } from 'react'
|
import React, { memo, useCallback, useMemo } from 'react'
|
||||||
|
|
||||||
@ -16,8 +16,9 @@ interface Props {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const CodeBlock: React.FC<Props> = ({ children, className, node, blockId }) => {
|
const CodeBlock: React.FC<Props> = ({ children, className, node, blockId }) => {
|
||||||
const match = /language-([\w-+]+)/.exec(className || '') || children?.includes('\n')
|
const languageMatch = /language-([\w-+]+)/.exec(className || '')
|
||||||
const language = match?.[1] ?? 'text'
|
const isMultiline = children?.includes('\n')
|
||||||
|
const language = languageMatch?.[1] ?? (isMultiline ? 'text' : null)
|
||||||
|
|
||||||
// 代码块 id
|
// 代码块 id
|
||||||
const id = useMemo(() => getCodeBlockId(node?.position?.start), [node?.position?.start])
|
const id = useMemo(() => getCodeBlockId(node?.position?.start), [node?.position?.start])
|
||||||
@ -39,11 +40,11 @@ const CodeBlock: React.FC<Props> = ({ children, className, node, blockId }) => {
|
|||||||
[blockId, id]
|
[blockId, id]
|
||||||
)
|
)
|
||||||
|
|
||||||
if (match) {
|
if (language !== null) {
|
||||||
// HTML 代码块特殊处理
|
// HTML 代码块特殊处理
|
||||||
// FIXME: 感觉没有必要用 isHtmlCode 判断
|
|
||||||
if (language === 'html') {
|
if (language === 'html') {
|
||||||
return <HtmlArtifactsCard html={children} onSave={handleSave} isStreaming={isStreaming} />
|
const isOpenFence = isOpenFenceBlock(children?.length, languageMatch?.[1]?.length, node?.position)
|
||||||
|
return <HtmlArtifactsCard html={children} onSave={handleSave} isStreaming={isStreaming && isOpenFence} />
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@ -10,6 +10,7 @@ const mocks = vi.hoisted(() => ({
|
|||||||
emit: vi.fn()
|
emit: vi.fn()
|
||||||
},
|
},
|
||||||
getCodeBlockId: vi.fn(),
|
getCodeBlockId: vi.fn(),
|
||||||
|
isOpenFenceBlock: vi.fn(),
|
||||||
selectById: vi.fn(),
|
selectById: vi.fn(),
|
||||||
CodeBlockView: vi.fn(({ onSave, children }) => (
|
CodeBlockView: vi.fn(({ onSave, children }) => (
|
||||||
<div>
|
<div>
|
||||||
@ -36,7 +37,8 @@ vi.mock('@renderer/services/EventService', () => ({
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@renderer/utils/markdown', () => ({
|
vi.mock('@renderer/utils/markdown', () => ({
|
||||||
getCodeBlockId: mocks.getCodeBlockId
|
getCodeBlockId: mocks.getCodeBlockId,
|
||||||
|
isOpenFenceBlock: mocks.isOpenFenceBlock
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@renderer/store', () => ({
|
vi.mock('@renderer/store', () => ({
|
||||||
@ -74,6 +76,7 @@ describe('CodeBlock', () => {
|
|||||||
vi.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
// Default mock return values
|
// Default mock return values
|
||||||
mocks.getCodeBlockId.mockReturnValue('test-code-block-id')
|
mocks.getCodeBlockId.mockReturnValue('test-code-block-id')
|
||||||
|
mocks.isOpenFenceBlock.mockReturnValue(false)
|
||||||
mocks.selectById.mockReturnValue({
|
mocks.selectById.mockReturnValue({
|
||||||
id: 'test-msg-block-id',
|
id: 'test-msg-block-id',
|
||||||
status: MessageBlockStatus.SUCCESS
|
status: MessageBlockStatus.SUCCESS
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import remarkParse from 'remark-parse'
|
|||||||
import remarkStringify from 'remark-stringify'
|
import remarkStringify from 'remark-stringify'
|
||||||
import removeMarkdown from 'remove-markdown'
|
import removeMarkdown from 'remove-markdown'
|
||||||
import { unified } from 'unified'
|
import { unified } from 'unified'
|
||||||
|
import type { Point, Position } from 'unist'
|
||||||
import { visit } from 'unist-util-visit'
|
import { visit } from 'unist-util-visit'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -189,7 +190,7 @@ export function removeTrailingDoubleSpaces(markdown: string): string {
|
|||||||
* @param start 代码块节点的起始位置
|
* @param start 代码块节点的起始位置
|
||||||
* @returns 代码块在 Markdown 字符串中的 ID
|
* @returns 代码块在 Markdown 字符串中的 ID
|
||||||
*/
|
*/
|
||||||
export function getCodeBlockId(start: any): string | null {
|
export function getCodeBlockId(start?: Point): string | null {
|
||||||
return start ? `${start.line}:${start.column}:${start.offset}` : null
|
return start ? `${start.line}:${start.column}:${start.offset}` : null
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -218,6 +219,28 @@ export function updateCodeBlock(raw: string, id: string, newContent: string): st
|
|||||||
return unified().use(remarkStringify).stringify(tree)
|
return unified().use(remarkStringify).stringify(tree)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查代码块是否包含 open fence。
|
||||||
|
* 限制:
|
||||||
|
* - 语言名不能包含空格,因为 remark-math 无法处理,会导致 end.offset 过长。
|
||||||
|
*
|
||||||
|
* 这个算法基于 remark/micromark 解析代码块的原理,所有参数实际上都可以从 node 中获取。
|
||||||
|
* 一个代码块的 node.position 包含 fences,而 children 不包含 fences,通过它们之间的
|
||||||
|
* 差值就可以判断有没有 closed fence。
|
||||||
|
*
|
||||||
|
* @param codeLength 代码长度(不包含语言信息)
|
||||||
|
* @param metaLength 元数据长度(```之后的语言信息)
|
||||||
|
* @param position 位置(unist 节点位置)
|
||||||
|
* @returns 是否为 open fence 代码块
|
||||||
|
*/
|
||||||
|
export function isOpenFenceBlock(codeLength?: number, metaLength?: number, position?: Position): boolean {
|
||||||
|
const contentLength = (codeLength ?? 0) + (metaLength ?? 0)
|
||||||
|
const start = position?.start?.offset ?? 0
|
||||||
|
const end = position?.end?.offset ?? 0
|
||||||
|
// 余量至少是 fence (3) + newlines (2)
|
||||||
|
return end - start <= contentLength + 5
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 检查代码是否具有HTML特征
|
* 检查代码是否具有HTML特征
|
||||||
* @param code 输入的代码字符串
|
* @param code 输入的代码字符串
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user