From e51a37cc74d37fbcaa04705311c4fc14f1e37b6e Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat <43230886+MyPrototypeWhat@users.noreply.github.com> Date: Sat, 17 May 2025 23:53:57 +0800 Subject: [PATCH 01/17] fix: group message resend (#6106) --- src/renderer/src/store/thunk/messageThunk.ts | 40 +++++++++++++------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/src/renderer/src/store/thunk/messageThunk.ts b/src/renderer/src/store/thunk/messageThunk.ts index c2c51b25e4..78f2876212 100644 --- a/src/renderer/src/store/thunk/messageThunk.ts +++ b/src/renderer/src/store/thunk/messageThunk.ts @@ -805,21 +805,19 @@ export const deleteMessageGroupThunk = const currentState = getState() const topicMessageIds = currentState.messages.messageIdsByTopic[topicId] || [] const messagesToDelete: Message[] = [] - const idsToDelete: string[] = [] topicMessageIds.forEach((id) => { const msg = currentState.messages.entities[id] if (msg && msg.askId === askId) { messagesToDelete.push(msg) - idsToDelete.push(id) } }) - const userQuery = currentState.messages.entities[askId] - if (userQuery && userQuery.topicId === topicId && !idsToDelete.includes(askId)) { - messagesToDelete.push(userQuery) - idsToDelete.push(askId) - } + // const userQuery = currentState.messages.entities[askId] + // if (userQuery && userQuery.topicId === topicId && !idsToDelete.includes(askId)) { + // messagesToDelete.push(userQuery) + // idsToDelete.push(askId) + // } if (messagesToDelete.length === 0) { console.warn(`[deleteMessageGroup] No messages found with askId ${askId} in topic ${topicId}.`) @@ -894,13 +892,29 @@ export const resendMessageThunk = const resetDataList: Message[] = [] if (assistantMessagesToReset.length === 0) { - // 没有用户消息,就创建一个 - const assistantMessage = createAssistantMessage(assistant.id, topicId, { - askId: userMessageToResend.id, - model: assistant.model + // 没有用户消息,就创建一个或多个 + + if (userMessageToResend?.mentions?.length) { + console.log('userMessageToResend.mentions', userMessageToResend.mentions) + for (const mention of userMessageToResend.mentions) { + const assistantMessage = createAssistantMessage(assistant.id, topicId, { + askId: userMessageToResend.id, + model: mention, + modelId: mention.id + }) + resetDataList.push(assistantMessage) + } + } else { + const assistantMessage = createAssistantMessage(assistant.id, topicId, { + askId: userMessageToResend.id, + model: assistant.model + }) + resetDataList.push(assistantMessage) + } + + resetDataList.forEach((message) => { + dispatch(newMessagesActions.addMessage({ topicId, message })) }) - resetDataList.push(assistantMessage) - dispatch(newMessagesActions.addMessage({ topicId, message: assistantMessage })) } const allBlockIdsToDelete: string[] = [] From ba88a24455b4dbd3c6a4fdd420d104de6ce0f762 Mon Sep 17 00:00:00 2001 From: eeee0717 Date: Sun, 18 May 2025 09:14:57 +0800 Subject: [PATCH 02/17] fix(knowledge): remove topN --- src/main/reranker/BaseReranker.ts | 2 +- .../components/KnowledgeSearchPopup.tsx | 28 +--- .../components/KnowledgeSettingsPopup.tsx | 21 +-- src/renderer/src/services/KnowledgeService.ts | 130 ++++++++++-------- src/renderer/src/types/index.ts | 4 +- 5 files changed, 76 insertions(+), 109 deletions(-) diff --git a/src/main/reranker/BaseReranker.ts b/src/main/reranker/BaseReranker.ts index a88d0883ae..f956a0573f 100644 --- a/src/main/reranker/BaseReranker.ts +++ b/src/main/reranker/BaseReranker.ts @@ -38,7 +38,7 @@ export default abstract class BaseReranker { protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) { const provider = this.base.rerankModelProvider const documents = searchResults.map((doc) => doc.pageContent) - const topN = this.base.topN || 10 + const topN = this.base.documentCount if (provider === 'voyageai') { return { diff --git a/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx b/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx index fd76d870c5..34f750c908 100644 --- a/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx +++ b/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx @@ -1,8 +1,7 @@ import { CopyOutlined } from '@ant-design/icons' import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import { TopView } from '@renderer/components/TopView' -import { DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant' -import { getFileFromUrl, getKnowledgeBaseParams } from '@renderer/services/KnowledgeService' +import { searchKnowledgeBase } from '@renderer/services/KnowledgeService' import { FileType, KnowledgeBase } from '@renderer/types' import { Input, List, message, Modal, Spin, Tooltip, Typography } from 'antd' import { useRef, useState } from 'react' @@ -38,29 +37,8 @@ const PopupContainer: React.FC = ({ base, resolve }) => { setSearchKeyword(value.trim()) setLoading(true) try { - const searchResults = await window.api.knowledgeBase.search({ - search: value, - base: getKnowledgeBaseParams(base) - }) - let rerankResult = searchResults - if (base.rerankModel) { - rerankResult = await window.api.knowledgeBase.rerank({ - search: value, - base: getKnowledgeBaseParams(base), - results: searchResults - }) - } - const results = await Promise.all( - rerankResult.map(async (item) => { - const file = await getFileFromUrl(item.metadata.source) - return { ...item, file } - }) - ) - const filteredResults = results.filter((item) => { - const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD - return item.score >= threshold - }) - setResults(filteredResults) + const searchResults = await searchKnowledgeBase(value, base) + setResults(searchResults) } catch (error) { console.error('Search failed:', error) } finally { diff --git a/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx b/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx index f409990094..237c4a8bcb 100644 --- a/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx +++ b/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx @@ -29,7 +29,6 @@ interface FormData { chunkOverlap?: number threshold?: number rerankModel?: string - topN?: number } interface Props extends ShowParams { @@ -95,8 +94,7 @@ const PopupContainer: React.FC = ({ base: _base, resolve }) => { threshold: values.threshold ?? undefined, rerankModel: values.rerankModel ? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel) - : undefined, - topN: values.topN + : undefined } updateKnowledgeBase(newBase) setOpen(false) @@ -283,23 +281,6 @@ const PopupContainer: React.FC = ({ base: _base, resolve }) => { - 30)) { - return Promise.reject(new Error(t('knowledge.topN_too_large_or_small'))) - } - return Promise.resolve() - } - } - ]}> - - > => { + try { + const baseParams = getKnowledgeBaseParams(base) + const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT + const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD + + // 执行搜索 + const searchResults = await window.api.knowledgeBase.search({ + search: query, + base: baseParams + }) + + // 过滤阈值不达标的结果 + const filteredResults = searchResults.filter((item) => item.score >= threshold) + + // 如果有rerank模型,执行重排 + let rerankResults = filteredResults + if (base.rerankModel && filteredResults.length > 0) { + rerankResults = await window.api.knowledgeBase.rerank({ + search: rewrite || query, + base: baseParams, + results: filteredResults + }) + } + + // 限制文档数量 + const limitedResults = rerankResults.slice(0, documentCount) + + // 处理文件信息 + return await Promise.all( + limitedResults.map(async (item) => { + const file = await getFileFromUrl(item.metadata.source) + return { ...item, file } + }) + ) + } catch (error) { + Logger.error(`Error searching knowledge base ${base.name}:`, error) + return [] + } +} + export const processKnowledgeSearch = async ( extractResults: ExtractResults, knowledgeBaseIds: string[] | undefined @@ -100,6 +145,7 @@ export const processKnowledgeSearch = async ( Logger.log('No valid question found in extractResults.knowledge') return [] } + const questions = extractResults.knowledge.question const rewrite = extractResults.knowledge.rewrite @@ -109,73 +155,35 @@ export const processKnowledgeSearch = async ( return [] } - const referencesPromises = bases.map(async (base) => { - try { - const baseParams = getKnowledgeBaseParams(base) - const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT + // 为每个知识库执行多问题搜索 + const baseSearchPromises = bases.map(async (base) => { + // 为每个问题搜索并合并结果 + const allResults = await Promise.all(questions.map((question) => searchKnowledgeBase(question, base, rewrite))) - const allSearchResultsPromises = questions.map((question) => - window.api.knowledgeBase - .search({ - search: question, - base: baseParams - }) - .then((results) => - results.filter((item) => { - const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD - return item.score >= threshold - }) - ) - ) + // 合并结果并去重 + const flatResults = allResults.flat() + const uniqueResults = Array.from( + new Map(flatResults.map((item) => [item.metadata.uniqueId || item.pageContent, item])).values() + ).sort((a, b) => b.score - a.score) - const allSearchResults = await Promise.all(allSearchResultsPromises) - - const searchResults = Array.from( - new Map(allSearchResults.flat().map((item) => [item.metadata.uniqueId || item.pageContent, item])).values() - ).sort((a, b) => b.score - a.score) - - Logger.log(`Knowledge base ${base.name} search results:`, searchResults) - - let rerankResults = searchResults - if (base.rerankModel && searchResults.length > 0) { - rerankResults = await window.api.knowledgeBase.rerank({ - search: rewrite, - base: baseParams, - results: searchResults - }) - } - - if (rerankResults.length > 0) { - rerankResults = rerankResults.slice(0, documentCount) - } - - const processdResults = await Promise.all( - rerankResults.map(async (item) => { - const file = await getFileFromUrl(item.metadata.source) - return { ...item, file } - }) - ) - - return await Promise.all( - processdResults.map(async (item, index) => { - // const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId) - return { - id: index + 1, // 搜索多个库会导致ID重复 + // 转换为引用格式 + return await Promise.all( + uniqueResults.map( + async (item, index) => + ({ + id: index + 1, content: item.pageContent, sourceUrl: await getKnowledgeSourceUrl(item), - type: 'file' // 需要映射 baseItem.type是'localPathLoader' -> 'file' - } as KnowledgeReference - }) + type: 'file' + }) as KnowledgeReference ) - } catch (error) { - Logger.error(`Error searching knowledge base ${base.name}:`, error) - return [] - } + ) }) - const resultsPerBase = await Promise.all(referencesPromises) - + // 汇总所有知识库的结果 + const resultsPerBase = await Promise.all(baseSearchPromises) const allReferencesRaw = resultsPerBase.flat().filter((ref): ref is KnowledgeReference => !!ref) + // 重新为引用分配ID return allReferencesRaw.map((ref, index) => ({ ...ref, diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index ac10c11b3d..0169bfd83f 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -372,7 +372,7 @@ export interface KnowledgeBase { chunkOverlap?: number threshold?: number rerankModel?: Model - topN?: number + // topN?: number } export type KnowledgeBaseParams = { @@ -388,7 +388,7 @@ export type KnowledgeBaseParams = { rerankBaseURL?: string rerankModel?: string rerankModelProvider?: string - topN?: number + documentCount?: number } export type GenerateImageParams = { From f43c16b85f0f256cbcf86cdf7d6bc6fbc46e5223 Mon Sep 17 00:00:00 2001 From: kangfenmao Date: Sun, 18 May 2025 14:41:10 +0800 Subject: [PATCH 03/17] fix(WindowService): add backgroundThrottling option to Electron configuration --- src/main/services/WindowService.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/services/WindowService.ts b/src/main/services/WindowService.ts index aff511d748..9bfefab68a 100644 --- a/src/main/services/WindowService.ts +++ b/src/main/services/WindowService.ts @@ -75,7 +75,8 @@ export class WindowService { sandbox: false, webSecurity: false, webviewTag: true, - allowRunningInsecureContent: true + allowRunningInsecureContent: true, + backgroundThrottling: false } }) From 711888a8972cf0aee1eda7e367a131a6b6e8c51e Mon Sep 17 00:00:00 2001 From: purefkh Date: Sat, 17 May 2025 21:34:17 +0800 Subject: [PATCH 04/17] fix: Prevent sending message during input method composition in mini window (#6104) --- src/renderer/src/windows/mini/home/HomeWindow.tsx | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/renderer/src/windows/mini/home/HomeWindow.tsx b/src/renderer/src/windows/mini/home/HomeWindow.tsx index b1fd9b18e5..f5f0171935 100644 --- a/src/renderer/src/windows/mini/home/HomeWindow.tsx +++ b/src/renderer/src/windows/mini/home/HomeWindow.tsx @@ -90,6 +90,9 @@ const HomeWindow: FC = () => { // 例子,中文输入法候选词过程使用`Enter`直接上屏字母,日文输入法候选词过程使用`Enter`输入假名 // 输入法可以`Esc`终止候选词过程 // 这两个例子的`Enter`和`Esc`快捷助手都不应该响应 + if (e.nativeEvent.isComposing) { + return + } if (e.key === 'Process') { return } From 114d4850b18dae0f0fad5ddec62d541ecb2183b6 Mon Sep 17 00:00:00 2001 From: Konjac-XZ <1951801592@qq.com> Date: Sat, 17 May 2025 16:08:29 +0800 Subject: [PATCH 05/17] fix: Summary for single message export doesn't work. --- src/renderer/src/services/MessagesService.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/renderer/src/services/MessagesService.ts b/src/renderer/src/services/MessagesService.ts index 91a676e7a1..a20f481f6c 100644 --- a/src/renderer/src/services/MessagesService.ts +++ b/src/renderer/src/services/MessagesService.ts @@ -215,10 +215,9 @@ export async function getMessageTitle(message: Message, length = 30): Promise Date: Sun, 18 May 2025 17:16:05 +0800 Subject: [PATCH 06/17] Hotfix/rich text paste (#6122) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix(Inputbar): 优化粘贴处理逻辑,优先处理文本粘贴 --- .../src/pages/home/Inputbar/Inputbar.tsx | 40 ++++++++++--------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index 39621cc85d..2458e35a49 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -580,7 +580,27 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = const onPaste = useCallback( async (event: ClipboardEvent) => { - // 1. 文件/图片粘贴 + // 优先处理文本粘贴 + const clipboardText = event.clipboardData?.getData('text') + if (clipboardText) { + // 1. 文本粘贴 + if (pasteLongTextAsFile && clipboardText.length > pasteLongTextThreshold) { + // 长文本直接转文件,阻止默认粘贴 + event.preventDefault() + + const tempFilePath = await window.api.file.create('pasted_text.txt') + await window.api.file.write(tempFilePath, clipboardText) + const selectedFile = await window.api.file.get(tempFilePath) + selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile]) + setText(text) // 保持输入框内容不变 + setTimeout(() => resizeTextArea(), 50) + return + } + // 短文本走默认粘贴行为,直接返回 + return + } + + // 2. 文件/图片粘贴(仅在无文本时处理) if (event.clipboardData?.files && event.clipboardData.files.length > 0) { event.preventDefault() for (const file of event.clipboardData.files) { @@ -616,23 +636,7 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = } return } - - // 2. 文本粘贴 - const clipboardText = event.clipboardData?.getData('text') - if (pasteLongTextAsFile && clipboardText && clipboardText.length > pasteLongTextThreshold) { - // 长文本直接转文件,阻止默认粘贴 - event.preventDefault() - - const tempFilePath = await window.api.file.create('pasted_text.txt') - await window.api.file.write(tempFilePath, clipboardText) - const selectedFile = await window.api.file.get(tempFilePath) - selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile]) - setText(text) // 保持输入框内容不变 - setTimeout(() => resizeTextArea(), 50) - return - } - - // 短文本走默认粘贴行为 + // 其他情况默认粘贴 }, [model, pasteLongTextAsFile, pasteLongTextThreshold, resizeTextArea, supportExts, t, text] ) From fed855a4aed0d8ec8a5911ee34b7798cdb22b32e Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat <43230886+MyPrototypeWhat@users.noreply.github.com> Date: Sun, 18 May 2025 18:08:13 +0800 Subject: [PATCH 07/17] =?UTF-8?q?refactor(MessageGroup):=20optimize=20sele?= =?UTF-8?q?cted=20message=20handling=20with=20useMe=E2=80=A6=20(#6124)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(MessageGroup): optimize selected message handling with useMemo and clean up unused code * refactor(MainTextBlock): optimize citation handling with useMemo and improve code clarity * fix:del console --- .../home/Messages/Blocks/MainTextBlock.tsx | 9 ++--- .../src/pages/home/Messages/MessageGroup.tsx | 34 +++++++++---------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/renderer/src/pages/home/Messages/Blocks/MainTextBlock.tsx b/src/renderer/src/pages/home/Messages/Blocks/MainTextBlock.tsx index b004e784a9..25c6223c9d 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/MainTextBlock.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/MainTextBlock.tsx @@ -38,13 +38,14 @@ const MainTextBlock: React.FC = ({ block, citationBlockId, role, mentions // Use the passed citationBlockId directly in the selector const { renderInputMessageAsMarkdown } = useSettings() - const formattedCitations = useSelector((state: RootState) => { - const citations = selectFormattedCitationsByBlockId(state, citationBlockId) - return citations.map((citation) => ({ + const rawCitations = useSelector((state: RootState) => selectFormattedCitationsByBlockId(state, citationBlockId)) + + const formattedCitations = useMemo(() => { + return rawCitations.map((citation) => ({ ...citation, content: citation.content ? cleanMarkdownContent(citation.content) : citation.content })) - }) + }, [rawCitations]) const processedContent = useMemo(() => { let content = block.content diff --git a/src/renderer/src/pages/home/Messages/MessageGroup.tsx b/src/renderer/src/pages/home/Messages/MessageGroup.tsx index 405ced9b6c..de183547e1 100644 --- a/src/renderer/src/pages/home/Messages/MessageGroup.tsx +++ b/src/renderer/src/pages/home/Messages/MessageGroup.tsx @@ -7,7 +7,7 @@ import type { Topic } from '@renderer/types' import type { Message } from '@renderer/types/newMessage' import { classNames } from '@renderer/utils' import { Popover } from 'antd' -import { memo, useCallback, useEffect, useRef, useState } from 'react' +import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react' import styled, { css } from 'styled-components' import MessageItem from './Message' @@ -31,7 +31,8 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => { const prevMessageLengthRef = useRef(messageLength) const [selectedIndex, setSelectedIndex] = useState(messageLength - 1) - const getSelectedMessageId = useCallback(() => { + const selectedMessageId = useMemo(() => { + if (messages.length === 1) return messages[0]?.id const selectedMessage = messages.find((message) => message.foldSelected) if (selectedMessage) { return selectedMessage.id @@ -41,9 +42,10 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => { const setSelectedMessage = useCallback( (message: Message) => { - messages.forEach(async (m) => { - await editMessage(m.id, { foldSelected: m.id === message.id }) - }) + // 前一个 + editMessage(selectedMessageId, { foldSelected: false }) + // 当前选中的消息 + editMessage(message.id, { foldSelected: true }) setTimeout(() => { const messageElement = document.getElementById(`message-${message.id}`) @@ -52,7 +54,7 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => { } }, 200) }, - [editMessage, messages] + [editMessage, selectedMessageId] ) const isGrouped = messageLength > 1 && messages.every((m) => m.role === 'assistant') @@ -67,8 +69,7 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => { setSelectedMessage(lastMessage) } } else { - const selectedId = getSelectedMessageId() - const newIndex = messages.findIndex((msg) => msg.id === selectedId) + const newIndex = messages.findIndex((msg) => msg.id === selectedMessageId) if (newIndex !== -1) { setSelectedIndex(newIndex) } @@ -147,7 +148,7 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => { }, [messages, setSelectedMessage]) const renderMessage = useCallback( - (message: Message & { index: number }, index: number) => { + (message: Message & { index: number }) => { const isGridGroupMessage = isGrid && message.role === 'assistant' && isGrouped const messageProps = { isGrouped, @@ -164,15 +165,15 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => { - + {message.id === selectedMessageId && } ) @@ -183,7 +184,7 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => { content={ @@ -204,11 +205,10 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => { isGrouped, isHorizontal, multiModelMessageStyle, - selectedIndex, topic, hidePresetMessages, gridPopoverTrigger, - getSelectedMessageId + selectedMessageId ] ) @@ -235,7 +235,7 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => { }) }} messages={messages} - selectMessageId={getSelectedMessageId()} + selectMessageId={selectedMessageId} setSelectedMessage={setSelectedMessage} topic={topic} /> @@ -297,7 +297,7 @@ const GridContainer = styled.div<{ $count: number; $layout: MultiModelMessageSty interface MessageWrapperProps { $layout: 'fold' | 'horizontal' | 'vertical' | 'grid' - $selected: boolean + // $selected: boolean $isGrouped: boolean $isInPopover?: boolean } From 2202b82f33af8fbbbfd8a4e738b7b60c6e210cb4 Mon Sep 17 00:00:00 2001 From: George Zhao <38124587+CreatorZZY@users.noreply.github.com> Date: Sun, 18 May 2025 18:37:08 +0800 Subject: [PATCH 08/17] fix: remove infinite token display for max count and simplify context count rendering. (#6103) Co-authored-by: George Zhao --- .../src/pages/home/Inputbar/TokenCount.tsx | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/renderer/src/pages/home/Inputbar/TokenCount.tsx b/src/renderer/src/pages/home/Inputbar/TokenCount.tsx index c2a0ee5068..b7ca3b351b 100644 --- a/src/renderer/src/pages/home/Inputbar/TokenCount.tsx +++ b/src/renderer/src/pages/home/Inputbar/TokenCount.tsx @@ -22,18 +22,6 @@ const TokenCount: FC = ({ estimateTokenCount, inputTokenCount, contextCou } const formatMaxCount = (max: number) => { - if (max == 100) { - return ( - - ∞ - - ) - } return max.toString() } @@ -43,7 +31,7 @@ const TokenCount: FC = ({ estimateTokenCount, inputTokenCount, contextCou {t('chat.input.context_count.tip')} - {contextCount.current} / {contextCount.max == 20 ? '∞' : contextCount.max} + {contextCount.current} / {contextCount.max} From 03fa6b5a74e86872b2636160e317425554bff9ac Mon Sep 17 00:00:00 2001 From: kangfenmao Date: Sun, 18 May 2025 19:43:40 +0800 Subject: [PATCH 09/17] fix(WindowService): handle fullscreen toggle before hiding window --- src/main/services/WindowService.ts | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/main/services/WindowService.ts b/src/main/services/WindowService.ts index 9bfefab68a..eb4e2f104a 100644 --- a/src/main/services/WindowService.ts +++ b/src/main/services/WindowService.ts @@ -324,6 +324,11 @@ export class WindowService { event.preventDefault() + if (mainWindow.isFullScreen()) { + mainWindow.setFullScreen(false) + return + } + mainWindow.hide() //for mac users, should hide dock icon if close to tray From 4b2417ce375f945415235cd9cdd1d4034fe1b917 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=87=AA=E7=94=B1=E7=9A=84=E4=B8=96=E7=95=8C=E4=BA=BA?= <3196812536@qq.com> Date: Sun, 18 May 2025 20:16:33 +0800 Subject: [PATCH 10/17] hotfix: github models check error (#6128) --- src/renderer/src/providers/AiProvider/OpenAIProvider.ts | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts index 545c92ed3e..cae673437d 100644 --- a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts +++ b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts @@ -1136,13 +1136,16 @@ export default class OpenAIProvider extends BaseOpenAIProvider { return { valid: false, error: new Error('No model found') } } - const body = { + const body: any = { model: model.id, messages: [{ role: 'user', content: 'hi' }], - enable_thinking: false, // qwen3 stream } + if (this.provider.id !== 'github') { + body.enable_thinking = false; // qwen3 + } + try { await this.checkIsCopilot() if (!stream) { From 8bfbbd497cff3ffc693e29bc9f66f4586908178f Mon Sep 17 00:00:00 2001 From: kangfenmao Date: Sun, 18 May 2025 18:47:01 +0800 Subject: [PATCH 11/17] refactor: streamline MCP service handling and improve IPC registration * Refactored MCPService to implement a singleton pattern for better instance management. * Updated IPC registration to utilize the new getMcpInstance method for handling MCP-related requests. * Removed redundant IPC handlers from the main index file and centralized them in the ipc module. * Added background throttling option in WindowService configuration to enhance performance. * Introduced delays in MCPToolsButton to optimize resource and prompt fetching after initial load. --- src/main/index.ts | 19 +- src/main/ipc.ts | 30 +- src/main/services/MCPService.ts | 440 ++++++++++-------- src/main/services/WindowService.ts | 3 +- src/main/services/mcp/shell-env.ts | 2 +- src/renderer/src/hooks/useMCPServers.ts | 11 +- .../pages/home/Inputbar/MCPToolsButton.tsx | 67 ++- 7 files changed, 320 insertions(+), 252 deletions(-) diff --git a/src/main/index.ts b/src/main/index.ts index f85803ed84..44d516a5ca 100644 --- a/src/main/index.ts +++ b/src/main/index.ts @@ -2,12 +2,11 @@ import '@main/config' import { electronApp, optimizer } from '@electron-toolkit/utils' import { replaceDevtoolsFont } from '@main/utils/windowUtil' -import { IpcChannel } from '@shared/IpcChannel' -import { app, BrowserWindow, ipcMain } from 'electron' +import { app } from 'electron' import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer' import Logger from 'electron-log' -import { isDev, isMac, isWin } from './constant' +import { isDev } from './constant' import { registerIpc } from './ipc' import { configManager } from './services/ConfigManager' import mcpService from './services/MCPService' @@ -85,18 +84,6 @@ if (!app.requestSingleInstanceLock()) { .then((name) => console.log(`Added Extension: ${name}`)) .catch((err) => console.log('An error occurred: ', err)) } - ipcMain.handle(IpcChannel.System_GetDeviceType, () => { - return isMac ? 'mac' : isWin ? 'windows' : 'linux' - }) - - ipcMain.handle(IpcChannel.System_GetHostname, () => { - return require('os').hostname() - }) - - ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => { - const win = BrowserWindow.fromWebContents(e.sender) - win && win.webContents.toggleDevTools() - }) }) registerProtocolClient(app) @@ -128,7 +115,7 @@ if (!app.requestSingleInstanceLock()) { app.on('will-quit', async () => { // event.preventDefault() try { - await mcpService.cleanup() + await mcpService().cleanup() } catch (error) { Logger.error('Error cleaning up MCP service:', error) } diff --git a/src/main/ipc.ts b/src/main/ipc.ts index 665e8114b7..439ed63e3d 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -19,7 +19,7 @@ import FileService from './services/FileService' import FileStorage from './services/FileStorage' import { GeminiService } from './services/GeminiService' import KnowledgeService from './services/KnowledgeService' -import mcpService from './services/MCPService' +import { getMcpInstance } from './services/MCPService' import * as NutstoreService from './services/NutstoreService' import ObsidianVaultService from './services/ObsidianVaultService' import { ProxyConfig, proxyManager } from './services/ProxyManager' @@ -204,6 +204,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.Zip_Compress, (_, text: string) => compress(text)) ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text)) + // system + ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux')) + ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname()) + ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => { + const win = BrowserWindow.fromWebContents(e.sender) + win && win.webContents.toggleDevTools() + }) + // backup ipcMain.handle(IpcChannel.Backup_Backup, backupManager.backup) ipcMain.handle(IpcChannel.Backup_Restore, backupManager.restore) @@ -301,16 +309,16 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ) // Register MCP handlers - ipcMain.handle(IpcChannel.Mcp_RemoveServer, mcpService.removeServer) - ipcMain.handle(IpcChannel.Mcp_RestartServer, mcpService.restartServer) - ipcMain.handle(IpcChannel.Mcp_StopServer, mcpService.stopServer) - ipcMain.handle(IpcChannel.Mcp_ListTools, mcpService.listTools) - ipcMain.handle(IpcChannel.Mcp_CallTool, mcpService.callTool) - ipcMain.handle(IpcChannel.Mcp_ListPrompts, mcpService.listPrompts) - ipcMain.handle(IpcChannel.Mcp_GetPrompt, mcpService.getPrompt) - ipcMain.handle(IpcChannel.Mcp_ListResources, mcpService.listResources) - ipcMain.handle(IpcChannel.Mcp_GetResource, mcpService.getResource) - ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo) + ipcMain.handle(IpcChannel.Mcp_RemoveServer, (event, server) => getMcpInstance().removeServer(event, server)) + ipcMain.handle(IpcChannel.Mcp_RestartServer, (event, server) => getMcpInstance().restartServer(event, server)) + ipcMain.handle(IpcChannel.Mcp_StopServer, (event, server) => getMcpInstance().stopServer(event, server)) + ipcMain.handle(IpcChannel.Mcp_ListTools, (event, server) => getMcpInstance().listTools(event, server)) + ipcMain.handle(IpcChannel.Mcp_CallTool, (event, params) => getMcpInstance().callTool(event, params)) + ipcMain.handle(IpcChannel.Mcp_ListPrompts, (event, server) => getMcpInstance().listPrompts(event, server)) + ipcMain.handle(IpcChannel.Mcp_GetPrompt, (event, params) => getMcpInstance().getPrompt(event, params)) + ipcMain.handle(IpcChannel.Mcp_ListResources, (event, server) => getMcpInstance().listResources(event, server)) + ipcMain.handle(IpcChannel.Mcp_GetResource, (event, params) => getMcpInstance().getResource(event, params)) + ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, () => getMcpInstance().getInstallInfo()) ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name)) ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name)) diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index 237e709deb..5ea91343f5 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -68,20 +68,18 @@ function withCache( } class McpService { + private static instance: McpService | null = null private clients: Map = new Map() + private pendingClients: Map> = new Map() - private getServerKey(server: MCPServer): string { - return JSON.stringify({ - baseUrl: server.baseUrl, - command: server.command, - args: server.args, - registryUrl: server.registryUrl, - env: server.env, - id: server.id - }) + public static getInstance(): McpService { + if (!McpService.instance) { + McpService.instance = new McpService() + } + return McpService.instance } - constructor() { + private constructor() { this.initClient = this.initClient.bind(this) this.listTools = this.listTools.bind(this) this.callTool = this.callTool.bind(this) @@ -96,9 +94,26 @@ class McpService { this.cleanup = this.cleanup.bind(this) } + private getServerKey(server: MCPServer): string { + return JSON.stringify({ + baseUrl: server.baseUrl, + command: server.command, + args: server.args, + registryUrl: server.registryUrl, + env: server.env, + id: server.id + }) + } + async initClient(server: MCPServer): Promise { const serverKey = this.getServerKey(server) + // If there's a pending initialization, wait for it + const pendingClient = this.pendingClients.get(serverKey) + if (pendingClient) { + return pendingClient + } + // Check if we already have a client for this server configuration const existingClient = this.clients.get(serverKey) if (existingClient) { @@ -113,209 +128,226 @@ class McpService { } else { return existingClient } - } catch (error) { - Logger.error(`[MCP] Error pinging server ${server.name}:`, error) + } catch (error: any) { + Logger.error(`[MCP] Error pinging server ${server.name}:`, error?.message) this.clients.delete(serverKey) } } - // Create new client instance for each connection - const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} }) - const args = [...(server.args || [])] + // Create a promise for the initialization process + const initPromise = (async () => { + try { + // Create new client instance for each connection + const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} }) - // let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport - const authProvider = new McpOAuthClientProvider({ - serverUrlHash: crypto - .createHash('md5') - .update(server.baseUrl || '') - .digest('hex') - }) + const args = [...(server.args || [])] - const initTransport = async (): Promise< - StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport - > => { - // Create appropriate transport based on configuration - if (server.type === 'inMemory') { - Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`) - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair() - // start the in-memory server with the given name and environment variables - const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {}) - try { - await inMemoryServer.connect(serverTransport) - Logger.info(`[MCP] In-memory server started: ${server.name}`) - } catch (error: Error | any) { - Logger.error(`[MCP] Error starting in-memory server: ${error}`) - throw new Error(`Failed to start in-memory server: ${error.message}`) - } - // set the client transport to the client - return clientTransport - } else if (server.baseUrl) { - if (server.type === 'streamableHttp') { - const options: StreamableHTTPClientTransportOptions = { - requestInit: { - headers: server.headers || {} - }, - authProvider - } - return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options) - } else if (server.type === 'sse') { - const options: SSEClientTransportOptions = { - eventSourceInit: { - fetch: async (url, init) => { - const headers = { ...(server.headers || {}), ...(init?.headers || {}) } + // let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport + const authProvider = new McpOAuthClientProvider({ + serverUrlHash: crypto + .createHash('md5') + .update(server.baseUrl || '') + .digest('hex') + }) - // Get tokens from authProvider to make sure using the latest tokens - if (authProvider && typeof authProvider.tokens === 'function') { - try { - const tokens = await authProvider.tokens() - if (tokens && tokens.access_token) { - headers['Authorization'] = `Bearer ${tokens.access_token}` + const initTransport = async (): Promise< + StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport + > => { + // Create appropriate transport based on configuration + if (server.type === 'inMemory') { + Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`) + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair() + // start the in-memory server with the given name and environment variables + const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {}) + try { + await inMemoryServer.connect(serverTransport) + Logger.info(`[MCP] In-memory server started: ${server.name}`) + } catch (error: Error | any) { + Logger.error(`[MCP] Error starting in-memory server: ${error}`) + throw new Error(`Failed to start in-memory server: ${error.message}`) + } + // set the client transport to the client + return clientTransport + } else if (server.baseUrl) { + if (server.type === 'streamableHttp') { + const options: StreamableHTTPClientTransportOptions = { + requestInit: { + headers: server.headers || {} + }, + authProvider + } + return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options) + } else if (server.type === 'sse') { + const options: SSEClientTransportOptions = { + eventSourceInit: { + fetch: async (url, init) => { + const headers = { ...(server.headers || {}), ...(init?.headers || {}) } + + // Get tokens from authProvider to make sure using the latest tokens + if (authProvider && typeof authProvider.tokens === 'function') { + try { + const tokens = await authProvider.tokens() + if (tokens && tokens.access_token) { + headers['Authorization'] = `Bearer ${tokens.access_token}` + } + } catch (error) { + Logger.error('Failed to fetch tokens:', error) + } } - } catch (error) { - Logger.error('Failed to fetch tokens:', error) + + return fetch(url, { ...init, headers }) } + }, + requestInit: { + headers: server.headers || {} + }, + authProvider + } + return new SSEClientTransport(new URL(server.baseUrl!), options) + } else { + throw new Error('Invalid server type') + } + } else if (server.command) { + let cmd = server.command + + if (server.command === 'npx') { + cmd = await getBinaryPath('bun') + Logger.info(`[MCP] Using command: ${cmd}`) + + // add -x to args if args exist + if (args && args.length > 0) { + if (!args.includes('-y')) { + args.unshift('-y') + } + if (!args.includes('x')) { + args.unshift('x') + } + } + if (server.registryUrl) { + server.env = { + ...server.env, + NPM_CONFIG_REGISTRY: server.registryUrl } - return fetch(url, { ...init, headers }) + // if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory + if (server.name.includes('mcp-auto-install')) { + const binPath = await getBinaryPath() + makeSureDirExists(binPath) + server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json') + } + } + } else if (server.command === 'uvx' || server.command === 'uv') { + cmd = await getBinaryPath(server.command) + if (server.registryUrl) { + server.env = { + ...server.env, + UV_DEFAULT_INDEX: server.registryUrl, + PIP_INDEX_URL: server.registryUrl + } } - }, - requestInit: { - headers: server.headers || {} - }, - authProvider - } - return new SSEClientTransport(new URL(server.baseUrl!), options) - } else { - throw new Error('Invalid server type') - } - } else if (server.command) { - let cmd = server.command - - if (server.command === 'npx') { - cmd = await getBinaryPath('bun') - Logger.info(`[MCP] Using command: ${cmd}`) - - // add -x to args if args exist - if (args && args.length > 0) { - if (!args.includes('-y')) { - args.unshift('-y') - } - if (!args.includes('x')) { - args.unshift('x') - } - } - if (server.registryUrl) { - server.env = { - ...server.env, - NPM_CONFIG_REGISTRY: server.registryUrl } - // if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory - if (server.name.includes('mcp-auto-install')) { - const binPath = await getBinaryPath() - makeSureDirExists(binPath) - server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json') - } - } - } else if (server.command === 'uvx' || server.command === 'uv') { - cmd = await getBinaryPath(server.command) - if (server.registryUrl) { - server.env = { - ...server.env, - UV_DEFAULT_INDEX: server.registryUrl, - PIP_INDEX_URL: server.registryUrl - } + Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`) + // Logger.info(`[MCP] Environment variables for server:`, server.env) + const loginShellEnv = await this.getLoginShellEnv() + const stdioTransport = new StdioClientTransport({ + command: cmd, + args, + env: { + ...loginShellEnv, + ...server.env + }, + stderr: 'pipe' + }) + stdioTransport.stderr?.on('data', (data) => + Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString()) + ) + return stdioTransport + } else { + throw new Error('Either baseUrl or command must be provided') } } - Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`) - // Logger.info(`[MCP] Environment variables for server:`, server.env) - const loginShellEnv = await this.getLoginShellEnv() - const stdioTransport = new StdioClientTransport({ - command: cmd, - args, - env: { - ...loginShellEnv, - ...server.env - }, - stderr: 'pipe' - }) - stdioTransport.stderr?.on('data', (data) => - Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString()) - ) - return stdioTransport - } else { - throw new Error('Either baseUrl or command must be provided') - } - } + const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => { + Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`) + // Create an event emitter for the OAuth callback + const events = new EventEmitter() - const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => { - Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`) - // Create an event emitter for the OAuth callback - const events = new EventEmitter() + // Create a callback server + const callbackServer = new CallBackServer({ + port: authProvider.config.callbackPort, + path: authProvider.config.callbackPath || '/oauth/callback', + events + }) - // Create a callback server - const callbackServer = new CallBackServer({ - port: authProvider.config.callbackPort, - path: authProvider.config.callbackPath || '/oauth/callback', - events - }) + // Set a timeout to close the callback server + const timeoutId = setTimeout(() => { + Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`) + callbackServer.close() + }, 300000) // 5 minutes timeout - // Set a timeout to close the callback server - const timeoutId = setTimeout(() => { - Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`) - callbackServer.close() - }, 300000) // 5 minutes timeout + try { + // Wait for the authorization code + const authCode = await callbackServer.waitForAuthCode() + Logger.info(`[MCP] Received auth code: ${authCode}`) - try { - // Wait for the authorization code - const authCode = await callbackServer.waitForAuthCode() - Logger.info(`[MCP] Received auth code: ${authCode}`) + // Complete the OAuth flow + await transport.finishAuth(authCode) - // Complete the OAuth flow - await transport.finishAuth(authCode) + Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`) - Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`) + const newTransport = await initTransport() + // Try to connect again + await client.connect(newTransport) - const newTransport = await initTransport() - // Try to connect again - await client.connect(newTransport) + Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`) + } catch (oauthError) { + Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError) + throw new Error( + `OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}` + ) + } finally { + // Clear the timeout and close the callback server + clearTimeout(timeoutId) + callbackServer.close() + } + } - Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`) - } catch (oauthError) { - Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError) - throw new Error( - `OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}` - ) + try { + const transport = await initTransport() + try { + await client.connect(transport) + } catch (error: Error | any) { + if ( + error instanceof Error && + (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized')) + ) { + Logger.info(`[MCP] Authentication required for server: ${server.name}`) + await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport) + } else { + throw error + } + } + + // Store the new client in the cache + this.clients.set(serverKey, client) + + Logger.info(`[MCP] Activated server: ${server.name}`) + return client + } catch (error: any) { + Logger.error(`[MCP] Error activating server ${server.name}:`, error?.message) + throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`) + } } finally { - // Clear the timeout and close the callback server - clearTimeout(timeoutId) - callbackServer.close() + // Clean up the pending promise when done + this.pendingClients.delete(serverKey) } - } + })() - try { - const transport = await initTransport() - try { - await client.connect(transport) - } catch (error: Error | any) { - if (error instanceof Error && (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))) { - Logger.info(`[MCP] Authentication required for server: ${server.name}`) - await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport) - } else { - throw error - } - } + // Store the pending promise + this.pendingClients.set(serverKey, initPromise) - // Store the new client in the cache - this.clients.set(serverKey, client) - - Logger.info(`[MCP] Activated server: ${server.name}`) - return client - } catch (error: any) { - Logger.error(`[MCP] Error activating server ${server.name}:`, error) - throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`) - } + return initPromise } async closeClient(serverKey: string) { @@ -357,8 +389,8 @@ class McpService { for (const [key] of this.clients) { try { await this.closeClient(key) - } catch (error) { - Logger.error(`[MCP] Failed to close client: ${error}`) + } catch (error: any) { + Logger.error(`[MCP] Failed to close client: ${error?.message}`) } } } @@ -379,8 +411,8 @@ class McpService { serverTools.push(serverTool) }) return serverTools - } catch (error) { - Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error) + } catch (error: any) { + Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error?.message) return [] } } @@ -439,8 +471,8 @@ class McpService { * List prompts available on an MCP server */ private async listPromptsImpl(server: MCPServer): Promise { - Logger.info(`[MCP] Listing prompts for server: ${server.name}`) const client = await this.initClient(server) + Logger.info(`[MCP] Listing prompts for server: ${server.name}`) try { const { prompts } = await client.listPrompts() return prompts.map((prompt: any) => ({ @@ -449,8 +481,11 @@ class McpService { serverId: server.id, serverName: server.name })) - } catch (error) { - Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error) + } catch (error: any) { + // -32601 is the code for the method not found + if (error?.code !== -32601) { + Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error?.message) + } return [] } } @@ -508,8 +543,8 @@ class McpService { * List resources available on an MCP server (implementation) */ private async listResourcesImpl(server: MCPServer): Promise { - Logger.info(`[MCP] Listing resources for server: ${server.name}`) const client = await this.initClient(server) + Logger.info(`[MCP] Listing resources for server: ${server.name}`) try { const result = await client.listResources() const resources = result.resources || [] @@ -519,8 +554,11 @@ class McpService { serverName: server.name })) return serverResources - } catch (error) { - Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error) + } catch (error: any) { + // -32601 is the code for the method not found + if (error?.code !== -32601) { + Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error?.message) + } return [] } } @@ -563,7 +601,7 @@ class McpService { contents: contents } } catch (error: Error | any) { - Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error) + Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error.message) throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`) } } @@ -602,5 +640,13 @@ class McpService { }) } -const mcpService = new McpService() -export default mcpService +let mcpInstance: ReturnType | null = null + +export const getMcpInstance = () => { + if (!mcpInstance) { + mcpInstance = McpService.getInstance() + } + return mcpInstance +} + +export default McpService.getInstance diff --git a/src/main/services/WindowService.ts b/src/main/services/WindowService.ts index eb4e2f104a..f033cc82bf 100644 --- a/src/main/services/WindowService.ts +++ b/src/main/services/WindowService.ts @@ -446,7 +446,8 @@ export class WindowService { preload: join(__dirname, '../preload/index.js'), sandbox: false, webSecurity: false, - webviewTag: true + webviewTag: true, + backgroundThrottling: false } }) diff --git a/src/main/services/mcp/shell-env.ts b/src/main/services/mcp/shell-env.ts index 54cc21280f..9901417024 100644 --- a/src/main/services/mcp/shell-env.ts +++ b/src/main/services/mcp/shell-env.ts @@ -47,7 +47,7 @@ function getLoginShellEnvironment(): Promise> { commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command } - Logger.log(`Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`) + Logger.log(`[ShellEnv] Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`) const child = spawn(shellPath, commandArgs, { cwd: homeDirectory, // Run the command in the user's home directory diff --git a/src/renderer/src/hooks/useMCPServers.ts b/src/renderer/src/hooks/useMCPServers.ts index 90d6e6fec4..49fde29f60 100644 --- a/src/renderer/src/hooks/useMCPServers.ts +++ b/src/renderer/src/hooks/useMCPServers.ts @@ -1,8 +1,8 @@ +import { createSelector } from '@reduxjs/toolkit' import store, { useAppDispatch, useAppSelector } from '@renderer/store' import { addMCPServer, deleteMCPServer, setMCPServers, updateMCPServer } from '@renderer/store/mcp' import { MCPServer } from '@renderer/types' import { IpcChannel } from '@shared/IpcChannel' -import { useMemo } from 'react' const ipcRenderer = window.electron.ipcRenderer @@ -14,9 +14,14 @@ ipcRenderer.on(IpcChannel.Mcp_AddServer, (_event, server: MCPServer) => { store.dispatch(addMCPServer(server)) }) +const selectMcpServers = (state) => state.mcp.servers +const selectActiveMcpServers = createSelector([selectMcpServers], (servers) => + servers.filter((server) => server.isActive) +) + export const useMCPServers = () => { - const mcpServers = useAppSelector((state) => state.mcp.servers) - const activedMcpServers = useMemo(() => mcpServers.filter((server) => server.isActive), [mcpServers]) + const mcpServers = useAppSelector(selectMcpServers) + const activedMcpServers = useAppSelector(selectActiveMcpServers) const dispatch = useAppDispatch() return { diff --git a/src/renderer/src/pages/home/Inputbar/MCPToolsButton.tsx b/src/renderer/src/pages/home/Inputbar/MCPToolsButton.tsx index 60eff1cd23..4aa461c9e9 100644 --- a/src/renderer/src/pages/home/Inputbar/MCPToolsButton.tsx +++ b/src/renderer/src/pages/home/Inputbar/MCPToolsButton.tsx @@ -3,6 +3,7 @@ import { useAssistant } from '@renderer/hooks/useAssistant' import { useMCPServers } from '@renderer/hooks/useMCPServers' import { EventEmitter } from '@renderer/services/EventService' import { Assistant, MCPPrompt, MCPResource, MCPServer } from '@renderer/types' +import { delay, runAsyncFunction } from '@renderer/utils' import { Form, Input, Tooltip } from 'antd' import { Plus, SquareTerminal } from 'lucide-react' import { FC, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react' @@ -109,6 +110,11 @@ const extractPromptContent = (response: any): string | null => { return null } +// Add static variable before component definition +let isFirstResourcesListCall = true +let isFirstPromptListCall = true +const initMcpDelay = 3 + const MCPToolsButton: FC = ({ ref, setInputValue, resizeTextArea, ToolbarButton, ...props }) => { const { activedMcpServers } = useMCPServers() const { t } = useTranslation() @@ -308,6 +314,11 @@ const MCPToolsButton: FC = ({ ref, setInputValue, resizeTextArea, Toolbar const promptList = useMemo(async () => { const prompts: MCPPrompt[] = [] + if (isFirstPromptListCall) { + await delay(initMcpDelay) + isFirstPromptListCall = false + } + for (const server of activedMcpServers) { const serverPrompts = await window.api.mcp.listPrompts(server) prompts.push(...serverPrompts) @@ -319,7 +330,8 @@ const MCPToolsButton: FC = ({ ref, setInputValue, resizeTextArea, Toolbar icon: , action: () => handlePromptSelect(prompt as MCPPromptWithArgs) })) - }, [handlePromptSelect, activedMcpServers]) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [activedMcpServers]) const openPromptList = useCallback(async () => { const prompts = await promptList @@ -380,33 +392,42 @@ const MCPToolsButton: FC = ({ ref, setInputValue, resizeTextArea, Toolbar const [resourcesList, setResourcesList] = useState([]) useEffect(() => { - let isMounted = true + runAsyncFunction(async () => { + let isMounted = true - const fetchResources = async () => { - const resources: MCPResource[] = [] - for (const server of activedMcpServers) { - const serverResources = await window.api.mcp.listResources(server) - resources.push(...serverResources) + const fetchResources = async () => { + const resources: MCPResource[] = [] + + for (const server of activedMcpServers) { + const serverResources = await window.api.mcp.listResources(server) + resources.push(...serverResources) + } + + if (isMounted) { + setResourcesList( + resources.map((resource) => ({ + label: resource.name, + description: resource.description, + icon: , + action: () => handleResourceSelect(resource) + })) + ) + } } - if (isMounted) { - setResourcesList( - resources.map((resource) => ({ - label: resource.name, - description: resource.description, - icon: , - action: () => handleResourceSelect(resource) - })) - ) + // Avoid mcp following the software startup, affecting the startup speed + if (isFirstResourcesListCall) { + await delay(initMcpDelay) + isFirstResourcesListCall = false + fetchResources() } - } - fetchResources() - - return () => { - isMounted = false - } - }, [activedMcpServers, handleResourceSelect]) + return () => { + isMounted = false + } + }) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [activedMcpServers]) const openResourcesList = useCallback(async () => { const resources = resourcesList From bdbb937403cfd90f8cf73178c106b27d0bc1d604 Mon Sep 17 00:00:00 2001 From: kangfenmao Date: Sun, 18 May 2025 21:04:36 +0800 Subject: [PATCH 12/17] refactor: update spinner handling and improve initialization timing * Modified the Content Security Policy to include 'unsafe-inline' for script-src. * Changed the spinner display style from 'none' to 'flex' for better visibility. * Removed the initSpinner function and directly initialized the spinner in useAppInit. * Added console timing for initialization to track performance. --- src/renderer/index.html | 7 +++++-- src/renderer/src/hooks/useAppInit.ts | 6 +++++- src/renderer/src/init.ts | 8 -------- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/renderer/index.html b/src/renderer/index.html index eebceeac66..c8832dc573 100644 --- a/src/renderer/index.html +++ b/src/renderer/index.html @@ -5,7 +5,7 @@ + content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' 'unsafe-inline' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; frame-src * file:" /> Cherry Studio