feat: enhance citation handling and add metadata support in citation blocks

This commit is contained in:
suyao 2025-05-09 21:24:48 +08:00
parent 30696e1ef1
commit bf819a7142
No known key found for this signature in database
11 changed files with 165 additions and 75 deletions

View File

@ -12,14 +12,14 @@ import CitationsList from '../CitationsList'
function CitationBlock({ block }: { block: CitationMessageBlock }) { function CitationBlock({ block }: { block: CitationMessageBlock }) {
const formattedCitations = useSelector((state: RootState) => selectFormattedCitationsByBlockId(state, block.id)) const formattedCitations = useSelector((state: RootState) => selectFormattedCitationsByBlockId(state, block.id))
const hasCitations = useMemo(() => {
const hasGeminiBlock = block.response?.source === WebSearchSource.GEMINI const hasGeminiBlock = block.response?.source === WebSearchSource.GEMINI
const hasCitations = useMemo(() => {
return ( return (
(formattedCitations && formattedCitations.length > 0) || (formattedCitations && formattedCitations.length > 0) ||
hasGeminiBlock || hasGeminiBlock ||
(block.knowledge && block.knowledge.length > 0) (block.knowledge && block.knowledge.length > 0)
) )
}, [formattedCitations, block.response, block.knowledge]) }, [formattedCitations, block.knowledge, hasGeminiBlock])
if (block.status === MessageBlockStatus.PROCESSING) { if (block.status === MessageBlockStatus.PROCESSING) {
return <Spinner text="message.searching" /> return <Spinner text="message.searching" />
@ -29,12 +29,10 @@ function CitationBlock({ block }: { block: CitationMessageBlock }) {
return null return null
} }
const isGemini = block.response?.source === WebSearchSource.GEMINI
return ( return (
<> <>
{block.status === MessageBlockStatus.SUCCESS && {block.status === MessageBlockStatus.SUCCESS &&
(isGemini ? ( (hasGeminiBlock ? (
<> <>
<CitationsList citations={formattedCitations} /> <CitationsList citations={formattedCitations} />
<SearchEntryPoint <SearchEntryPoint

View File

@ -1,8 +1,9 @@
import { GroundingSupport } from '@google/genai'
import { useSettings } from '@renderer/hooks/useSettings' import { useSettings } from '@renderer/hooks/useSettings'
import { getModelUniqId } from '@renderer/services/ModelService' import { getModelUniqId } from '@renderer/services/ModelService'
import type { RootState } from '@renderer/store' import type { RootState } from '@renderer/store'
import { selectFormattedCitationsByBlockId } from '@renderer/store/messageBlock' import { selectFormattedCitationsByBlockId } from '@renderer/store/messageBlock'
import type { Model } from '@renderer/types' import { type Model, WebSearchSource } from '@renderer/types'
import type { MainTextMessageBlock, Message } from '@renderer/types/newMessage' import type { MainTextMessageBlock, Message } from '@renderer/types/newMessage'
import { Flex } from 'antd' import { Flex } from 'antd'
import React, { useMemo } from 'react' import React, { useMemo } from 'react'
@ -47,8 +48,9 @@ const MainTextBlock: React.FC<Props> = ({ block, citationBlockId, role, mentions
return content return content
} }
// FIXME性能问题需要优化 switch (block.citationReferences[0].citationBlockSource) {
// Replace all citation numbers in the content with formatted citations case WebSearchSource.OPENAI_COMPATIBLE:
case WebSearchSource.OPENAI: {
formattedCitations.forEach((citation) => { formattedCitations.forEach((citation) => {
const citationNum = citation.number const citationNum = citation.number
const supData = { const supData = {
@ -58,12 +60,83 @@ const MainTextBlock: React.FC<Props> = ({ block, citationBlockId, role, mentions
content: citation.content?.substring(0, 200) content: citation.content?.substring(0, 200)
} }
const citationJson = encodeHTML(JSON.stringify(supData)) const citationJson = encodeHTML(JSON.stringify(supData))
// Handle[<sup>N</sup>](url)
const preFormattedRegex = new RegExp(`\\[<sup>${citationNum}</sup>\\]\\(.*?\\)`, 'g')
const citationTag = `[<sup data-citation='${citationJson}'>${citationNum}</sup>](${citation.url})` const citationTag = `[<sup data-citation='${citationJson}'>${citationNum}</sup>](${citation.url})`
// Replace all occurrences of [citationNum] with the formatted citation content = content.replace(preFormattedRegex, citationTag)
const regex = new RegExp(`\\[${citationNum}\\]`, 'g')
content = content.replace(regex, citationTag)
}) })
break
}
case WebSearchSource.GEMINI: {
// First pass: Add basic citation marks using metadata
let processedContent = content
const firstCitation = formattedCitations[0]
if (firstCitation?.metadata) {
console.log('groundingSupport, ', firstCitation.metadata)
firstCitation.metadata.forEach((support: GroundingSupport) => {
const citationNums = support.groundingChunkIndices!
if (support.segment) {
const text = support.segment.text!
// 生成引用标记
const basicTag = citationNums
.map((citationNum) => {
const citation = formattedCitations.find((c) => c.number === citationNum + 1)
return citation ? `[<sup>${citationNum + 1}</sup>](${citation.url})` : ''
})
.join('')
// 在文本后面添加引用标记,而不是替换
if (text && basicTag) {
processedContent = processedContent.replace(text, `${text}${basicTag}`)
}
}
})
content = processedContent
}
// Second pass: Replace basic citations with full citation data
formattedCitations.forEach((citation) => {
const citationNum = citation.number
const supData = {
id: citationNum,
url: citation.url,
title: citation.title || citation.hostname || '',
content: citation.content?.substring(0, 200)
}
const citationJson = encodeHTML(JSON.stringify(supData))
// Replace basic citation with full citation including data
const basicCitationRegex = new RegExp(`\\[<sup>${citationNum}</sup>\\]\\(${citation.url}\\)`, 'g')
const fullCitationTag = `[<sup data-citation='${citationJson}'>${citationNum}</sup>](${citation.url})`
content = content.replace(basicCitationRegex, fullCitationTag)
})
break
}
default: {
// FIXME性能问题需要优化
// Replace all citation numbers and pre-formatted links with formatted citations
formattedCitations.forEach((citation) => {
const citationNum = citation.number
const supData = {
id: citationNum,
url: citation.url,
title: citation.title || citation.hostname || '',
content: citation.content?.substring(0, 200)
}
const citationJson = encodeHTML(JSON.stringify(supData))
// Handle both plain references [N] and pre-formatted links [<sup>N</sup>](url)
const plainRefRegex = new RegExp(`\\[${citationNum}\\]`, 'g')
const citationTag = `[<sup data-citation='${citationJson}'>${citationNum}</sup>](${citation.url})`
content = content.replace(plainRefRegex, citationTag)
})
}
}
return content return content
}, [block.content, block.citationReferences, citationBlockId, formattedCitations]) }, [block.content, block.citationReferences, citationBlockId, formattedCitations])

View File

@ -16,6 +16,7 @@ export interface Citation {
content?: string content?: string
showFavicon?: boolean showFavicon?: boolean
type?: string type?: string
metadata?: Record<string, any>
} }
interface CitationsListProps { interface CitationsListProps {

View File

@ -331,7 +331,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
const isEnabledWebSearch = assistant.enableWebSearch || !!assistant.webSearchProviderId const isEnabledBultinWebSearch = assistant.enableWebSearch
messages = addImageFileToContents(messages) messages = addImageFileToContents(messages)
const enableReasoning = const enableReasoning =
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) && ((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
@ -597,7 +597,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
} }
} }
if ( if (
isEnabledWebSearch && isEnabledBultinWebSearch &&
isZhipuModel(model) && isZhipuModel(model) &&
finishReason === 'stop' && finishReason === 'stop' &&
originalFinishRawChunk?.web_search originalFinishRawChunk?.web_search
@ -611,7 +611,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
} as LLMWebSearchCompleteChunk) } as LLMWebSearchCompleteChunk)
} }
if ( if (
isEnabledWebSearch && isEnabledBultinWebSearch &&
isHunyuanSearchModel(model) && isHunyuanSearchModel(model) &&
originalFinishRawChunk?.search_info?.search_results originalFinishRawChunk?.search_info?.search_results
) { ) {

View File

@ -287,7 +287,7 @@ export default class OpenAIProvider extends BaseProvider {
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
const isEnabledWebSearch = assistant.enableWebSearch || !!assistant.webSearchProviderId const isEnabledBuiltinWebSearch = assistant.enableWebSearch
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
// 退回到 OpenAI 兼容模式 // 退回到 OpenAI 兼容模式
if (isOpenAIWebSearch(model)) { if (isOpenAIWebSearch(model)) {
@ -342,7 +342,7 @@ export default class OpenAIProvider extends BaseProvider {
const delta = chunk.choices[0]?.delta const delta = chunk.choices[0]?.delta
const finishReason = chunk.choices[0]?.finish_reason const finishReason = chunk.choices[0]?.finish_reason
if (delta?.content) { if (delta?.content) {
if (delta?.annotations) { if (isOpenAIWebSearch(model)) {
delta.content = convertLinks(delta.content || '', isFirstChunk) delta.content = convertLinks(delta.content || '', isFirstChunk)
} }
if (isFirstChunk) { if (isFirstChunk) {
@ -388,7 +388,10 @@ export default class OpenAIProvider extends BaseProvider {
return return
} }
const tools: OpenAI.Responses.Tool[] = [] const tools: OpenAI.Responses.Tool[] = []
if (isEnabledWebSearch) { const toolChoices: OpenAI.Responses.ToolChoiceTypes = {
type: 'web_search_preview'
}
if (isEnabledBuiltinWebSearch) {
tools.push({ tools.push({
type: 'web_search_preview' type: 'web_search_preview'
}) })
@ -558,17 +561,22 @@ export default class OpenAIProvider extends BaseProvider {
thinking_millsec: new Date().getTime() - time_first_token_millsec thinking_millsec: new Date().getTime() - time_first_token_millsec
}) })
break break
case 'response.output_text.delta': case 'response.output_text.delta': {
let delta = chunk.delta
if (isEnabledBuiltinWebSearch) {
delta = convertLinks(delta)
}
onChunk({ onChunk({
type: ChunkType.TEXT_DELTA, type: ChunkType.TEXT_DELTA,
text: chunk.delta text: delta
}) })
content += chunk.delta content += delta
break break
}
case 'response.output_text.done': case 'response.output_text.done':
onChunk({ onChunk({
type: ChunkType.TEXT_COMPLETE, type: ChunkType.TEXT_COMPLETE,
text: chunk.text text: content
}) })
break break
case 'response.content_part.done': case 'response.content_part.done':
@ -633,6 +641,7 @@ export default class OpenAIProvider extends BaseProvider {
max_output_tokens: maxTokens, max_output_tokens: maxTokens,
stream: streamOutput, stream: streamOutput,
tools: tools.length > 0 ? tools : undefined, tools: tools.length > 0 ? tools : undefined,
tool_choice: isEnabledBuiltinWebSearch ? toolChoices : undefined,
service_tier: this.getServiceTier(model), service_tier: this.getServiceTier(model),
...this.getResponseReasoningEffort(assistant, model), ...this.getResponseReasoningEffort(assistant, model),
...this.getCustomParameters(assistant) ...this.getCustomParameters(assistant)

View File

@ -85,19 +85,22 @@ const formatCitationsFromBlock = (block: CitationMessageBlock | undefined): Cita
if (!block) return [] if (!block) return []
let formattedCitations: Citation[] = [] let formattedCitations: Citation[] = []
// 1. Handle Web Search Responses (Non-Gemini) // 1. Handle Web Search Responses
if (block.response) { if (block.response) {
switch (block.response.source) { switch (block.response.source) {
case WebSearchSource.GEMINI: case WebSearchSource.GEMINI: {
const groundingMetadata = block.response.results as GroundingMetadata
formattedCitations = formattedCitations =
(block.response?.results as GroundingMetadata)?.groundingChunks?.map((chunk, index) => ({ groundingMetadata?.groundingChunks?.map((chunk, index) => ({
number: index + 1, number: index + 1,
url: chunk?.web?.uri || '', url: chunk?.web?.uri || '',
title: chunk?.web?.title, title: chunk?.web?.title,
showFavicon: false, showFavicon: true,
metadata: groundingMetadata.groundingSupports,
type: 'websearch' type: 'websearch'
})) || [] })) || []
break break
}
case WebSearchSource.OPENAI: case WebSearchSource.OPENAI:
formattedCitations = formattedCitations =
(block.response.results as OpenAI.Responses.ResponseOutputText.URLCitation[])?.map((result, index) => { (block.response.results as OpenAI.Responses.ResponseOutputText.URLCitation[])?.map((result, index) => {

View File

@ -470,7 +470,6 @@ const fetchAndProcessAssistantResponseImpl = async (
saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState) saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState)
}, },
onExternalToolComplete: (externalToolResult: ExternalToolResult) => { onExternalToolComplete: (externalToolResult: ExternalToolResult) => {
console.warn('onExternalToolComplete received.', externalToolResult)
if (citationBlockId) { if (citationBlockId) {
const changes: Partial<CitationMessageBlock> = { const changes: Partial<CitationMessageBlock> = {
response: externalToolResult.webSearch, response: externalToolResult.webSearch,
@ -505,19 +504,24 @@ const fetchAndProcessAssistantResponseImpl = async (
) )
citationBlockId = citationBlock.id citationBlockId = citationBlock.id
handleBlockTransition(citationBlock, MessageBlockType.CITATION) handleBlockTransition(citationBlock, MessageBlockType.CITATION)
}
if (mainTextBlockId) { if (mainTextBlockId) {
const state = getState() const state = getState()
const existingMainTextBlock = state.messageBlocks.entities[mainTextBlockId] const existingMainTextBlock = state.messageBlocks.entities[mainTextBlockId]
if (existingMainTextBlock && existingMainTextBlock.type === MessageBlockType.MAIN_TEXT) { if (existingMainTextBlock && existingMainTextBlock.type === MessageBlockType.MAIN_TEXT) {
const currentRefs = existingMainTextBlock.citationReferences || [] const currentRefs = existingMainTextBlock.citationReferences || []
if (!currentRefs.some((ref) => ref.citationBlockId === citationBlockId)) { if (!currentRefs.some((ref) => ref.citationBlockId === citationBlockId)) {
const mainTextChanges = { citationReferences: [...currentRefs, { citationBlockId }] } const mainTextChanges = {
citationReferences: [
...currentRefs,
{ citationBlockId, citationBlockSource: llmWebSearchResult.source }
]
}
dispatch(updateOneBlock({ id: mainTextBlockId, changes: mainTextChanges })) dispatch(updateOneBlock({ id: mainTextBlockId, changes: mainTextChanges }))
saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState) saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState)
} }
} }
} }
}
}, },
onImageCreated: () => { onImageCreated: () => {
const imageBlock = createImageBlock(assistantMsgId, { const imageBlock = createImageBlock(assistantMsgId, {

View File

@ -11,7 +11,8 @@ import type {
Model, Model,
Topic, Topic,
Usage, Usage,
WebSearchResponse WebSearchResponse,
WebSearchSource
} from '.' } from '.'
// MessageBlock 类型枚举 - 根据实际API返回特性优化 // MessageBlock 类型枚举 - 根据实际API返回特性优化
@ -63,6 +64,7 @@ export interface MainTextMessageBlock extends BaseMessageBlock {
// Citation references // Citation references
citationReferences?: { citationReferences?: {
citationBlockId?: string citationBlockId?: string
citationBlockSource?: WebSearchSource
}[] }[]
} }

View File

@ -99,12 +99,6 @@ describe('linkConverter', () => {
expect(result).toBe('这里有链接 [<sup>1</sup>](https://example.com)') expect(result).toBe('这里有链接 [<sup>1</sup>](https://example.com)')
}) })
it('should preserve non-domain link text', () => {
const input = '点击[这里](https://example.com)查看更多'
const result = convertLinks(input, true)
expect(result).toBe('点击这里[<sup>1</sup>](https://example.com)查看更多')
})
it('should use the same counter for duplicate URLs', () => { it('should use the same counter for duplicate URLs', () => {
const input = const input =
'第一个链接 [example.com](https://example.com) 和第二个相同链接 [subdomain.example.com](https://example.com)' '第一个链接 [example.com](https://example.com) 和第二个相同链接 [subdomain.example.com](https://example.com)'
@ -113,24 +107,6 @@ describe('linkConverter', () => {
'第一个链接 [<sup>1</sup>](https://example.com) 和第二个相同链接 [<sup>1</sup>](https://example.com)' '第一个链接 [<sup>1</sup>](https://example.com) 和第二个相同链接 [<sup>1</sup>](https://example.com)'
) )
}) })
it('should correctly convert links in Zhipu mode', () => {
const input = '这里是引用 [ref_1]'
const result = convertLinks(input, true, true)
expect(result).toBe('这里是引用 [<sup>1</sup>]()')
})
it('should handle incomplete links in chunked input', () => {
// 第一个块包含未完成的链接
const chunk1 = '这是链接 ['
const result1 = convertLinks(chunk1, true)
expect(result1).toBe('这是链接 ')
// 第二个块完成链接
const chunk2 = 'example.com](https://example.com)'
const result2 = convertLinks(chunk2, false)
expect(result2).toBe('[<sup>1</sup>](https://example.com)')
})
}) })
describe('convertLinksToOpenRouter', () => { describe('convertLinksToOpenRouter', () => {

View File

@ -126,3 +126,20 @@ export async function fetchWebContent(
} }
} }
} }
export async function fetchRedirectUrl(url: string) {
try {
const response = await fetch(url, {
method: 'HEAD',
redirect: 'follow',
headers: {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
}
})
return response.url
} catch (e) {
console.error(`Failed to fetch redirect url: ${e}`)
return url
}
}

View File

@ -113,6 +113,7 @@ export function convertLinksToHunyuan(text: string, webSearch: any[], resetCount
* Converts Markdown links in the text to numbered links based on the rules: * Converts Markdown links in the text to numbered links based on the rules:
* 1. ([host](url)) -> [cnt](url) * 1. ([host](url)) -> [cnt](url)
* 2. [host](url) -> [cnt](url) * 2. [host](url) -> [cnt](url)
* 3. [any text except url](url)-> any text [cnt](url)
* *
* @param text The current chunk of text to process * @param text The current chunk of text to process
* @param resetCounter Whether to reset the counter and buffer * @param resetCounter Whether to reset the counter and buffer
@ -131,7 +132,6 @@ export function convertLinks(text: string, resetCounter = false): string {
// Find the safe point - the position after which we might have incomplete patterns // Find the safe point - the position after which we might have incomplete patterns
let safePoint = buffer.length let safePoint = buffer.length
// Check for potentially incomplete patterns from the end // Check for potentially incomplete patterns from the end
for (let i = buffer.length - 1; i >= 0; i--) { for (let i = buffer.length - 1; i >= 0; i--) {
if (buffer[i] === '(') { if (buffer[i] === '(') {
@ -198,6 +198,7 @@ export function convertLinks(text: string, resetCounter = false): string {
if (match) { if (match) {
// Found complete regular link // Found complete regular link
const linkText = match[1]
const url = match[2] const url = match[2]
// Check if this URL has been seen before // Check if this URL has been seen before
@ -209,7 +210,13 @@ export function convertLinks(text: string, resetCounter = false): string {
urlToCounterMap.set(url, counter) urlToCounterMap.set(url, counter)
} }
// Rule 3: If the link text is not a URL/host, keep the text and add the numbered link
if (!isHost(linkText)) {
result += `${linkText} [<sup>${counter}</sup>](${url})`
} else {
// Rule 2: If the link text is a URL/host, replace with numbered link
result += `[<sup>${counter}</sup>](${url})` result += `[<sup>${counter}</sup>](${url})`
}
position += match[0].length position += match[0].length
continue continue
@ -317,7 +324,7 @@ export function extractUrlsFromMarkdown(text: string): string[] {
// 匹配所有Markdown链接格式 // 匹配所有Markdown链接格式
const linkPattern = /\[(?:[^[\]]*)\]\(([^()]+)\)/g const linkPattern = /\[(?:[^[\]]*)\]\(([^()]+)\)/g
let match let match: RegExpExecArray | null
while ((match = linkPattern.exec(text)) !== null) { while ((match = linkPattern.exec(text)) !== null) {
const url = match[1].trim() const url = match[1].trim()