diff --git a/package.json b/package.json index d550355d04..c6eac03d52 100644 --- a/package.json +++ b/package.json @@ -213,7 +213,7 @@ "styled-components": "^6.1.11", "tar": "^7.4.3", "tiny-pinyin": "^1.3.2", - "tokenx": "^0.4.1", + "tokenx": "^1.1.0", "typescript": "^5.6.2", "uuid": "^10.0.0", "vite": "6.2.6", diff --git a/src/main/loader/noteLoader.ts b/src/main/loader/noteLoader.ts new file mode 100644 index 0000000000..693f5f3c0a --- /dev/null +++ b/src/main/loader/noteLoader.ts @@ -0,0 +1,44 @@ +import { BaseLoader } from '@cherrystudio/embedjs-interfaces' +import { cleanString } from '@cherrystudio/embedjs-utils' +import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters' +import md5 from 'md5' + +export class NoteLoader extends BaseLoader<{ type: 'NoteLoader' }> { + private readonly text: string + private readonly sourceUrl?: string + + constructor({ + text, + sourceUrl, + chunkSize, + chunkOverlap + }: { + text: string + sourceUrl?: string + chunkSize?: number + chunkOverlap?: number + }) { + super(`NoteLoader_${md5(text + (sourceUrl || ''))}`, { text, sourceUrl }, chunkSize ?? 2000, chunkOverlap ?? 0) + this.text = text + this.sourceUrl = sourceUrl + } + + override async *getUnfilteredChunks() { + const chunker = new RecursiveCharacterTextSplitter({ + chunkSize: this.chunkSize, + chunkOverlap: this.chunkOverlap + }) + + const chunks = await chunker.splitText(cleanString(this.text)) + + for (const chunk of chunks) { + yield { + pageContent: chunk, + metadata: { + type: 'NoteLoader' as const, + source: this.sourceUrl || 'note' + } + } + } + } +} diff --git a/src/main/services/KnowledgeService.ts b/src/main/services/KnowledgeService.ts index 62b2bba08f..d2d381c598 100644 --- a/src/main/services/KnowledgeService.ts +++ b/src/main/services/KnowledgeService.ts @@ -16,13 +16,14 @@ import * as fs from 'node:fs' import path from 'node:path' -import { RAGApplication, RAGApplicationBuilder, TextLoader } from '@cherrystudio/embedjs' +import { RAGApplication, RAGApplicationBuilder } from '@cherrystudio/embedjs' import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import { LibSqlDb } from '@cherrystudio/embedjs-libsql' import { SitemapLoader } from '@cherrystudio/embedjs-loader-sitemap' import { WebLoader } from '@cherrystudio/embedjs-loader-web' import Embeddings from '@main/embeddings/Embeddings' import { addFileLoader } from '@main/loader' +import { NoteLoader } from '@main/loader/noteLoader' import Reranker from '@main/reranker/Reranker' import { windowService } from '@main/services/WindowService' import { getDataPath } from '@main/utils' @@ -143,7 +144,7 @@ class KnowledgeService { this.getRagApplication(base) } - public reset = async (_: Electron.IpcMainInvokeEvent, { base }: { base: KnowledgeBaseParams }): Promise => { + public reset = async (_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise => { const ragApplication = await this.getRagApplication(base) await ragApplication.reset() } @@ -333,6 +334,7 @@ class KnowledgeService { ): LoaderTask { const { base, item, forceReload } = options const content = item.content as string + const sourceUrl = (item as any).sourceUrl const encoder = new TextEncoder() const contentBytes = encoder.encode(content) @@ -342,7 +344,12 @@ class KnowledgeService { state: LoaderTaskItemState.PENDING, task: () => { const loaderReturn = ragApplication.addLoader( - new TextLoader({ text: content, chunkSize: base.chunkSize, chunkOverlap: base.chunkOverlap }), + new NoteLoader({ + text: content, + sourceUrl, + chunkSize: base.chunkSize, + chunkOverlap: base.chunkOverlap + }), forceReload ) as Promise diff --git a/src/renderer/src/components/Spinner.tsx b/src/renderer/src/components/Spinner.tsx index 74408fc50d..5495115056 100644 --- a/src/renderer/src/components/Spinner.tsx +++ b/src/renderer/src/components/Spinner.tsx @@ -1,6 +1,5 @@ import { Search } from 'lucide-react' import { motion } from 'motion/react' -import { useTranslation } from 'react-i18next' import styled from 'styled-components' interface Props { @@ -18,7 +17,6 @@ const spinnerVariants = { } export default function Spinner({ text }: Props) { - const { t } = useTranslation() return ( - {t(text)} + {text} ) } diff --git a/src/renderer/src/config/constant.ts b/src/renderer/src/config/constant.ts index 5a50daa9d1..1a8a0d0fee 100644 --- a/src/renderer/src/config/constant.ts +++ b/src/renderer/src/config/constant.ts @@ -3,6 +3,7 @@ export const DEFAULT_CONTEXTCOUNT = 5 export const DEFAULT_MAX_TOKENS = 4096 export const DEFAULT_KNOWLEDGE_DOCUMENT_COUNT = 6 export const DEFAULT_KNOWLEDGE_THRESHOLD = 0.0 +export const DEFAULT_WEBSEARCH_RAG_DOCUMENT_COUNT = 1 export const platform = window.electron?.process?.platform export const isMac = platform === 'darwin' diff --git a/src/renderer/src/hooks/useWebSearchProviders.ts b/src/renderer/src/hooks/useWebSearchProviders.ts index b018d2a65c..f5c1dda78c 100644 --- a/src/renderer/src/hooks/useWebSearchProviders.ts +++ b/src/renderer/src/hooks/useWebSearchProviders.ts @@ -1,9 +1,12 @@ import { useAppDispatch, useAppSelector } from '@renderer/store' import { addSubscribeSource as _addSubscribeSource, + type CompressionConfig, removeSubscribeSource as _removeSubscribeSource, + setCompressionConfig, setDefaultProvider as _setDefaultProvider, setSubscribeSources as _setSubscribeSources, + updateCompressionConfig, updateSubscribeBlacklist as _updateSubscribeBlacklist, updateWebSearchProvider, updateWebSearchProviders @@ -90,3 +93,14 @@ export const useBlacklist = () => { setSubscribeSources } } + +export const useWebSearchSettings = () => { + const state = useAppSelector((state) => state.websearch) + const dispatch = useAppDispatch() + + return { + ...state, + setCompressionConfig: (config: CompressionConfig) => dispatch(setCompressionConfig(config)), + updateCompressionConfig: (config: Partial) => dispatch(updateCompressionConfig(config)) + } +} diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index b36b931f25..ecb7182753 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -703,6 +703,13 @@ "success.siyuan.export": "Successfully exported to Siyuan Note", "warn.yuque.exporting": "Exporting to Yuque, please do not request export repeatedly!", "warn.siyuan.exporting": "Exporting to Siyuan Note, please do not request export repeatedly!", + "websearch": { + "rag": "Executing RAG...", + "rag_complete": "Keeping {{countAfter}} out of {{countBefore}} results...", + "rag_failed": "RAG failed, returning empty results...", + "cutoff": "Truncating search content...", + "fetch_complete": "Completed {{count}} searches..." + }, "download.success": "Download successfully", "download.failed": "Download failed" }, @@ -776,6 +783,7 @@ "dimensions": "Dimensions {{dimensions}}", "edit": "Edit Model", "embedding": "Embedding", + "embedding_dimensions": "Embedding Dimensions", "embedding_model": "Embedding Model", "embedding_model_tooltip": "Add in Settings->Model Provider->Manage", "function_calling": "Function Calling", @@ -1845,8 +1853,33 @@ "overwrite_tooltip": "Force use search service instead of LLM", "apikey": "API key", "free": "Free", - "content_limit": "Content length limit", - "content_limit_tooltip": "Limit the content length of the search results; content that exceeds the limit will be truncated." + "compression": { + "title": "Search Result Compression", + "method": "Compression Method", + "method.none": "None", + "method.cutoff": "Cutoff", + "cutoff.limit": "Cutoff Limit", + "cutoff.limit.placeholder": "Enter length", + "cutoff.limit.tooltip": "Limit the content length of search results, content exceeding the limit will be truncated (e.g., 2000 characters)", + "cutoff.unit.char": "Char", + "cutoff.unit.token": "Token", + "method.rag": "RAG", + "rag.document_count": "Document Count", + "rag.document_count.default": "Default", + "rag.document_count.tooltip": "Expected number of documents to extract from each search result, the actual total number of extracted documents is this value multiplied by the number of search results.", + "rag.embedding_dimensions.auto_get": "Auto Get Dimensions", + "rag.embedding_dimensions.placeholder": "Leave empty", + "rag.embedding_dimensions.tooltip": "If left blank, the dimensions parameter will not be passed", + "info": { + "dimensions_auto_success": "Dimensions auto-obtained successfully, dimensions: {{dimensions}}" + }, + "error": { + "embedding_model_required": "Please select an embedding model first", + "dimensions_auto_failed": "Failed to auto-obtain dimensions", + "provider_not_found": "Provider not found", + "rag_failed": "RAG failed" + } + } }, "quickPhrase": { "title": "Quick Phrases", diff --git a/src/renderer/src/i18n/locales/ja-jp.json b/src/renderer/src/i18n/locales/ja-jp.json index a43b02177c..4b0d34eeb2 100644 --- a/src/renderer/src/i18n/locales/ja-jp.json +++ b/src/renderer/src/i18n/locales/ja-jp.json @@ -702,6 +702,13 @@ "warn.yuque.exporting": "語雀にエクスポート中です。重複してエクスポートしないでください!", "warn.siyuan.exporting": "思源ノートにエクスポート中です。重複してエクスポートしないでください!", "error.yuque.no_config": "語雀のAPIアドレスまたはトークンが設定されていません", + "websearch": { + "rag": "RAGを実行中...", + "rag_complete": "{{countBefore}}個の結果から{{countAfter}}個を保持...", + "rag_failed": "RAGが失敗しました。空の結果を返します...", + "cutoff": "検索内容を切り詰めています...", + "fetch_complete": "{{count}}回の検索を完了しました..." + }, "download.success": "ダウンロードに成功しました", "download.failed": "ダウンロードに失敗しました", "error.fetchTopicName": "トピック名の取得に失敗しました" @@ -776,6 +783,7 @@ "dimensions": "{{dimensions}} 次元", "edit": "モデルを編集", "embedding": "埋め込み", + "embedding_dimensions": "埋め込み次元", "embedding_model": "埋め込み模型", "embedding_model_tooltip": "設定->モデルサービス->管理で追加", "function_calling": "関数呼び出し", @@ -1826,8 +1834,33 @@ "overwrite_tooltip": "大規模言語モデルではなく、サービス検索を使用する", "apikey": "API キー", "free": "無料", - "content_limit": "内容の長さ制限", - "content_limit_tooltip": "検索結果の内容長を制限し、制限を超える内容は切り捨てられます。" + "compression": { + "title": "検索結果の圧縮", + "method": "圧縮方法", + "method.none": "圧縮しない", + "method.cutoff": "切り捨て", + "cutoff.limit": "切り捨て長", + "cutoff.limit.placeholder": "長さを入力", + "cutoff.limit.tooltip": "検索結果の内容長を制限し、制限を超える内容は切り捨てられます(例:2000文字)", + "cutoff.unit.char": "文字", + "cutoff.unit.token": "トークン", + "method.rag": "RAG", + "rag.document_count": "文書数", + "rag.document_count.default": "デフォルト", + "rag.document_count.tooltip": "単一の検索結果から抽出する文書数。実際に抽出される文書数は、この値に検索結果数を乗じたものです。", + "rag.embedding_dimensions.auto_get": "次元を自動取得", + "rag.embedding_dimensions.placeholder": "次元を設定しない", + "rag.embedding_dimensions.tooltip": "空の場合、dimensions パラメーターは渡されません", + "info": { + "dimensions_auto_success": "次元が自動取得されました。次元: {{dimensions}}" + }, + "error": { + "embedding_model_required": "まず埋め込みモデルを選択してください", + "dimensions_auto_failed": "次元の自動取得に失敗しました", + "provider_not_found": "プロバイダーが見つかりません", + "rag_failed": "RAG に失敗しました" + } + } }, "general.auto_check_update.title": "自動更新", "general.early_access.title": "早期アクセス", diff --git a/src/renderer/src/i18n/locales/ru-ru.json b/src/renderer/src/i18n/locales/ru-ru.json index 3eea3431f4..d7f20e297b 100644 --- a/src/renderer/src/i18n/locales/ru-ru.json +++ b/src/renderer/src/i18n/locales/ru-ru.json @@ -702,6 +702,13 @@ "success.siyuan.export": "Успешный экспорт в Siyuan", "warn.yuque.exporting": "Экспортируется в Yuque, пожалуйста, не отправляйте повторные запросы!", "warn.siyuan.exporting": "Экспортируется в Siyuan, пожалуйста, не отправляйте повторные запросы!", + "websearch": { + "rag": "Выполнение RAG...", + "rag_complete": "Сохранено {{countAfter}} из {{countBefore}} результатов...", + "rag_failed": "RAG не удалось, возвращается пустой результат...", + "cutoff": "Обрезка содержимого поиска...", + "fetch_complete": "Завершено {{count}} поисков..." + }, "download.success": "Скачано успешно", "download.failed": "Скачивание не удалось", "error.fetchTopicName": "Не удалось назвать топик" @@ -776,6 +783,7 @@ "dimensions": "{{dimensions}} мер", "edit": "Редактировать модель", "embedding": "Встраиваемые", + "embedding_dimensions": "Встраиваемые размерности", "embedding_model": "Встраиваемые модели", "embedding_model_tooltip": "Добавьте в настройки->модель сервиса->управление", "function_calling": "Вызов функции", @@ -1826,8 +1834,33 @@ "overwrite_tooltip": "Использовать провайдера поиска вместо LLM", "apikey": "API ключ", "free": "Бесплатно", - "content_limit": "Ограничение длины текста", - "content_limit_tooltip": "Ограничьте длину содержимого результатов поиска, контент, превышающий ограничение, будет обрезан." + "compression": { + "title": "Сжатие результатов поиска", + "method": "Метод сжатия", + "method.none": "Не сжимать", + "method.cutoff": "Обрезка", + "cutoff.limit": "Лимит обрезки", + "cutoff.limit.placeholder": "Введите длину", + "cutoff.limit.tooltip": "Ограничьте длину содержимого результатов поиска, контент, превышающий ограничение, будет обрезан (например, 2000 символов)", + "cutoff.unit.char": "Символы", + "cutoff.unit.token": "Токены", + "method.rag": "RAG", + "rag.document_count": "Количество документов", + "rag.document_count.default": "По умолчанию", + "rag.document_count.tooltip": "Ожидаемое количество документов, которые будут извлечены из каждого результата поиска. Фактическое количество извлеченных документов равно этому значению, умноженному на количество результатов поиска.", + "rag.embedding_dimensions.auto_get": "Автоматически получить размерности", + "rag.embedding_dimensions.placeholder": "Не устанавливать размерности", + "rag.embedding_dimensions.tooltip": "Если оставить пустым, параметр dimensions не будет передан", + "info": { + "dimensions_auto_success": "Размерности успешно получены, размерности: {{dimensions}}" + }, + "error": { + "embedding_model_required": "Пожалуйста, сначала выберите модель встраивания", + "dimensions_auto_failed": "Не удалось получить размерности", + "provider_not_found": "Поставщик не найден", + "rag_failed": "RAG не удалось" + } + } }, "general.auto_check_update.title": "Автоматическое обновление", "general.early_access.title": "Ранний доступ", diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index 7bb2f25608..6b662d38ee 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -703,6 +703,13 @@ "success.siyuan.export": "导出到思源笔记成功", "warn.yuque.exporting": "正在导出语雀, 请勿重复请求导出!", "warn.siyuan.exporting": "正在导出到思源笔记,请勿重复请求导出!", + "websearch": { + "rag": "正在执行 RAG...", + "rag_complete": "保留 {{countBefore}} 个结果中的 {{countAfter}} 个...", + "rag_failed": "RAG 失败,返回空结果...", + "cutoff": "正在截断搜索内容...", + "fetch_complete": "已完成 {{count}} 次搜索..." + }, "download.success": "下载成功", "download.failed": "下载失败" }, @@ -776,6 +783,7 @@ "dimensions": "{{dimensions}} 维", "edit": "编辑模型", "embedding": "嵌入", + "embedding_dimensions": "嵌入维度", "embedding_model": "嵌入模型", "embedding_model_tooltip": "在设置->模型服务中点击管理按钮添加", "function_calling": "函数调用", @@ -1845,8 +1853,33 @@ "title": "网络搜索", "apikey": "API 密钥", "free": "免费", - "content_limit": "内容长度限制", - "content_limit_tooltip": "限制搜索结果的内容长度, 超过限制的内容将被截断" + "compression": { + "title": "搜索结果压缩", + "method": "压缩方法", + "method.none": "不压缩", + "method.cutoff": "截断", + "cutoff.limit": "截断长度", + "cutoff.limit.placeholder": "输入长度", + "cutoff.limit.tooltip": "限制搜索结果的内容长度, 超过限制的内容将被截断(例如 2000 字符)", + "cutoff.unit.char": "字符", + "cutoff.unit.token": "Token", + "method.rag": "RAG", + "rag.document_count": "文档数量", + "rag.document_count.default": "默认", + "rag.document_count.tooltip": "预期从单个搜索结果中提取的文档数量,实际提取的总数量是这个值乘以搜索结果数量。", + "rag.embedding_dimensions.auto_get": "自动获取维度", + "rag.embedding_dimensions.placeholder": "不设置维度", + "rag.embedding_dimensions.tooltip": "留空则不传递 dimensions 参数", + "info": { + "dimensions_auto_success": "维度自动获取成功,维度为 {{dimensions}}" + }, + "error": { + "embedding_model_required": "请先选择嵌入模型", + "dimensions_auto_failed": "维度自动获取失败", + "provider_not_found": "未找到服务商", + "rag_failed": "RAG 失败" + } + } }, "quickPhrase": { "title": "快捷短语", diff --git a/src/renderer/src/i18n/locales/zh-tw.json b/src/renderer/src/i18n/locales/zh-tw.json index 949afc9639..44d99d4da9 100644 --- a/src/renderer/src/i18n/locales/zh-tw.json +++ b/src/renderer/src/i18n/locales/zh-tw.json @@ -703,6 +703,13 @@ "success.siyuan.export": "導出到思源筆記成功", "warn.yuque.exporting": "正在導出語雀,請勿重複請求導出!", "warn.siyuan.exporting": "正在導出到思源筆記,請勿重複請求導出!", + "websearch": { + "rag": "正在執行 RAG...", + "rag_complete": "保留 {{countBefore}} 個結果中的 {{countAfter}} 個...", + "rag_failed": "RAG 失敗,返回空結果...", + "cutoff": "正在截斷搜尋內容...", + "fetch_complete": "已完成 {{count}} 次搜尋..." + }, "download.success": "下載成功", "download.failed": "下載失敗" }, @@ -776,6 +783,7 @@ "dimensions": "{{dimensions}} 維", "edit": "編輯模型", "embedding": "嵌入", + "embedding_dimensions": "嵌入維度", "embedding_model": "嵌入模型", "embedding_model_tooltip": "在設定->模型服務中點選管理按鈕新增", "function_calling": "函數調用", @@ -1829,8 +1837,33 @@ "overwrite_tooltip": "強制使用搜尋服務商而不是大語言模型進行搜尋", "apikey": "API 金鑰", "free": "免費", - "content_limit": "內容長度限制", - "content_limit_tooltip": "限制搜尋結果的內容長度,超過限制的內容將被截斷" + "compression": { + "title": "搜尋結果壓縮", + "method": "壓縮方法", + "method.none": "不壓縮", + "method.cutoff": "截斷", + "cutoff.limit": "截斷長度", + "cutoff.limit.placeholder": "輸入長度", + "cutoff.limit.tooltip": "限制搜尋結果的內容長度,超過限制的內容將被截斷(例如 2000 字符)", + "cutoff.unit.char": "字符", + "cutoff.unit.token": "Token", + "method.rag": "RAG", + "rag.document_count": "文檔數量", + "rag.document_count.default": "預設", + "rag.document_count.tooltip": "預期從單個搜尋結果中提取的文檔數量,實際提取的總數量是這個值乘以搜尋結果數量。", + "rag.embedding_dimensions.auto_get": "自動獲取維度", + "rag.embedding_dimensions.placeholder": "不設置維度", + "rag.embedding_dimensions.tooltip": "留空則不傳遞 dimensions 參數", + "info": { + "dimensions_auto_success": "維度自動獲取成功,維度為 {{dimensions}}" + }, + "error": { + "embedding_model_required": "請先選擇嵌入模型", + "dimensions_auto_failed": "維度自動獲取失敗", + "provider_not_found": "未找到服務商", + "rag_failed": "RAG 失敗" + } + } }, "general.auto_check_update.title": "自動更新", "general.early_access.title": "搶先體驗", diff --git a/src/renderer/src/pages/home/Messages/Blocks/CitationBlock.tsx b/src/renderer/src/pages/home/Messages/Blocks/CitationBlock.tsx index fdd4640d00..f86dd6de16 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/CitationBlock.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/CitationBlock.tsx @@ -5,13 +5,19 @@ import { selectFormattedCitationsByBlockId } from '@renderer/store/messageBlock' import { WebSearchSource } from '@renderer/types' import { type CitationMessageBlock, MessageBlockStatus } from '@renderer/types/newMessage' import React, { useMemo } from 'react' +import { useTranslation } from 'react-i18next' import { useSelector } from 'react-redux' import styled from 'styled-components' import CitationsList from '../CitationsList' function CitationBlock({ block }: { block: CitationMessageBlock }) { + const { t } = useTranslation() const formattedCitations = useSelector((state: RootState) => selectFormattedCitationsByBlockId(state, block.id)) + const { websearch } = useSelector((state: RootState) => state.runtime) + const message = useSelector((state: RootState) => state.messages.entities[block.messageId]) + const userMessageId = message?.askId || block.messageId // 如果没有 askId 则回退到 messageId + const hasGeminiBlock = block.response?.source === WebSearchSource.GEMINI const hasCitations = useMemo(() => { return ( @@ -21,8 +27,32 @@ function CitationBlock({ block }: { block: CitationMessageBlock }) { ) }, [formattedCitations, block.knowledge, hasGeminiBlock]) + const getWebSearchStatusText = (requestId: string) => { + const status = websearch.activeSearches[requestId] ?? { phase: 'default' } + + switch (status.phase) { + case 'fetch_complete': + return t('message.websearch.fetch_complete', { + count: status.countAfter ?? 0 + }) + case 'rag': + return t('message.websearch.rag') + case 'rag_complete': + return t('message.websearch.rag_complete', { + countBefore: status.countBefore ?? 0, + countAfter: status.countAfter ?? 0 + }) + case 'rag_failed': + return t('message.websearch.rag_failed') + case 'cutoff': + return t('message.websearch.cutoff') + default: + return t('message.searching') + } + } + if (block.status === MessageBlockStatus.PROCESSING) { - return + return } if (!hasCitations) { diff --git a/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx b/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx index 21be188ffc..fc12ba5aec 100644 --- a/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx +++ b/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx @@ -41,7 +41,8 @@ const PopupContainer: React.FC = ({ base, resolve }) => { const searchResults = await searchKnowledgeBase(value, base) setResults(searchResults) } catch (error) { - console.error('Search failed:', error) + console.error(`Failed to search knowledge base ${base.name}:`, error) + setResults([]) } finally { setLoading(false) } diff --git a/src/renderer/src/pages/settings/WebSearchSettings/BasicSettings.tsx b/src/renderer/src/pages/settings/WebSearchSettings/BasicSettings.tsx index 098c007ca8..2a85df4b7f 100644 --- a/src/renderer/src/pages/settings/WebSearchSettings/BasicSettings.tsx +++ b/src/renderer/src/pages/settings/WebSearchSettings/BasicSettings.tsx @@ -1,18 +1,16 @@ import { useTheme } from '@renderer/context/ThemeProvider' -import { useAppDispatch, useAppSelector } from '@renderer/store' -import { setContentLimit, setMaxResult, setSearchWithTime } from '@renderer/store/websearch' -import { Input, Slider, Switch, Tooltip } from 'antd' +import { useWebSearchSettings } from '@renderer/hooks/useWebSearchProviders' +import { useAppDispatch } from '@renderer/store' +import { setMaxResult, setSearchWithTime } from '@renderer/store/websearch' +import { Slider, Switch } from 'antd' import { t } from 'i18next' -import { Info } from 'lucide-react' import { FC } from 'react' import { SettingDivider, SettingGroup, SettingRow, SettingRowTitle, SettingTitle } from '..' const BasicSettings: FC = () => { const { theme } = useTheme() - const searchWithTime = useAppSelector((state) => state.websearch.searchWithTime) - const maxResults = useAppSelector((state) => state.websearch.maxResults) - const contentLimit = useAppSelector((state) => state.websearch.contentLimit) + const { searchWithTime, maxResults } = useWebSearchSettings() const dispatch = useAppDispatch() @@ -38,28 +36,6 @@ const BasicSettings: FC = () => { onChangeComplete={(value) => dispatch(setMaxResult(value))} /> - - - - {t('settings.websearch.content_limit')} - - - - - { - const value = e.target.value - if (value === '') { - dispatch(setContentLimit(undefined)) - } else if (!isNaN(Number(value)) && Number(value) > 0) { - dispatch(setContentLimit(Number(value))) - } - }} - /> - ) diff --git a/src/renderer/src/pages/settings/WebSearchSettings/CompressionSettings/CutoffSettings.tsx b/src/renderer/src/pages/settings/WebSearchSettings/CompressionSettings/CutoffSettings.tsx new file mode 100644 index 0000000000..add4024598 --- /dev/null +++ b/src/renderer/src/pages/settings/WebSearchSettings/CompressionSettings/CutoffSettings.tsx @@ -0,0 +1,60 @@ +import { useWebSearchSettings } from '@renderer/hooks/useWebSearchProviders' +import { SettingRow, SettingRowTitle } from '@renderer/pages/settings' +import { Input, Select, Space, Tooltip } from 'antd' +import { ChevronDown, Info } from 'lucide-react' +import { useTranslation } from 'react-i18next' + +const INPUT_BOX_WIDTH = '200px' + +const CutoffSettings = () => { + const { t } = useTranslation() + const { compressionConfig, updateCompressionConfig } = useWebSearchSettings() + + const handleCutoffLimitChange = (value: number | null) => { + updateCompressionConfig({ cutoffLimit: value || undefined }) + } + + const handleCutoffUnitChange = (unit: 'char' | 'token') => { + updateCompressionConfig({ cutoffUnit: unit }) + } + + const unitOptions = [ + { value: 'char', label: t('settings.websearch.compression.cutoff.unit.char') }, + { value: 'token', label: t('settings.websearch.compression.cutoff.unit.token') } + ] + + return ( + + + {t('settings.websearch.compression.cutoff.limit')} + + + + + + { + const value = e.target.value + if (value === '') { + handleCutoffLimitChange(null) + } else if (!isNaN(Number(value)) && Number(value) > 0) { + handleCutoffLimitChange(Number(value)) + } + }} + /> + } + /> + + + + + + {t('models.embedding_dimensions')} + + + + +
+ + +
+
+ + + + {t('models.rerank_model')} + } + /> + + + + {compressionConfig?.method === 'cutoff' && } + {compressionConfig?.method === 'rag' && } + + ) +} + +export default CompressionSettings diff --git a/src/renderer/src/pages/settings/WebSearchSettings/index.tsx b/src/renderer/src/pages/settings/WebSearchSettings/index.tsx index 6fab13d388..ecf2b8375b 100644 --- a/src/renderer/src/pages/settings/WebSearchSettings/index.tsx +++ b/src/renderer/src/pages/settings/WebSearchSettings/index.tsx @@ -9,6 +9,7 @@ import { useTranslation } from 'react-i18next' import { SettingContainer, SettingDivider, SettingGroup, SettingRow, SettingRowTitle, SettingTitle } from '..' import BasicSettings from './BasicSettings' import BlacklistSettings from './BlacklistSettings' +import CompressionSettings from './CompressionSettings' import WebSearchProviderSetting from './WebSearchProviderSetting' const WebSearchSettings: FC = () => { @@ -56,6 +57,7 @@ const WebSearchSettings: FC = () => { )} + ) diff --git a/src/renderer/src/providers/WebSearchProvider/BochaProvider.ts b/src/renderer/src/providers/WebSearchProvider/BochaProvider.ts index 6d2bc2401e..1a3d53d87c 100644 --- a/src/renderer/src/providers/WebSearchProvider/BochaProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/BochaProvider.ts @@ -26,15 +26,13 @@ export default class BochaProvider extends BaseWebSearchProvider { Authorization: `Bearer ${this.apiKey}` } - const contentLimit = websearch.contentLimit - const params: BochaSearchParams = { query, count: websearch.maxResults, exclude: websearch.excludeDomains.join(','), freshness: websearch.searchWithTime ? 'oneDay' : 'noLimit', - summary: false, - page: contentLimit ? Math.ceil(contentLimit / websearch.maxResults) : 1 + summary: true, + page: 1 } const response = await fetch(`${this.apiHost}/v1/web-search`, { @@ -58,7 +56,8 @@ export default class BochaProvider extends BaseWebSearchProvider { query: resp.data.queryContext.originalQuery, results: resp.data.webPages.value.map((result) => ({ title: result.name, - content: result.snippet, + // 优先使用 summary(更详细),如果没有则使用 snippet + content: result.summary || result.snippet || '', url: result.url })) } diff --git a/src/renderer/src/providers/WebSearchProvider/ExaProvider.ts b/src/renderer/src/providers/WebSearchProvider/ExaProvider.ts index 8f65449b05..7aee19609f 100644 --- a/src/renderer/src/providers/WebSearchProvider/ExaProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/ExaProvider.ts @@ -35,14 +35,9 @@ export default class ExaProvider extends BaseWebSearchProvider { return { query: response.autopromptString, results: response.results.slice(0, websearch.maxResults).map((result) => { - let content = result.text || '' - if (websearch.contentLimit && content.length > websearch.contentLimit) { - content = content.slice(0, websearch.contentLimit) + '...' - } - return { title: result.title || 'No title', - content: content, + content: result.text || '', url: result.url || '' } }) diff --git a/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts b/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts index 8f171dd3e5..8a09b76016 100644 --- a/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/LocalSearchProvider.ts @@ -55,11 +55,7 @@ export default class LocalSearchProvider extends BaseWebSearchProvider { // Fetch content for each URL concurrently const fetchPromises = validItems.map(async (item) => { // Logger.log(`Fetching content for ${item.url}...`) - const result = await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser, httpOptions) - if (websearch.contentLimit && result.content.length > websearch.contentLimit) { - result.content = result.content.slice(0, websearch.contentLimit) + '...' - } - return result + return await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser, httpOptions) }) // Wait for all fetches to complete diff --git a/src/renderer/src/providers/WebSearchProvider/SearxngProvider.ts b/src/renderer/src/providers/WebSearchProvider/SearxngProvider.ts index 926da8a248..82b95142f6 100644 --- a/src/renderer/src/providers/WebSearchProvider/SearxngProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/SearxngProvider.ts @@ -122,11 +122,7 @@ export default class SearxngProvider extends BaseWebSearchProvider { // Fetch content for each URL concurrently const fetchPromises = validItems.map(async (item) => { // Logger.log(`Fetching content for ${item.url}...`) - const result = await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser) - if (websearch.contentLimit && result.content.length > websearch.contentLimit) { - result.content = result.content.slice(0, websearch.contentLimit) + '...' - } - return result + return await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser) }) // Wait for all fetches to complete diff --git a/src/renderer/src/providers/WebSearchProvider/TavilyProvider.ts b/src/renderer/src/providers/WebSearchProvider/TavilyProvider.ts index e38b2661d9..225bce308f 100644 --- a/src/renderer/src/providers/WebSearchProvider/TavilyProvider.ts +++ b/src/renderer/src/providers/WebSearchProvider/TavilyProvider.ts @@ -31,14 +31,9 @@ export default class TavilyProvider extends BaseWebSearchProvider { return { query: result.query, results: result.results.slice(0, websearch.maxResults).map((result) => { - let content = result.content || '' - if (websearch.contentLimit && content.length > websearch.contentLimit) { - content = content.slice(0, websearch.contentLimit) + '...' - } - return { title: result.title || 'No title', - content: content, + content: result.content || '', url: result.url || '' } }) diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index fb2c073109..4704a8bfd3 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -157,8 +157,13 @@ async function fetchExternalTool( try { // Use the consolidated processWebsearch function WebSearchService.createAbortSignal(lastUserMessage.id) + const webSearchResponse = await WebSearchService.processWebsearch( + webSearchProvider!, + extractResults, + lastUserMessage.id + ) return { - results: await WebSearchService.processWebsearch(webSearchProvider!, extractResults), + results: webSearchResponse, source: WebSearchSource.WEBSEARCH } } catch (error) { diff --git a/src/renderer/src/services/KnowledgeService.ts b/src/renderer/src/services/KnowledgeService.ts index da7b939161..707e4df0b8 100644 --- a/src/renderer/src/services/KnowledgeService.ts +++ b/src/renderer/src/services/KnowledgeService.ts @@ -130,7 +130,7 @@ export const searchKnowledgeBase = async ( ) } catch (error) { Logger.error(`Error searching knowledge base ${base.name}:`, error) - return [] + throw error } } diff --git a/src/renderer/src/services/WebSearchService.ts b/src/renderer/src/services/WebSearchService.ts index cf773a7512..efb726a0aa 100644 --- a/src/renderer/src/services/WebSearchService.ts +++ b/src/renderer/src/services/WebSearchService.ts @@ -1,13 +1,38 @@ +import { DEFAULT_WEBSEARCH_RAG_DOCUMENT_COUNT } from '@renderer/config/constant' import Logger from '@renderer/config/logger' +import i18n from '@renderer/i18n' import WebSearchEngineProvider from '@renderer/providers/WebSearchProvider' import store from '@renderer/store' -import { WebSearchState } from '@renderer/store/websearch' -import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types' -import { hasObjectKey } from '@renderer/utils' +import { setWebSearchStatus } from '@renderer/store/runtime' +import { CompressionConfig, WebSearchState } from '@renderer/store/websearch' +import { + KnowledgeBase, + KnowledgeItem, + KnowledgeReference, + WebSearchProvider, + WebSearchProviderResponse, + WebSearchProviderResult, + WebSearchStatus +} from '@renderer/types' +import { hasObjectKey, uuid } from '@renderer/utils' import { addAbortController } from '@renderer/utils/abortController' +import { formatErrorMessage } from '@renderer/utils/error' import { ExtractResults } from '@renderer/utils/extract' import { fetchWebContents } from '@renderer/utils/fetch' +import { consolidateReferencesByUrl, selectReferences } from '@renderer/utils/websearch' import dayjs from 'dayjs' +import { LRUCache } from 'lru-cache' + +import { getKnowledgeBaseParams } from './KnowledgeService' +import { getKnowledgeSourceUrl, searchKnowledgeBase } from './KnowledgeService' + +interface RequestState { + signal: AbortSignal | null + searchBase?: KnowledgeBase + isPaused: boolean + createdAt: number +} + /** * 提供网络搜索相关功能的服务类 */ @@ -19,12 +44,47 @@ class WebSearchService { isPaused = false - createAbortSignal(key: string) { + // 管理不同请求的状态 + private requestStates = new LRUCache({ + max: 5, // 最多5个并发请求 + ttl: 1000 * 60 * 2, // 2分钟过期 + dispose: (requestState: RequestState, requestId: string) => { + if (!requestState.searchBase) return + window.api.knowledgeBase + .delete(requestState.searchBase.id) + .catch((error) => Logger.warn(`[WebSearchService] Failed to cleanup search base for ${requestId}:`, error)) + } + }) + + /** + * 获取或创建单个请求的状态 + * @param requestId 请求 ID(通常是消息 ID) + */ + private getRequestState(requestId: string): RequestState { + let state = this.requestStates.get(requestId) + if (!state) { + state = { + signal: null, + isPaused: false, + createdAt: Date.now() + } + this.requestStates.set(requestId, state) + } + return state + } + + createAbortSignal(requestId: string) { const controller = new AbortController() - this.signal = controller.signal - addAbortController(key, () => { - this.isPaused = true + this.signal = controller.signal // 保持向后兼容 + + const state = this.getRequestState(requestId) + state.signal = controller.signal + + addAbortController(requestId, () => { + this.isPaused = true // 保持向后兼容 + state.isPaused = true this.signal = null + this.requestStates.delete(requestId) controller.abort() }) return controller @@ -137,45 +197,338 @@ class WebSearchService { } } + /** + * 设置网络搜索状态 + */ + private async setWebSearchStatus(requestId: string, status: WebSearchStatus, delayMs?: number) { + store.dispatch(setWebSearchStatus({ requestId, status })) + if (delayMs) { + await new Promise((resolve) => setTimeout(resolve, delayMs)) + } + } + + /** + * 确保搜索压缩知识库存在并配置正确 + */ + private async ensureSearchBase( + config: CompressionConfig, + documentCount: number, + requestId: string + ): Promise { + const baseId = `websearch-compression-${requestId}` + const state = this.getRequestState(requestId) + + // 如果已存在且配置未变,直接复用 + if (state.searchBase && this.isConfigMatched(state.searchBase, config)) { + return state.searchBase + } + + // 清理旧的知识库 + if (state.searchBase) { + await window.api.knowledgeBase.delete(state.searchBase.id) + } + + if (!config.embeddingModel) { + throw new Error('Embedding model is required for RAG compression') + } + + // 创建新的知识库 + state.searchBase = { + id: baseId, + name: `WebSearch-RAG-${requestId}`, + model: config.embeddingModel, + rerankModel: config.rerankModel, + dimensions: config.embeddingDimensions, + documentCount, + items: [], + created_at: Date.now(), + updated_at: Date.now(), + version: 1 + } + + // 更新LRU cache + this.requestStates.set(requestId, state) + + // 创建知识库 + const baseParams = getKnowledgeBaseParams(state.searchBase) + await window.api.knowledgeBase.create(baseParams) + + return state.searchBase + } + + /** + * 检查配置是否匹配 + */ + private isConfigMatched(base: KnowledgeBase, config: CompressionConfig): boolean { + return ( + base.model.id === config.embeddingModel?.id && + base.rerankModel?.id === config.rerankModel?.id && + base.dimensions === config.embeddingDimensions + ) + } + + /** + * 对搜索知识库执行多问题查询并按分数排序 + * @param questions 问题列表 + * @param searchBase 搜索知识库 + * @returns 排序后的知识引用列表 + */ + private async querySearchBase(questions: string[], searchBase: KnowledgeBase): Promise { + // 1. 单独搜索每个问题 + const searchPromises = questions.map((question) => searchKnowledgeBase(question, searchBase)) + const allResults = await Promise.all(searchPromises) + + // 2. 合并所有结果并按分数排序 + const flatResults = allResults.flat().sort((a, b) => b.score - a.score) + + // 3. 去重,保留最高分的重复内容 + const seen = new Set() + const uniqueResults = flatResults.filter((item) => { + if (seen.has(item.pageContent)) { + return false + } + seen.add(item.pageContent) + return true + }) + + // 4. 转换为引用格式 + return await Promise.all( + uniqueResults.map(async (result, index) => ({ + id: index + 1, + content: result.pageContent, + sourceUrl: await getKnowledgeSourceUrl(result), + type: 'url' as const + })) + ) + } + + /** + * 使用RAG压缩搜索结果。 + * - 一次性将所有搜索结果添加到知识库 + * - 从知识库中 retrieve 相关结果 + * - 根据 sourceUrl 映射回原始搜索结果 + * + * @param questions 问题列表 + * @param rawResults 原始搜索结果 + * @param config 压缩配置 + * @param requestId 请求ID + * @returns 压缩后的搜索结果 + */ + private async compressWithSearchBase( + questions: string[], + rawResults: WebSearchProviderResult[], + config: CompressionConfig, + requestId: string + ): Promise { + // 根据搜索次数计算所需的文档数量 + const totalDocumentCount = + Math.max(0, rawResults.length) * (config.documentCount ?? DEFAULT_WEBSEARCH_RAG_DOCUMENT_COUNT) + + const searchBase = await this.ensureSearchBase(config, totalDocumentCount, requestId) + + // 1. 清空知识库 + await window.api.knowledgeBase.reset(getKnowledgeBaseParams(searchBase)) + + // 2. 一次性添加所有搜索结果到知识库 + const addPromises = rawResults.map(async (result) => { + const item: KnowledgeItem & { sourceUrl?: string } = { + id: uuid(), + type: 'note', + content: result.content, + sourceUrl: result.url, // 设置 sourceUrl 用于映射 + created_at: Date.now(), + updated_at: Date.now(), + processingStatus: 'pending' + } + + await window.api.knowledgeBase.add({ + base: getKnowledgeBaseParams(searchBase), + item + }) + }) + + // 等待所有结果添加完成 + await Promise.all(addPromises) + + // 3. 对知识库执行多问题搜索获取压缩结果 + const references = await this.querySearchBase(questions, searchBase) + + // 4. 使用 Round Robin 策略选择引用 + const selectedReferences = selectReferences(rawResults, references, totalDocumentCount) + + Logger.log('[WebSearchService] With RAG, the number of search results:', { + raw: rawResults.length, + retrieved: references.length, + selected: selectedReferences.length + }) + + // 5. 按 sourceUrl 分组并合并同源片段 + return consolidateReferencesByUrl(rawResults, selectedReferences) + } + + /** + * 使用截断方式压缩搜索结果,可以选择单位 char 或 token。 + * + * @param rawResults 原始搜索结果 + * @param config 压缩配置 + * @returns 截断后的搜索结果 + */ + private async compressWithCutoff( + rawResults: WebSearchProviderResult[], + config: CompressionConfig + ): Promise { + if (!config.cutoffLimit) { + Logger.warn('[WebSearchService] Cutoff limit is not set, skipping compression') + return rawResults + } + + const perResultLimit = Math.max(1, Math.floor(config.cutoffLimit / rawResults.length)) + + // 动态导入 tokenx + const { sliceByTokens } = await import('tokenx') + + return rawResults.map((result) => { + if (config.cutoffUnit === 'token') { + // 使用 token 截断 + const slicedContent = sliceByTokens(result.content, 0, perResultLimit) + return { + ...result, + content: slicedContent.length < result.content.length ? slicedContent + '...' : slicedContent + } + } else { + // 使用字符截断(默认行为) + return { + ...result, + content: + result.content.length > perResultLimit ? result.content.slice(0, perResultLimit) + '...' : result.content + } + } + }) + } + + /** + * 处理网络搜索请求的核心方法,处理过程中会设置运行时状态供 UI 使用。 + * + * 该方法执行以下步骤: + * - 验证输入参数并处理边界情况 + * - 处理特殊的summarize请求 + * - 并行执行多个搜索查询 + * - 聚合搜索结果并处理失败情况 + * - 根据配置应用结果压缩(RAG或截断) + * - 返回最终的搜索响应 + * + * @param webSearchProvider - 要使用的网络搜索提供商 + * @param extractResults - 包含搜索问题和链接的提取结果对象 + * @param requestId - 唯一的请求标识符,用于状态跟踪和资源管理 + * + * @returns 包含搜索结果的响应对象 + */ public async processWebsearch( webSearchProvider: WebSearchProvider, - extractResults: ExtractResults + extractResults: ExtractResults, + requestId: string ): Promise { + // 重置状态 + await this.setWebSearchStatus(requestId, { phase: 'default' }) + // 检查 websearch 和 question 是否有效 if (!extractResults.websearch?.question || extractResults.websearch.question.length === 0) { Logger.log('[processWebsearch] No valid question found in extractResults.websearch') return { results: [] } } + // 使用请求特定的signal,如果没有则回退到全局signal + const signal = this.getRequestState(requestId).signal || this.signal + const questions = extractResults.websearch.question const links = extractResults.websearch.links - const firstQuestion = questions[0] - if (firstQuestion === 'summarize' && links && links.length > 0) { - const contents = await fetchWebContents(links, undefined, undefined, { - signal: this.signal - }) - return { - query: 'summaries', - results: contents - } - } - const searchPromises = questions.map((q) => this.search(webSearchProvider, q, { signal: this.signal })) - const searchResults = await Promise.allSettled(searchPromises) - const aggregatedResults: any[] = [] + // 处理 summarize + if (questions[0] === 'summarize' && links && links.length > 0) { + const contents = await fetchWebContents(links, undefined, undefined, { signal }) + return { query: 'summaries', results: contents } + } + + const searchPromises = questions.map((q) => this.search(webSearchProvider, q, { signal })) + const searchResults = await Promise.allSettled(searchPromises) + + // 统计成功完成的搜索数量 + const successfulSearchCount = searchResults.filter((result) => result.status === 'fulfilled').length + if (successfulSearchCount > 1) { + await this.setWebSearchStatus( + requestId, + { + phase: 'fetch_complete', + countAfter: successfulSearchCount + }, + 1000 + ) + } + + let finalResults: WebSearchProviderResult[] = [] searchResults.forEach((result) => { if (result.status === 'fulfilled') { if (result.value.results) { - aggregatedResults.push(...result.value.results) + finalResults.push(...result.value.results) } } if (result.status === 'rejected') { throw result.reason } }) + + // 如果没有搜索结果,直接返回空结果 + if (finalResults.length === 0) { + await this.setWebSearchStatus(requestId, { phase: 'default' }) + return { + query: questions.join(' | '), + results: [] + } + } + + const { compressionConfig } = this.getWebSearchState() + + // RAG压缩处理 + if (compressionConfig?.method === 'rag' && requestId) { + await this.setWebSearchStatus(requestId, { phase: 'rag' }, 500) + + const originalCount = finalResults.length + + try { + finalResults = await this.compressWithSearchBase(questions, finalResults, compressionConfig, requestId) + await this.setWebSearchStatus( + requestId, + { + phase: 'rag_complete', + countBefore: originalCount, + countAfter: finalResults.length + }, + 1000 + ) + } catch (error) { + Logger.warn('[WebSearchService] RAG compression failed, will return empty results:', error) + window.message.error({ + key: 'websearch-rag-failed', + duration: 10, + content: `${i18n.t('settings.websearch.compression.error.rag_failed')}: ${formatErrorMessage(error)}` + }) + + finalResults = [] + await this.setWebSearchStatus(requestId, { phase: 'rag_failed' }, 1000) + } + } + // 截断压缩处理 + else if (compressionConfig?.method === 'cutoff' && compressionConfig.cutoffLimit) { + await this.setWebSearchStatus(requestId, { phase: 'cutoff' }, 500) + finalResults = await this.compressWithCutoff(finalResults, compressionConfig) + } + + // 重置状态 + await this.setWebSearchStatus(requestId, { phase: 'default' }) + return { query: questions.join(' | '), - results: aggregatedResults + results: finalResults } } } diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index ae5c52a9fb..4a26bf96c0 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -50,7 +50,7 @@ const persistedReducer = persistReducer( { key: 'cherry-studio', storage, - version: 115, + version: 116, blacklist: ['runtime', 'messages', 'messageBlocks'], migrate }, diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 0e9385de0c..c8a132180f 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -1631,6 +1631,31 @@ const migrateConfig = { if (state.settings) { state.settings.upgradeChannel = UpgradeChannel.LATEST } + return state + } catch (error) { + return state + } + }, + '116': (state: RootState) => { + try { + if (state.websearch) { + // migrate contentLimit to cutoffLimit + // @ts-ignore eslint-disable-next-line + if (state.websearch.contentLimit) { + state.websearch.compressionConfig = { + method: 'cutoff', + cutoffUnit: 'char', + // @ts-ignore eslint-disable-next-line + cutoffLimit: state.websearch.contentLimit + } + } else { + state.websearch.compressionConfig = { method: 'none', cutoffUnit: 'char' } + } + + // @ts-ignore eslint-disable-next-line + delete state.websearch.contentLimit + } + return state } catch (error) { return state diff --git a/src/renderer/src/store/runtime.ts b/src/renderer/src/store/runtime.ts index 5c84ab8000..d1e3752d10 100644 --- a/src/renderer/src/store/runtime.ts +++ b/src/renderer/src/store/runtime.ts @@ -1,6 +1,6 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit' import { AppLogo, UserAvatar } from '@renderer/config/env' -import type { MinAppType, Topic } from '@renderer/types' +import type { MinAppType, Topic, WebSearchStatus } from '@renderer/types' import type { UpdateInfo } from 'builder-util-runtime' export interface ChatState { @@ -13,6 +13,10 @@ export interface ChatState { newlyRenamedTopics: string[] } +export interface WebSearchState { + activeSearches: Record +} + export interface UpdateState { info: UpdateInfo | null checking: boolean @@ -39,6 +43,7 @@ export interface RuntimeState { update: UpdateState export: ExportState chat: ChatState + websearch: WebSearchState } export interface ExportState { @@ -72,6 +77,9 @@ const initialState: RuntimeState = { activeTopic: null, renamingTopics: [], newlyRenamedTopics: [] + }, + websearch: { + activeSearches: {} } } @@ -130,6 +138,17 @@ const runtimeSlice = createSlice({ }, setNewlyRenamedTopics: (state, action: PayloadAction) => { state.chat.newlyRenamedTopics = action.payload + }, + // WebSearch related actions + setActiveSearches: (state, action: PayloadAction>) => { + state.websearch.activeSearches = action.payload + }, + setWebSearchStatus: (state, action: PayloadAction<{ requestId: string; status: WebSearchStatus }>) => { + const { requestId, status } = action.payload + if (status.phase === 'default') { + delete state.websearch.activeSearches[requestId] + } + state.websearch.activeSearches[requestId] = status } } }) @@ -151,7 +170,10 @@ export const { setSelectedMessageIds, setActiveTopic, setRenamingTopics, - setNewlyRenamedTopics + setNewlyRenamedTopics, + // WebSearch related actions + setActiveSearches, + setWebSearchStatus } = runtimeSlice.actions export default runtimeSlice.reducer diff --git a/src/renderer/src/store/websearch.ts b/src/renderer/src/store/websearch.ts index 4f223ccbf1..ad6172065a 100644 --- a/src/renderer/src/store/websearch.ts +++ b/src/renderer/src/store/websearch.ts @@ -1,5 +1,5 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit' -import type { WebSearchProvider } from '@renderer/types' +import type { Model, WebSearchProvider } from '@renderer/types' export interface SubscribeSource { key: number url: string @@ -7,6 +7,16 @@ export interface SubscribeSource { blacklist?: string[] // 存储从该订阅源获取的黑名单 } +export interface CompressionConfig { + method: 'none' | 'cutoff' | 'rag' + cutoffLimit?: number + cutoffUnit?: 'char' | 'token' + embeddingModel?: Model + embeddingDimensions?: number // undefined表示自动获取 + documentCount?: number // 每个搜索结果的文档数量(只是预期值) + rerankModel?: Model +} + export interface WebSearchState { // 默认搜索提供商的ID /** @deprecated 支持在快捷菜单中自选搜索供应商,所以这个不再适用 */ @@ -24,12 +34,13 @@ export interface WebSearchState { // 是否覆盖服务商搜索 /** @deprecated 支持在快捷菜单中自选搜索供应商,所以这个不再适用 */ overwrite: boolean - contentLimit?: number + // 搜索结果压缩 + compressionConfig?: CompressionConfig // 具体供应商的配置 providerConfig: Record } -const initialState: WebSearchState = { +export const initialState: WebSearchState = { defaultProvider: 'local-bing', providers: [ { @@ -78,6 +89,10 @@ const initialState: WebSearchState = { excludeDomains: [], subscribeSources: [], overwrite: false, + compressionConfig: { + method: 'none', + cutoffUnit: 'char' + }, providerConfig: {} } @@ -150,8 +165,14 @@ const websearchSlice = createSlice({ state.providers.push(action.payload) } }, - setContentLimit: (state, action: PayloadAction) => { - state.contentLimit = action.payload + setCompressionConfig: (state, action: PayloadAction) => { + state.compressionConfig = action.payload + }, + updateCompressionConfig: (state, action: PayloadAction>) => { + state.compressionConfig = { + ...state.compressionConfig, + ...action.payload + } as CompressionConfig }, setProviderConfig: (state, action: PayloadAction>) => { state.providerConfig = action.payload @@ -176,7 +197,8 @@ export const { setSubscribeSources, setOverwrite, addWebSearchProvider, - setContentLimit, + setCompressionConfig, + updateCompressionConfig, setProviderConfig, updateProviderConfig } = websearchSlice.actions diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 340119fa31..3b4cc5cdc3 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -500,7 +500,6 @@ export type WebSearchProvider = { url?: string basicAuthUsername?: string basicAuthPassword?: string - contentLimit?: number usingBrowser?: boolean } @@ -542,6 +541,14 @@ export type WebSearchResponse = { source: WebSearchSource } +export type WebSearchPhase = 'default' | 'fetch_complete' | 'rag' | 'rag_complete' | 'rag_failed' | 'cutoff' + +export type WebSearchStatus = { + phase: WebSearchPhase + countBefore?: number + countAfter?: number +} + export type KnowledgeReference = { id: number content: string diff --git a/src/renderer/src/utils/__tests__/websearch.test.ts b/src/renderer/src/utils/__tests__/websearch.test.ts new file mode 100644 index 0000000000..2f807d111e --- /dev/null +++ b/src/renderer/src/utils/__tests__/websearch.test.ts @@ -0,0 +1,226 @@ +import { KnowledgeReference, WebSearchProviderResult } from '@renderer/types' +import { describe, expect, it } from 'vitest' + +import { consolidateReferencesByUrl, selectReferences } from '../websearch' + +describe('websearch', () => { + describe('consolidateReferencesByUrl', () => { + const createMockRawResult = (url: string, title: string): WebSearchProviderResult => ({ + title, + url, + content: `Original content for ${title}` + }) + + const createMockReference = (sourceUrl: string, content: string, id: number = 1): KnowledgeReference => ({ + id, + sourceUrl, + content, + type: 'url' + }) + + it('should consolidate single reference to matching raw result', () => { + // 基本功能:单个引用与原始结果匹配 + const rawResults = [createMockRawResult('https://example.com', 'Example Title')] + const references = [createMockReference('https://example.com', 'Retrieved content')] + + const result = consolidateReferencesByUrl(rawResults, references) + + expect(result).toHaveLength(1) + expect(result[0]).toEqual({ + title: 'Example Title', + url: 'https://example.com', + content: 'Retrieved content' + }) + }) + + it('should consolidate multiple references from same source URL', () => { + // 多个片段合并到同一个URL + const rawResults = [createMockRawResult('https://example.com', 'Example Title')] + const references = [ + createMockReference('https://example.com', 'First content', 1), + createMockReference('https://example.com', 'Second content', 2) + ] + + const result = consolidateReferencesByUrl(rawResults, references) + + expect(result).toHaveLength(1) + expect(result[0]).toEqual({ + title: 'Example Title', + url: 'https://example.com', + content: 'First content\n\n---\n\nSecond content' + }) + }) + + it('should consolidate references from multiple source URLs', () => { + // 多个不同URL的引用 + const rawResults = [ + createMockRawResult('https://example.com', 'Example Title'), + createMockRawResult('https://test.com', 'Test Title') + ] + const references = [ + createMockReference('https://example.com', 'Example content', 1), + createMockReference('https://test.com', 'Test content', 2) + ] + + const result = consolidateReferencesByUrl(rawResults, references) + + expect(result).toHaveLength(2) + // 结果顺序可能不确定,使用 toContainEqual + expect(result).toContainEqual({ + title: 'Example Title', + url: 'https://example.com', + content: 'Example content' + }) + expect(result).toContainEqual({ + title: 'Test Title', + url: 'https://test.com', + content: 'Test content' + }) + }) + + it('should use custom separator for multiple references', () => { + // 自定义分隔符 + const rawResults = [createMockRawResult('https://example.com', 'Example Title')] + const references = [ + createMockReference('https://example.com', 'First content', 1), + createMockReference('https://example.com', 'Second content', 2) + ] + + const result = consolidateReferencesByUrl(rawResults, references, ' | ') + + expect(result).toHaveLength(1) + expect(result[0].content).toBe('First content | Second content') + }) + + it('should ignore references with no matching raw result', () => { + // 无匹配的引用 + const rawResults = [createMockRawResult('https://example.com', 'Example Title')] + const references = [ + createMockReference('https://example.com', 'Matching content', 1), + createMockReference('https://nonexistent.com', 'Non-matching content', 2) + ] + + const result = consolidateReferencesByUrl(rawResults, references) + + expect(result).toHaveLength(1) + expect(result[0]).toEqual({ + title: 'Example Title', + url: 'https://example.com', + content: 'Matching content' + }) + }) + + it('should return empty array when no references match raw results', () => { + // 完全无匹配的情况 + const rawResults = [createMockRawResult('https://example.com', 'Example Title')] + const references = [createMockReference('https://nonexistent.com', 'Non-matching content', 1)] + + const result = consolidateReferencesByUrl(rawResults, references) + + expect(result).toHaveLength(0) + }) + + it('should handle empty inputs', () => { + // 边界条件:空输入 + expect(consolidateReferencesByUrl([], [])).toEqual([]) + + const rawResults = [createMockRawResult('https://example.com', 'Example Title')] + expect(consolidateReferencesByUrl(rawResults, [])).toEqual([]) + + const references = [createMockReference('https://example.com', 'Content', 1)] + expect(consolidateReferencesByUrl([], references)).toEqual([]) + }) + + it('should preserve original result metadata', () => { + // 验证原始结果的元数据保持不变 + const rawResults = [createMockRawResult('https://example.com', 'Complex Title with Special Characters & Symbols')] + const references = [createMockReference('https://example.com', 'New content', 1)] + + const result = consolidateReferencesByUrl(rawResults, references) + + expect(result[0].title).toBe('Complex Title with Special Characters & Symbols') + expect(result[0].url).toBe('https://example.com') + }) + }) + + describe('selectReferences', () => { + const createMockRawResult = (url: string, title: string): WebSearchProviderResult => ({ + title, + url, + content: `Original content for ${title}` + }) + + const createMockReference = (sourceUrl: string, content: string, id: number = 1): KnowledgeReference => ({ + id, + sourceUrl, + content, + type: 'url' + }) + + it('should select references using round robin strategy', () => { + const rawResults = [ + createMockRawResult('https://a.com', 'A'), + createMockRawResult('https://b.com', 'B'), + createMockRawResult('https://c.com', 'C') + ] + + const references = [ + createMockReference('https://a.com', 'A1', 1), + createMockReference('https://a.com', 'A2', 2), + createMockReference('https://b.com', 'B1', 3), + createMockReference('https://c.com', 'C1', 4), + createMockReference('https://c.com', 'C2', 5) + ] + + const result = selectReferences(rawResults, references, 4) + + expect(result).toHaveLength(4) + // 按照 rawResults 顺序轮询:A1, B1, C1, A2 + expect(result[0].content).toBe('A1') + expect(result[1].content).toBe('B1') + expect(result[2].content).toBe('C1') + expect(result[3].content).toBe('A2') + }) + + it('should handle maxRefs larger than available references', () => { + const rawResults = [createMockRawResult('https://a.com', 'A')] + const references = [createMockReference('https://a.com', 'A1', 1)] + + const result = selectReferences(rawResults, references, 10) + + expect(result).toHaveLength(1) + expect(result[0].content).toBe('A1') + }) + + it('should return empty array for edge cases', () => { + const rawResults = [createMockRawResult('https://a.com', 'A')] + const references = [createMockReference('https://a.com', 'A1', 1)] + + // maxRefs is 0 + expect(selectReferences(rawResults, references, 0)).toEqual([]) + + // empty references + expect(selectReferences(rawResults, [], 5)).toEqual([]) + + // no matching URLs + const nonMatchingRefs = [createMockReference('https://different.com', 'Content', 1)] + expect(selectReferences(rawResults, nonMatchingRefs, 5)).toEqual([]) + }) + + it('should preserve rawResults order in round robin', () => { + // rawResults 的顺序应该影响轮询顺序 + const rawResults = [ + createMockRawResult('https://z.com', 'Z'), // 应该第一个被选择 + createMockRawResult('https://a.com', 'A') // 应该第二个被选择 + ] + + const references = [createMockReference('https://a.com', 'A1', 1), createMockReference('https://z.com', 'Z1', 2)] + + const result = selectReferences(rawResults, references, 2) + + expect(result).toHaveLength(2) + expect(result[0].content).toBe('Z1') // Z 先被选择 + expect(result[1].content).toBe('A1') // A 后被选择 + }) + }) +}) diff --git a/src/renderer/src/utils/websearch.ts b/src/renderer/src/utils/websearch.ts new file mode 100644 index 0000000000..05f82861c8 --- /dev/null +++ b/src/renderer/src/utils/websearch.ts @@ -0,0 +1,116 @@ +import { KnowledgeReference, WebSearchProviderResult } from '@renderer/types' + +/** + * 将检索到的知识片段按源URL整合为搜索结果 + * + * 这个函数接收原始搜索结果和从知识库检索到的相关片段, + * 将同源的片段按URL分组并合并为最终的搜索结果。 + * + * @param rawResults 原始搜索结果,用于提供标题和URL信息 + * @param references 从知识库检索到的相关片段 + * @param separator 合并片段时使用的分隔符,默认为 '\n\n---\n\n' + * @returns 合并后的搜索结果数组 + */ +export function consolidateReferencesByUrl( + rawResults: WebSearchProviderResult[], + references: KnowledgeReference[], + separator: string = '\n\n---\n\n' +): WebSearchProviderResult[] { + // 创建URL到原始结果的映射,用于快速查找 + const urlToOriginalResult = new Map(rawResults.map((result) => [result.url, result])) + + // 使用 reduce 进行分组和内容收集 + const sourceGroups = references.reduce((groups, reference) => { + const originalResult = urlToOriginalResult.get(reference.sourceUrl) + if (!originalResult) return groups + + const existing = groups.get(reference.sourceUrl) + if (existing) { + // 如果已存在该URL的分组,直接添加内容 + existing.contents.push(reference.content) + } else { + // 创建新的分组 + groups.set(reference.sourceUrl, { + originalResult, + contents: [reference.content] + }) + } + return groups + }, new Map()) + + // 转换为最终结果 + return Array.from(sourceGroups.values(), (group) => ({ + title: group.originalResult.title, + url: group.originalResult.url, + content: group.contents.join(separator) + })) +} + +/** + * 使用 Round Robin 策略从引用中选择指定数量的项目 + * 按照原始搜索结果的顺序轮询选择,确保每个源都有机会被选中 + * + * @param rawResults 原始搜索结果,用于确定轮询顺序 + * @param references 所有可选的引用项目 + * @param maxRefs 最大选择数量 + * @returns 按 Round Robin 策略选择的引用数组 + */ +export function selectReferences( + rawResults: WebSearchProviderResult[], + references: KnowledgeReference[], + maxRefs: number +): KnowledgeReference[] { + if (maxRefs <= 0 || references.length === 0) { + return [] + } + + // 建立URL到索引的映射,用于确定轮询顺序 + const urlToIndex = new Map() + rawResults.forEach((result, index) => { + urlToIndex.set(result.url, index) + }) + + // 按sourceUrl分组references,每组内按原顺序保持(已按分数排序) + const groupsByUrl = new Map() + references.forEach((ref) => { + if (!groupsByUrl.has(ref.sourceUrl)) { + groupsByUrl.set(ref.sourceUrl, []) + } + groupsByUrl.get(ref.sourceUrl)!.push(ref) + }) + + // 获取有效的URL列表,按rawResults顺序排序 + const availableUrls = Array.from(groupsByUrl.keys()) + .filter((url) => urlToIndex.has(url)) + .sort((a, b) => urlToIndex.get(a)! - urlToIndex.get(b)!) + + if (availableUrls.length === 0) { + return [] + } + + // Round Robin 选择 + const selected: KnowledgeReference[] = [] + let roundIndex = 0 + + while (selected.length < maxRefs && availableUrls.length > 0) { + const currentUrl = availableUrls[roundIndex] + const group = groupsByUrl.get(currentUrl)! + + if (group.length > 0) { + selected.push(group.shift()!) + } + + // 如果当前组为空,从可用URL列表中移除 + if (group.length === 0) { + availableUrls.splice(roundIndex, 1) + // 调整索引,避免跳过下一个URL + if (roundIndex >= availableUrls.length) { + roundIndex = 0 + } + } else { + roundIndex = (roundIndex + 1) % availableUrls.length + } + } + + return selected +} diff --git a/yarn.lock b/yarn.lock index eefde56f9c..2386409f15 100644 --- a/yarn.lock +++ b/yarn.lock @@ -5745,7 +5745,7 @@ __metadata: styled-components: "npm:^6.1.11" tar: "npm:^7.4.3" tiny-pinyin: "npm:^1.3.2" - tokenx: "npm:^0.4.1" + tokenx: "npm:^1.1.0" turndown: "npm:7.2.0" typescript: "npm:^5.6.2" uuid: "npm:^10.0.0" @@ -17588,10 +17588,10 @@ __metadata: languageName: node linkType: hard -"tokenx@npm:^0.4.1": - version: 0.4.1 - resolution: "tokenx@npm:0.4.1" - checksum: 10c0/377f4e3c31ff9dc57b5b6af0fb1ae821227dee5e1d87b92a3ab1a0ed25454f01185c709d73592002b0d3024de1c904c8f029c46ae1806677816e4659fb8c481e +"tokenx@npm:^1.1.0": + version: 1.1.0 + resolution: "tokenx@npm:1.1.0" + checksum: 10c0/8214bce58b48e130bcf4a27ac1bb5abf486c395310fb0c8f54e31656acacf97da533372afb9e8ac8f7736e6c3f29af86ea9623d4875f1399e66a5203b80609db languageName: node linkType: hard