diff --git a/package.json b/package.json index c84bc483d4..a651fe00de 100644 --- a/package.json +++ b/package.json @@ -74,15 +74,23 @@ "@libsql/win32-x64-msvc": "^0.4.7", "@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch", "@strongtz/win32-arm64-msvc": "^0.4.7", + "cheerio": "^1.1.2", + "faiss-node": "^0.5.1", "graceful-fs": "^4.2.11", + "html-to-text": "^9.0.5", + "htmlparser2": "^10.0.0", "jsdom": "26.1.0", "node-stream-zip": "^1.15.0", "officeparser": "^4.2.0", "os-proxy-config": "^1.1.2", + "pdf-parse": "^1.1.1", + "react-player": "^3.3.1", + "react-youtube": "^10.1.0", "selection-hook": "^1.0.11", "sharp": "^0.34.3", "tesseract.js": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch", - "turndown": "7.2.0" + "turndown": "7.2.0", + "youtubei.js": "^15.0.1" }, "devDependencies": { "@agentic/exa": "^7.3.3", @@ -127,8 +135,10 @@ "@google/genai": "patch:@google/genai@npm%3A1.0.1#~/.yarn/patches/@google-genai-npm-1.0.1-e26f0f9af7.patch", "@hello-pangea/dnd": "^18.0.1", "@kangfenmao/keyv-storage": "^0.1.0", - "@langchain/community": "^0.3.36", + "@langchain/community": "^0.3.50", + "@langchain/core": "^0.3.68", "@langchain/ollama": "^0.2.1", + "@langchain/openai": "^0.6.7", "@mistralai/mistralai": "^1.7.5", "@modelcontextprotocol/sdk": "^1.17.0", "@mozilla/readability": "^0.6.0", @@ -171,9 +181,11 @@ "@types/cli-progress": "^3", "@types/fs-extra": "^11", "@types/he": "^1", + "@types/html-to-text": "^9", "@types/lodash": "^4.17.5", "@types/markdown-it": "^14", "@types/md5": "^2.3.5", + "@types/mime-types": "^3", "@types/node": "^22.17.1", "@types/pako": "^1.0.2", "@types/react": "^19.0.12", @@ -253,6 +265,7 @@ "markdown-it": "^14.1.0", "mermaid": "^11.10.1", "mime": "^4.0.4", + "mime-types": "^3.0.1", "motion": "^12.10.5", "notion-helper": "^1.3.22", "npx-scope-finder": "^1.2.0", diff --git a/packages/shared/config/types.ts b/packages/shared/config/types.ts index d46717b47e..8fb11cd8b3 100644 --- a/packages/shared/config/types.ts +++ b/packages/shared/config/types.ts @@ -7,7 +7,7 @@ export type LoaderReturn = { loaderType: string status?: ProcessingStatus message?: string - messageSource?: 'preprocess' | 'embedding' + messageSource?: 'preprocess' | 'embedding' | 'validation' } export type FileChangeEventType = 'add' | 'change' | 'unlink' | 'addDir' | 'unlinkDir' diff --git a/src/main/ipc.ts b/src/main/ipc.ts index 8b22fee49c..f989a3d808 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -24,7 +24,7 @@ import DxtService from './services/DxtService' import { ExportService } from './services/ExportService' import { fileStorage as fileManager } from './services/FileStorage' import FileService from './services/FileSystemService' -import KnowledgeService from './services/KnowledgeService' +import KnowledgeService from './services/knowledge/KnowledgeService' import mcpService from './services/MCPService' import MemoryService from './services/memory/MemoryService' import { openTraceWindow, setTraceWindowTitle } from './services/NodeTraceService' @@ -524,7 +524,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { } }) - // knowledge base ipcMain.handle(IpcChannel.KnowledgeBase_Create, KnowledgeService.create.bind(KnowledgeService)) ipcMain.handle(IpcChannel.KnowledgeBase_Reset, KnowledgeService.reset.bind(KnowledgeService)) ipcMain.handle(IpcChannel.KnowledgeBase_Delete, KnowledgeService.delete.bind(KnowledgeService)) diff --git a/src/main/knowledge/embeddings/Embeddings.ts b/src/main/knowledge/embedjs/embeddings/Embeddings.ts similarity index 100% rename from src/main/knowledge/embeddings/Embeddings.ts rename to src/main/knowledge/embedjs/embeddings/Embeddings.ts diff --git a/src/main/knowledge/embeddings/EmbeddingsFactory.ts b/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts similarity index 100% rename from src/main/knowledge/embeddings/EmbeddingsFactory.ts rename to src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts diff --git a/src/main/knowledge/embeddings/VoyageEmbeddings.ts b/src/main/knowledge/embedjs/embeddings/VoyageEmbeddings.ts similarity index 100% rename from src/main/knowledge/embeddings/VoyageEmbeddings.ts rename to src/main/knowledge/embedjs/embeddings/VoyageEmbeddings.ts diff --git a/src/main/knowledge/loader/draftsExportLoader.ts b/src/main/knowledge/embedjs/loader/draftsExportLoader.ts similarity index 100% rename from src/main/knowledge/loader/draftsExportLoader.ts rename to src/main/knowledge/embedjs/loader/draftsExportLoader.ts diff --git a/src/main/knowledge/loader/epubLoader.ts b/src/main/knowledge/embedjs/loader/epubLoader.ts similarity index 100% rename from src/main/knowledge/loader/epubLoader.ts rename to src/main/knowledge/embedjs/loader/epubLoader.ts diff --git a/src/main/knowledge/loader/index.ts b/src/main/knowledge/embedjs/loader/index.ts similarity index 100% rename from src/main/knowledge/loader/index.ts rename to src/main/knowledge/embedjs/loader/index.ts diff --git a/src/main/knowledge/loader/noteLoader.ts b/src/main/knowledge/embedjs/loader/noteLoader.ts similarity index 100% rename from src/main/knowledge/loader/noteLoader.ts rename to src/main/knowledge/embedjs/loader/noteLoader.ts diff --git a/src/main/knowledge/loader/odLoader.ts b/src/main/knowledge/embedjs/loader/odLoader.ts similarity index 100% rename from src/main/knowledge/loader/odLoader.ts rename to src/main/knowledge/embedjs/loader/odLoader.ts diff --git a/src/main/knowledge/langchain/embeddings/EmbeddingsFactory.ts b/src/main/knowledge/langchain/embeddings/EmbeddingsFactory.ts new file mode 100644 index 0000000000..d2879aa349 --- /dev/null +++ b/src/main/knowledge/langchain/embeddings/EmbeddingsFactory.ts @@ -0,0 +1,63 @@ +import { VoyageEmbeddings } from '@langchain/community/embeddings/voyage' +import type { Embeddings } from '@langchain/core/embeddings' +import { OllamaEmbeddings } from '@langchain/ollama' +import { AzureOpenAIEmbeddings, OpenAIEmbeddings } from '@langchain/openai' +import { ApiClient, SystemProviderIds } from '@types' + +import { isJinaEmbeddingsModel, JinaEmbeddings } from './JinaEmbeddings' + +export default class EmbeddingsFactory { + static create({ embedApiClient, dimensions }: { embedApiClient: ApiClient; dimensions?: number }): Embeddings { + const batchSize = 10 + const { model, provider, apiKey, apiVersion, baseURL } = embedApiClient + if (provider === SystemProviderIds.ollama) { + let baseUrl = baseURL + if (baseURL.includes('v1/')) { + baseUrl = baseURL.replace('v1/', '') + } + const headers = apiKey + ? { + Authorization: `Bearer ${apiKey}` + } + : undefined + return new OllamaEmbeddings({ + model: model, + baseUrl, + ...headers + }) + } else if (provider === SystemProviderIds.voyageai) { + return new VoyageEmbeddings({ + modelName: model, + apiKey, + outputDimension: dimensions, + batchSize + }) + } + if (isJinaEmbeddingsModel(model)) { + return new JinaEmbeddings({ + model, + apiKey, + batchSize, + dimensions, + baseUrl: baseURL + }) + } + if (apiVersion !== undefined) { + return new AzureOpenAIEmbeddings({ + azureOpenAIApiKey: apiKey, + azureOpenAIApiVersion: apiVersion, + azureOpenAIApiDeploymentName: model, + azureOpenAIEndpoint: baseURL, + dimensions, + batchSize + }) + } + return new OpenAIEmbeddings({ + model, + apiKey, + dimensions, + batchSize, + configuration: { baseURL } + }) + } +} diff --git a/src/main/knowledge/langchain/embeddings/JinaEmbeddings.ts b/src/main/knowledge/langchain/embeddings/JinaEmbeddings.ts new file mode 100644 index 0000000000..0a6c5f1f84 --- /dev/null +++ b/src/main/knowledge/langchain/embeddings/JinaEmbeddings.ts @@ -0,0 +1,199 @@ +import { Embeddings, type EmbeddingsParams } from '@langchain/core/embeddings' +import { chunkArray } from '@langchain/core/utils/chunk_array' +import { getEnvironmentVariable } from '@langchain/core/utils/env' +import z from 'zod/v4' + +const jinaModelSchema = z.union([ + z.literal('jina-clip-v2'), + z.literal('jina-embeddings-v3'), + z.literal('jina-colbert-v2'), + z.literal('jina-clip-v1'), + z.literal('jina-colbert-v1-en'), + z.literal('jina-embeddings-v2-base-es'), + z.literal('jina-embeddings-v2-base-code'), + z.literal('jina-embeddings-v2-base-de'), + z.literal('jina-embeddings-v2-base-zh'), + z.literal('jina-embeddings-v2-base-en') +]) + +type JinaModel = z.infer + +export const isJinaEmbeddingsModel = (model: string): model is JinaModel => { + return jinaModelSchema.safeParse(model).success +} + +interface JinaEmbeddingsParams extends EmbeddingsParams { + /** Model name to use */ + model: JinaModel + + baseUrl?: string + + /** + * Timeout to use when making requests to Jina. + */ + timeout?: number + + /** + * The maximum number of documents to embed in a single request. + */ + batchSize?: number + + /** + * Whether to strip new lines from the input text. + */ + stripNewLines?: boolean + + /** + * The dimensions of the embedding. + */ + dimensions?: number + + /** + * Scales the embedding so its Euclidean (L2) norm becomes 1, preserving direction. Useful when downstream involves dot-product, classification, visualization.. + */ + normalized?: boolean +} + +type JinaMultiModelInput = + | { + text: string + image?: never + } + | { + image: string + text?: never + } + +type JinaEmbeddingsInput = string | JinaMultiModelInput + +interface EmbeddingCreateParams { + model: JinaEmbeddingsParams['model'] + + /** + * input can be strings or JinaMultiModelInputs,if you want embed image,you should use JinaMultiModelInputs + */ + input: JinaEmbeddingsInput[] + dimensions: number + task?: 'retrieval.query' | 'retrieval.passage' +} + +interface EmbeddingResponse { + model: string + object: string + usage: { + total_tokens: number + prompt_tokens: number + } + data: { + object: string + index: number + embedding: number[] + }[] +} + +interface EmbeddingErrorResponse { + detail: string +} + +export class JinaEmbeddings extends Embeddings implements JinaEmbeddingsParams { + model: JinaEmbeddingsParams['model'] = 'jina-clip-v2' + + batchSize = 24 + + baseUrl = 'https://api.jina.ai/v1/embeddings' + + stripNewLines = true + + dimensions = 1024 + + apiKey: string + + constructor( + fields?: Partial & { + apiKey?: string + } + ) { + const fieldsWithDefaults = { maxConcurrency: 2, ...fields } + super(fieldsWithDefaults) + + const apiKey = + fieldsWithDefaults?.apiKey || getEnvironmentVariable('JINA_API_KEY') || getEnvironmentVariable('JINA_AUTH_TOKEN') + + if (!apiKey) throw new Error('Jina API key not found') + + this.apiKey = apiKey + this.baseUrl = fieldsWithDefaults?.baseUrl ? `${fieldsWithDefaults?.baseUrl}embeddings` : this.baseUrl + this.model = fieldsWithDefaults?.model ?? this.model + this.dimensions = fieldsWithDefaults?.dimensions ?? this.dimensions + this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize + this.stripNewLines = fieldsWithDefaults?.stripNewLines ?? this.stripNewLines + } + + private doStripNewLines(input: JinaEmbeddingsInput[]) { + if (this.stripNewLines) { + return input.map((i) => { + if (typeof i === 'string') { + return i.replace(/\n/g, ' ') + } + if (i.text) { + return { text: i.text.replace(/\n/g, ' ') } + } + return i + }) + } + return input + } + + async embedDocuments(input: JinaEmbeddingsInput[]): Promise { + const batches = chunkArray(this.doStripNewLines(input), this.batchSize) + const batchRequests = batches.map((batch) => { + const params = this.getParams(batch) + return this.embeddingWithRetry(params) + }) + + const batchResponses = await Promise.all(batchRequests) + const embeddings: number[][] = [] + + for (let i = 0; i < batchResponses.length; i += 1) { + const batch = batches[i] + const batchResponse = batchResponses[i] || [] + for (let j = 0; j < batch.length; j += 1) { + embeddings.push(batchResponse[j]) + } + } + + return embeddings + } + + async embedQuery(input: JinaEmbeddingsInput): Promise { + const params = this.getParams(this.doStripNewLines([input]), true) + + const embeddings = (await this.embeddingWithRetry(params)) || [[]] + return embeddings[0] + } + + private getParams(input: JinaEmbeddingsInput[], query?: boolean): EmbeddingCreateParams { + return { + model: this.model, + input, + dimensions: this.dimensions, + task: query ? 'retrieval.query' : this.model === 'jina-clip-v2' ? undefined : 'retrieval.passage' + } + } + + private async embeddingWithRetry(body: EmbeddingCreateParams) { + const response = await fetch(this.baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${this.apiKey}` + }, + body: JSON.stringify(body) + }) + const embeddingData: EmbeddingResponse | EmbeddingErrorResponse = await response.json() + if ('detail' in embeddingData && embeddingData.detail) { + throw new Error(`${embeddingData.detail}`) + } + return (embeddingData as EmbeddingResponse).data.map(({ embedding }) => embedding) + } +} diff --git a/src/main/knowledge/langchain/embeddings/TextEmbeddings.ts b/src/main/knowledge/langchain/embeddings/TextEmbeddings.ts new file mode 100644 index 0000000000..b788070d4d --- /dev/null +++ b/src/main/knowledge/langchain/embeddings/TextEmbeddings.ts @@ -0,0 +1,25 @@ +import type { Embeddings as BaseEmbeddings } from '@langchain/core/embeddings' +import { TraceMethod } from '@mcp-trace/trace-core' +import { ApiClient } from '@types' + +import EmbeddingsFactory from './EmbeddingsFactory' + +export default class TextEmbeddings { + private sdk: BaseEmbeddings + constructor({ embedApiClient, dimensions }: { embedApiClient: ApiClient; dimensions?: number }) { + this.sdk = EmbeddingsFactory.create({ + embedApiClient, + dimensions + }) + } + + @TraceMethod({ spanName: 'embedDocuments', tag: 'Embeddings' }) + public async embedDocuments(texts: string[]): Promise { + return this.sdk.embedDocuments(texts) + } + + @TraceMethod({ spanName: 'embedQuery', tag: 'Embeddings' }) + public async embedQuery(text: string): Promise { + return this.sdk.embedQuery(text) + } +} diff --git a/src/main/knowledge/langchain/loader/MarkdownLoader.ts b/src/main/knowledge/langchain/loader/MarkdownLoader.ts new file mode 100644 index 0000000000..287c9a0bda --- /dev/null +++ b/src/main/knowledge/langchain/loader/MarkdownLoader.ts @@ -0,0 +1,97 @@ +import { BaseDocumentLoader } from '@langchain/core/document_loaders/base' +import { Document } from '@langchain/core/documents' +import { readTextFileWithAutoEncoding } from '@main/utils/file' +import MarkdownIt from 'markdown-it' + +export class MarkdownLoader extends BaseDocumentLoader { + private path: string + private md: MarkdownIt + + constructor(path: string) { + super() + this.path = path + this.md = new MarkdownIt() + } + public async load(): Promise { + const content = await readTextFileWithAutoEncoding(this.path) + return this.parseMarkdown(content) + } + + private parseMarkdown(content: string): Document[] { + const tokens = this.md.parse(content, {}) + const documents: Document[] = [] + + let currentSection: { + heading?: string + level?: number + content: string + startLine?: number + } = { content: '' } + + let i = 0 + while (i < tokens.length) { + const token = tokens[i] + + if (token.type === 'heading_open') { + // Save previous section if it has content + if (currentSection.content.trim()) { + documents.push( + new Document({ + pageContent: currentSection.content.trim(), + metadata: { + source: this.path, + heading: currentSection.heading || 'Introduction', + level: currentSection.level || 0, + startLine: currentSection.startLine || 0 + } + }) + ) + } + + // Start new section + const level = parseInt(token.tag.slice(1)) // Extract number from h1, h2, etc. + const headingContent = tokens[i + 1]?.content || '' + + currentSection = { + heading: headingContent, + level: level, + content: '', + startLine: token.map?.[0] || 0 + } + + // Skip heading_open, inline, heading_close tokens + i += 3 + continue + } + + // Add token content to current section + if (token.content) { + currentSection.content += token.content + } + + // Add newlines for block tokens + if (token.block && token.type !== 'heading_close') { + currentSection.content += '\n' + } + + i++ + } + + // Add the last section + if (currentSection.content.trim()) { + documents.push( + new Document({ + pageContent: currentSection.content.trim(), + metadata: { + source: this.path, + heading: currentSection.heading || 'Introduction', + level: currentSection.level || 0, + startLine: currentSection.startLine || 0 + } + }) + ) + } + + return documents + } +} diff --git a/src/main/knowledge/langchain/loader/NoteLoader.ts b/src/main/knowledge/langchain/loader/NoteLoader.ts new file mode 100644 index 0000000000..d0339a6ce7 --- /dev/null +++ b/src/main/knowledge/langchain/loader/NoteLoader.ts @@ -0,0 +1,50 @@ +import { BaseDocumentLoader } from '@langchain/core/document_loaders/base' +import { Document } from '@langchain/core/documents' + +export class NoteLoader extends BaseDocumentLoader { + private text: string + private sourceUrl?: string + constructor( + public _text: string, + public _sourceUrl?: string + ) { + super() + this.text = _text + this.sourceUrl = _sourceUrl + } + + /** + * A protected method that takes a `raw` string as a parameter and returns + * a promise that resolves to an array containing the raw text as a single + * element. + * @param raw The raw text to be parsed. + * @returns A promise that resolves to an array containing the raw text as a single element. + */ + protected async parse(raw: string): Promise { + return [raw] + } + + public async load(): Promise { + const metadata = { source: this.sourceUrl || 'note' } + const parsed = await this.parse(this.text) + parsed.forEach((pageContent, i) => { + if (typeof pageContent !== 'string') { + throw new Error(`Expected string, at position ${i} got ${typeof pageContent}`) + } + }) + + return parsed.map( + (pageContent, i) => + new Document({ + pageContent, + metadata: + parsed.length === 1 + ? metadata + : { + ...metadata, + line: i + 1 + } + }) + ) + } +} diff --git a/src/main/knowledge/langchain/loader/YoutubeLoader.ts b/src/main/knowledge/langchain/loader/YoutubeLoader.ts new file mode 100644 index 0000000000..671793d3c7 --- /dev/null +++ b/src/main/knowledge/langchain/loader/YoutubeLoader.ts @@ -0,0 +1,170 @@ +import { BaseDocumentLoader } from '@langchain/core/document_loaders/base' +import { Document } from '@langchain/core/documents' +import { Innertube } from 'youtubei.js' + +// ... (接口定义 YoutubeConfig 和 VideoMetadata 保持不变) + +/** + * Configuration options for the YoutubeLoader class. Includes properties + * such as the videoId, language, and addVideoInfo. + */ +interface YoutubeConfig { + videoId: string + language?: string + addVideoInfo?: boolean + // 新增一个选项,用于控制输出格式 + transcriptFormat?: 'text' | 'srt' +} + +/** + * Metadata of a YouTube video. Includes properties such as the source + * (videoId), description, title, view_count, author, and category. + */ +interface VideoMetadata { + source: string + description?: string + title?: string + view_count?: number + author?: string + category?: string +} + +/** + * A document loader for loading data from YouTube videos. It uses the + * youtubei.js library to fetch the transcript and video metadata. + * @example + * ```typescript + * const loader = new YoutubeLoader({ + * videoId: "VIDEO_ID", + * language: "en", + * addVideoInfo: true, + * transcriptFormat: "srt" // 获取 SRT 格式 + * }); + * const docs = await loader.load(); + * console.log(docs[0].pageContent); + * ``` + */ +export class YoutubeLoader extends BaseDocumentLoader { + private videoId: string + private language?: string + private addVideoInfo: boolean + // 新增格式化选项的私有属性 + private transcriptFormat: 'text' | 'srt' + + constructor(config: YoutubeConfig) { + super() + this.videoId = config.videoId + this.language = config?.language + this.addVideoInfo = config?.addVideoInfo ?? false + // 初始化格式化选项,默认为 'text' 以保持向后兼容 + this.transcriptFormat = config?.transcriptFormat ?? 'text' + } + + /** + * Extracts the videoId from a YouTube video URL. + * @param url The URL of the YouTube video. + * @returns The videoId of the YouTube video. + */ + private static getVideoID(url: string): string { + const match = url.match(/.*(?:youtu.be\/|v\/|u\/\w\/|embed\/|watch\?v=)([^#&?]*).*/) + if (match !== null && match[1].length === 11) { + return match[1] + } else { + throw new Error('Failed to get youtube video id from the url') + } + } + + /** + * Creates a new instance of the YoutubeLoader class from a YouTube video + * URL. + * @param url The URL of the YouTube video. + * @param config Optional configuration options for the YoutubeLoader instance, excluding the videoId. + * @returns A new instance of the YoutubeLoader class. + */ + static createFromUrl(url: string, config?: Omit): YoutubeLoader { + const videoId = YoutubeLoader.getVideoID(url) + return new YoutubeLoader({ ...config, videoId }) + } + + /** + * [新增] 辅助函数:将毫秒转换为 SRT 时间戳格式 (HH:MM:SS,ms) + * @param ms 毫秒数 + * @returns 格式化后的时间字符串 + */ + private static formatTimestamp(ms: number): string { + const totalSeconds = Math.floor(ms / 1000) + const hours = Math.floor(totalSeconds / 3600) + .toString() + .padStart(2, '0') + const minutes = Math.floor((totalSeconds % 3600) / 60) + .toString() + .padStart(2, '0') + const seconds = (totalSeconds % 60).toString().padStart(2, '0') + const milliseconds = (ms % 1000).toString().padStart(3, '0') + return `${hours}:${minutes}:${seconds},${milliseconds}` + } + + /** + * Loads the transcript and video metadata from the specified YouTube + * video. It can return the transcript as plain text or in SRT format. + * @returns An array of Documents representing the retrieved data. + */ + async load(): Promise { + const metadata: VideoMetadata = { + source: this.videoId + } + + try { + const youtube = await Innertube.create({ + lang: this.language, + retrieve_player: false + }) + + const info = await youtube.getInfo(this.videoId) + const transcriptData = await info.getTranscript() + + if (!transcriptData.transcript.content?.body?.initial_segments) { + throw new Error('Transcript segments not found in the response.') + } + + const segments = transcriptData.transcript.content.body.initial_segments + + let pageContent: string + + // 根据 transcriptFormat 选项决定如何格式化字幕 + if (this.transcriptFormat === 'srt') { + // [修改] 将字幕片段格式化为 SRT 格式 + pageContent = segments + .map((segment, index) => { + const srtIndex = index + 1 + const startTime = YoutubeLoader.formatTimestamp(Number(segment.start_ms)) + const endTime = YoutubeLoader.formatTimestamp(Number(segment.end_ms)) + const text = segment.snippet?.text || '' // 使用 segment.snippet.text + + return `${srtIndex}\n${startTime} --> ${endTime}\n${text}` + }) + .join('\n\n') // 每个 SRT 块之间用两个换行符分隔 + } else { + // [原始逻辑] 拼接为纯文本 + pageContent = segments.map((segment) => segment.snippet?.text || '').join(' ') + } + + if (this.addVideoInfo) { + const basicInfo = info.basic_info + metadata.description = basicInfo.short_description + metadata.title = basicInfo.title + metadata.view_count = basicInfo.view_count + metadata.author = basicInfo.author + } + + const document = new Document({ + pageContent, + metadata + }) + + return [document] + } catch (e: unknown) { + throw new Error(`Failed to get YouTube video transcription: ${(e as Error).message}`) + } + } +} diff --git a/src/main/knowledge/langchain/loader/index.ts b/src/main/knowledge/langchain/loader/index.ts new file mode 100644 index 0000000000..f4718f71f8 --- /dev/null +++ b/src/main/knowledge/langchain/loader/index.ts @@ -0,0 +1,236 @@ +import { DocxLoader } from '@langchain/community/document_loaders/fs/docx' +import { EPubLoader } from '@langchain/community/document_loaders/fs/epub' +import { PDFLoader } from '@langchain/community/document_loaders/fs/pdf' +import { PPTXLoader } from '@langchain/community/document_loaders/fs/pptx' +import { CheerioWebBaseLoader } from '@langchain/community/document_loaders/web/cheerio' +import { SitemapLoader } from '@langchain/community/document_loaders/web/sitemap' +import { FaissStore } from '@langchain/community/vectorstores/faiss' +import { Document } from '@langchain/core/documents' +import { loggerService } from '@logger' +import { UrlSource } from '@main/utils/knowledge' +import { LoaderReturn } from '@shared/config/types' +import { FileMetadata, FileTypes, KnowledgeBaseParams } from '@types' +import { randomUUID } from 'crypto' +import { JSONLoader } from 'langchain/document_loaders/fs/json' +import { TextLoader } from 'langchain/document_loaders/fs/text' + +import { SplitterFactory } from '../splitter' +import { MarkdownLoader } from './MarkdownLoader' +import { NoteLoader } from './NoteLoader' +import { YoutubeLoader } from './YoutubeLoader' + +const logger = loggerService.withContext('KnowledgeService File Loader') + +type LoaderInstance = + | TextLoader + | PDFLoader + | PPTXLoader + | DocxLoader + | JSONLoader + | EPubLoader + | CheerioWebBaseLoader + | YoutubeLoader + | SitemapLoader + | NoteLoader + | MarkdownLoader + +/** + * 为文档数组中的每个文档的 metadata 添加类型信息。 + */ +function formatDocument(docs: Document[], type: string): Document[] { + return docs.map((doc) => ({ + ...doc, + metadata: { + ...doc.metadata, + type: type + } + })) +} + +/** + * 通用文档处理管道 + */ +async function processDocuments( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + docs: Document[], + loaderType: string, + splitterType?: string +): Promise { + const formattedDocs = formatDocument(docs, loaderType) + const splitter = SplitterFactory.create({ + chunkSize: base.chunkSize, + chunkOverlap: base.chunkOverlap, + ...(splitterType && { type: splitterType }) + }) + + const splitterResults = await splitter.splitDocuments(formattedDocs) + const ids = splitterResults.map(() => randomUUID()) + + await vectorStore.addDocuments(splitterResults, { ids }) + + return { + entriesAdded: splitterResults.length, + uniqueId: ids[0] || '', + uniqueIds: ids, + loaderType + } +} + +/** + * 通用加载器执行函数 + */ +async function executeLoader( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + loaderInstance: LoaderInstance, + loaderType: string, + identifier: string, + splitterType?: string +): Promise { + const emptyResult: LoaderReturn = { + entriesAdded: 0, + uniqueId: '', + uniqueIds: [], + loaderType + } + + try { + const docs = await loaderInstance.load() + return await processDocuments(base, vectorStore, docs, loaderType, splitterType) + } catch (error) { + logger.error(`Error loading or processing ${identifier} with loader ${loaderType}: ${error}`) + return emptyResult + } +} + +/** + * 文件扩展名到加载器的映射 + */ +const FILE_LOADER_MAP: Record LoaderInstance; type: string }> = { + '.pdf': { loader: PDFLoader, type: 'pdf' }, + '.txt': { loader: TextLoader, type: 'text' }, + '.pptx': { loader: PPTXLoader, type: 'pptx' }, + '.docx': { loader: DocxLoader, type: 'docx' }, + '.doc': { loader: DocxLoader, type: 'doc' }, + '.json': { loader: JSONLoader, type: 'json' }, + '.epub': { loader: EPubLoader, type: 'epub' }, + '.md': { loader: MarkdownLoader, type: 'markdown' } +} + +export async function addFileLoader( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + file: FileMetadata +): Promise { + const fileExt = file.ext.toLowerCase() + const loaderConfig = FILE_LOADER_MAP[fileExt] + + if (!loaderConfig) { + // 默认使用文本加载器 + const loaderInstance = new TextLoader(file.path) + const type = fileExt.replace('.', '') || 'unknown' + return executeLoader(base, vectorStore, loaderInstance, type, file.path) + } + + const loaderInstance = new loaderConfig.loader(file.path) + return executeLoader(base, vectorStore, loaderInstance, loaderConfig.type, file.path) +} + +export async function addWebLoader( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + url: string, + source: UrlSource +): Promise { + let loaderInstance: CheerioWebBaseLoader | YoutubeLoader | undefined + let splitterType: string | undefined + + switch (source) { + case 'normal': + loaderInstance = new CheerioWebBaseLoader(url) + break + case 'youtube': + loaderInstance = YoutubeLoader.createFromUrl(url, { + addVideoInfo: true, + transcriptFormat: 'srt', + language: 'zh' + }) + splitterType = 'srt' + break + } + + if (!loaderInstance) { + return { + entriesAdded: 0, + uniqueId: '', + uniqueIds: [], + loaderType: source + } + } + + return executeLoader(base, vectorStore, loaderInstance, source, url, splitterType) +} + +export async function addSitemapLoader( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + url: string +): Promise { + const loaderInstance = new SitemapLoader(url) + return executeLoader(base, vectorStore, loaderInstance, 'sitemap', url) +} + +export async function addNoteLoader( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + content: string, + sourceUrl: string +): Promise { + const loaderInstance = new NoteLoader(content, sourceUrl) + return executeLoader(base, vectorStore, loaderInstance, 'note', sourceUrl) +} + +export async function addVideoLoader( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + files: FileMetadata[] +): Promise { + const srtFile = files.find((f) => f.type === FileTypes.TEXT) + const videoFile = files.find((f) => f.type === FileTypes.VIDEO) + + const emptyResult: LoaderReturn = { + entriesAdded: 0, + uniqueId: '', + uniqueIds: [], + loaderType: 'video' + } + + if (!srtFile || !videoFile) { + return emptyResult + } + + try { + const loaderInstance = new TextLoader(srtFile.path) + const originalDocs = await loaderInstance.load() + + const docsWithVideoMeta = originalDocs.map( + (doc) => + new Document({ + ...doc, + metadata: { + ...doc.metadata, + video: { + path: videoFile.path, + name: videoFile.origin_name + } + } + }) + ) + + return await processDocuments(base, vectorStore, docsWithVideoMeta, 'video', 'srt') + } catch (error) { + logger.error(`Error loading or processing file ${srtFile.path} with loader video: ${error}`) + return emptyResult + } +} diff --git a/src/main/knowledge/langchain/retriever/index.ts b/src/main/knowledge/langchain/retriever/index.ts new file mode 100644 index 0000000000..7f673d1deb --- /dev/null +++ b/src/main/knowledge/langchain/retriever/index.ts @@ -0,0 +1,55 @@ +import { BM25Retriever } from '@langchain/community/retrievers/bm25' +import { FaissStore } from '@langchain/community/vectorstores/faiss' +import { BaseRetriever } from '@langchain/core/retrievers' +import { loggerService } from '@main/services/LoggerService' +import { type KnowledgeBaseParams } from '@types' +import { type Document } from 'langchain/document' +import { EnsembleRetriever } from 'langchain/retrievers/ensemble' + +const logger = loggerService.withContext('RetrieverFactory') +export class RetrieverFactory { + /** + * 根据提供的参数创建一个 LangChain 检索器 (Retriever)。 + * @param base 知识库配置参数。 + * @param vectorStore 一个已初始化的向量存储实例。 + * @param documents 文档列表,用于初始化 BM25Retriever。 + * @returns 返回一个 BaseRetriever 实例。 + */ + public createRetriever(base: KnowledgeBaseParams, vectorStore: FaissStore, documents: Document[]): BaseRetriever { + const retrieverType = base.retriever?.mode ?? 'hybrid' + const retrieverWeight = base.retriever?.weight ?? 0.5 + const searchK = base.documentCount ?? 5 + + logger.info(`Creating retriever of type: ${retrieverType} with k=${searchK}`) + + switch (retrieverType) { + case 'bm25': + if (documents.length === 0) { + throw new Error('BM25Retriever requires documents, but none were provided or found.') + } + logger.info('Create BM25 Retriever') + return BM25Retriever.fromDocuments(documents, { k: searchK }) + + case 'hybrid': { + if (documents.length === 0) { + logger.warn('No documents provided for BM25 part of hybrid search. Falling back to vector search only.') + return vectorStore.asRetriever(searchK) + } + + const vectorstoreRetriever = vectorStore.asRetriever(searchK) + const bm25Retriever = BM25Retriever.fromDocuments(documents, { k: searchK }) + + logger.info('Create Hybrid Retriever') + return new EnsembleRetriever({ + retrievers: [bm25Retriever, vectorstoreRetriever], + weights: [retrieverWeight, 1 - retrieverWeight] + }) + } + + case 'vector': + default: + logger.info('Create Vector Retriever') + return vectorStore.asRetriever(searchK) + } + } +} diff --git a/src/main/knowledge/langchain/splitter/SrtSplitter.ts b/src/main/knowledge/langchain/splitter/SrtSplitter.ts new file mode 100644 index 0000000000..2e7d47da76 --- /dev/null +++ b/src/main/knowledge/langchain/splitter/SrtSplitter.ts @@ -0,0 +1,133 @@ +import { Document } from '@langchain/core/documents' +import { TextSplitter, TextSplitterParams } from 'langchain/text_splitter' + +// 定义一个接口来表示解析后的单个字幕片段 +interface SrtSegment { + text: string + startTime: number // in seconds + endTime: number // in seconds +} + +// 辅助函数:将 SRT 时间戳字符串 (HH:MM:SS,ms) 转换为秒 +function srtTimeToSeconds(time: string): number { + const parts = time.split(':') + const secondsAndMs = parts[2].split(',') + const hours = parseInt(parts[0], 10) + const minutes = parseInt(parts[1], 10) + const seconds = parseInt(secondsAndMs[0], 10) + const milliseconds = parseInt(secondsAndMs[1], 10) + + return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000 +} + +export class SrtSplitter extends TextSplitter { + constructor(fields?: Partial) { + // 传入 chunkSize 和 chunkOverlap + super(fields) + } + splitText(): Promise { + throw new Error('Method not implemented.') + } + + // 核心方法:重写 splitDocuments 来实现自定义逻辑 + async splitDocuments(documents: Document[]): Promise { + const allChunks: Document[] = [] + + for (const doc of documents) { + // 1. 解析 SRT 内容 + const segments = this.parseSrt(doc.pageContent) + if (segments.length === 0) continue + + // 2. 将字幕片段组合成块 + const chunks = this.mergeSegmentsIntoChunks(segments, doc.metadata) + allChunks.push(...chunks) + } + + return allChunks + } + + // 辅助方法:解析整个 SRT 字符串 + private parseSrt(srt: string): SrtSegment[] { + const segments: SrtSegment[] = [] + const blocks = srt.trim().split(/\n\n/) + + for (const block of blocks) { + const lines = block.split('\n') + if (lines.length < 3) continue + + const timeMatch = lines[1].match(/(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})/) + if (!timeMatch) continue + + const startTime = srtTimeToSeconds(timeMatch[1]) + const endTime = srtTimeToSeconds(timeMatch[2]) + const text = lines.slice(2).join(' ').trim() + + segments.push({ text, startTime, endTime }) + } + + return segments + } + + // 辅助方法:将解析后的片段合并成每 5 段一个块 + private mergeSegmentsIntoChunks(segments: SrtSegment[], baseMetadata: Record): Document[] { + const chunks: Document[] = [] + let currentChunkText = '' + let currentChunkStartTime = 0 + let currentChunkEndTime = 0 + let segmentCount = 0 + + for (const segment of segments) { + if (segmentCount === 0) { + currentChunkStartTime = segment.startTime + } + + currentChunkText += (currentChunkText ? ' ' : '') + segment.text + currentChunkEndTime = segment.endTime + segmentCount++ + + // 当累积到 5 段时,创建一个新的 Document + if (segmentCount === 5) { + const metadata: Record = { + ...baseMetadata, + startTime: currentChunkStartTime, + endTime: currentChunkEndTime + } + if (baseMetadata.source_url) { + metadata.source_url_with_timestamp = `${baseMetadata.source_url}?t=${Math.floor(currentChunkStartTime)}s` + } + chunks.push( + new Document({ + pageContent: currentChunkText, + metadata + }) + ) + + // 重置计数器和临时变量 + currentChunkText = '' + currentChunkStartTime = 0 + currentChunkEndTime = 0 + segmentCount = 0 + } + } + + // 如果还有剩余的片段,创建最后一个 Document + if (segmentCount > 0) { + const metadata: Record = { + ...baseMetadata, + startTime: currentChunkStartTime, + endTime: currentChunkEndTime + } + if (baseMetadata.source_url) { + metadata.source_url_with_timestamp = `${baseMetadata.source_url}?t=${Math.floor(currentChunkStartTime)}s` + } + chunks.push( + new Document({ + pageContent: currentChunkText, + metadata + }) + ) + } + + return chunks + } +} diff --git a/src/main/knowledge/langchain/splitter/index.ts b/src/main/knowledge/langchain/splitter/index.ts new file mode 100644 index 0000000000..62ca1c9e90 --- /dev/null +++ b/src/main/knowledge/langchain/splitter/index.ts @@ -0,0 +1,31 @@ +import { RecursiveCharacterTextSplitter, TextSplitter } from '@langchain/textsplitters' + +import { SrtSplitter } from './SrtSplitter' + +export type SplitterConfig = { + chunkSize?: number + chunkOverlap?: number + type?: 'recursive' | 'srt' | string +} +export class SplitterFactory { + /** + * Creates a TextSplitter instance based on the provided configuration. + * @param config - The configuration object specifying the splitter type and its parameters. + * @returns An instance of a TextSplitter, or null if no splitting is required. + */ + public static create(config: SplitterConfig): TextSplitter { + switch (config.type) { + case 'srt': + return new SrtSplitter({ + chunkSize: config.chunkSize, + chunkOverlap: config.chunkOverlap + }) + case 'recursive': + default: + return new RecursiveCharacterTextSplitter({ + chunkSize: config.chunkSize, + chunkOverlap: config.chunkOverlap + }) + } + } +} diff --git a/src/main/knowledge/preprocess/PreprocessingService.ts b/src/main/knowledge/preprocess/PreprocessingService.ts new file mode 100644 index 0000000000..1c10529be9 --- /dev/null +++ b/src/main/knowledge/preprocess/PreprocessingService.ts @@ -0,0 +1,63 @@ +import PreprocessProvider from '@main/knowledge/preprocess/PreprocessProvider' +import { loggerService } from '@main/services/LoggerService' +import { windowService } from '@main/services/WindowService' +import type { FileMetadata, KnowledgeBaseParams, KnowledgeItem } from '@types' + +const logger = loggerService.withContext('PreprocessingService') + +class PreprocessingService { + public async preprocessFile( + file: FileMetadata, + base: KnowledgeBaseParams, + item: KnowledgeItem, + userId: string + ): Promise { + let fileToProcess: FileMetadata = file + // Check if preprocessing is configured and applicable (e.g., for PDFs) + if (base.preprocessProvider && file.ext.toLowerCase() === '.pdf') { + try { + const provider = new PreprocessProvider(base.preprocessProvider.provider, userId) + + // Check if file has already been preprocessed + const alreadyProcessed = await provider.checkIfAlreadyProcessed(file) + if (alreadyProcessed) { + logger.debug(`File already preprocessed, using cached result: ${file.path}`) + return alreadyProcessed + } + + // Execute preprocessing + logger.debug(`Starting preprocess for scanned PDF: ${file.path}`) + const { processedFile, quota } = await provider.parseFile(item.id, file) + fileToProcess = processedFile + + // Notify the UI + const mainWindow = windowService.getMainWindow() + mainWindow?.webContents.send('file-preprocess-finished', { + itemId: item.id, + quota: quota + }) + } catch (err) { + logger.error(`Preprocessing failed: ${err}`) + // If preprocessing fails, re-throw the error to be handled by the caller + throw new Error(`Preprocessing failed: ${err}`) + } + } + + return fileToProcess + } + + public async checkQuota(base: KnowledgeBaseParams, userId: string): Promise { + try { + if (base.preprocessProvider && base.preprocessProvider.type === 'preprocess') { + const provider = new PreprocessProvider(base.preprocessProvider.provider, userId) + return await provider.checkQuota() + } + throw new Error('No preprocess provider configured') + } catch (err) { + logger.error(`Failed to check quota: ${err}`) + throw new Error(`Failed to check quota: ${err}`) + } + } +} + +export const preprocessingService = new PreprocessingService() diff --git a/src/main/knowledge/reranker/BaseReranker.ts b/src/main/knowledge/reranker/BaseReranker.ts index c3ac979d25..9483cb3d4e 100644 --- a/src/main/knowledge/reranker/BaseReranker.ts +++ b/src/main/knowledge/reranker/BaseReranker.ts @@ -1,101 +1,46 @@ -import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' -import { KnowledgeBaseParams } from '@types' +import { DEFAULT_DOCUMENT_COUNT, DEFAULT_RELEVANT_SCORE } from '@main/utils/knowledge' +import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types' + +import { MultiModalDocument, RerankStrategy } from './strategies/RerankStrategy' +import { StrategyFactory } from './strategies/StrategyFactory' export default abstract class BaseReranker { protected base: KnowledgeBaseParams + protected strategy: RerankStrategy constructor(base: KnowledgeBaseParams) { if (!base.rerankApiClient) { throw new Error('Rerank model is required') } this.base = base + this.strategy = StrategyFactory.createStrategy(base.rerankApiClient.provider) } - - abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise - - /** - * Get Rerank Request Url - */ - protected getRerankUrl() { - if (this.base.rerankApiClient?.provider === 'bailian') { - return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank' - } - - let baseURL = this.base.rerankApiClient?.baseURL - - if (baseURL && baseURL.endsWith('/')) { - // `/` 结尾强制使用rerankBaseURL - return `${baseURL}rerank` - } - - if (baseURL && !baseURL.endsWith('/v1')) { - baseURL = `${baseURL}/v1` - } - - return `${baseURL}/rerank` + abstract rerank(query: string, searchResults: KnowledgeSearchResult[]): Promise + protected getRerankUrl(): string { + return this.strategy.buildUrl(this.base.rerankApiClient?.baseURL) } - - /** - * Get Rerank Request Body - */ - protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) { - const provider = this.base.rerankApiClient?.provider - const documents = searchResults.map((doc) => doc.pageContent) - const topN = this.base.documentCount - - if (provider === 'voyageai') { - return { - model: this.base.rerankApiClient?.model, - query, - documents, - top_k: topN - } - } else if (provider === 'bailian') { - return { - model: this.base.rerankApiClient?.model, - input: { - query, - documents - }, - parameters: { - top_n: topN - } - } - } else if (provider?.includes('tei')) { - return { - query, - texts: documents, - return_text: true - } - } else { - return { - model: this.base.rerankApiClient?.model, - query, - documents, - top_n: topN - } - } + protected getRerankRequestBody(query: string, searchResults: KnowledgeSearchResult[]) { + const documents = this.buildDocuments(searchResults) + const topN = this.base.documentCount ?? DEFAULT_DOCUMENT_COUNT + const model = this.base.rerankApiClient?.model + return this.strategy.buildRequestBody(query, documents, topN, model) } + private buildDocuments(searchResults: KnowledgeSearchResult[]): MultiModalDocument[] { + return searchResults.map((doc) => { + const document: MultiModalDocument = {} - /** - * Extract Rerank Result - */ + // 检查是否是图片类型,添加图片内容 + if (doc.metadata?.type === 'image') { + document.image = doc.pageContent + } else { + document.text = doc.pageContent + } + + return document + }) + } protected extractRerankResult(data: any) { - const provider = this.base.rerankApiClient?.provider - if (provider === 'bailian') { - return data.output.results - } else if (provider === 'voyageai') { - return data.data - } else if (provider?.includes('tei')) { - return data.map((item: any) => { - return { - index: item.index, - relevance_score: item.score - } - }) - } else { - return data.results - } + return this.strategy.extractResults(data) } /** @@ -105,35 +50,30 @@ export default abstract class BaseReranker { * @protected */ protected getRerankResult( - searchResults: ExtractChunkData[], - rerankResults: Array<{ - index: number - relevance_score: number - }> + searchResults: KnowledgeSearchResult[], + rerankResults: Array<{ index: number; relevance_score: number }> ) { - const resultMap = new Map(rerankResults.map((result) => [result.index, result.relevance_score || 0])) + const resultMap = new Map( + rerankResults.map((result) => [result.index, result.relevance_score || DEFAULT_RELEVANT_SCORE]) + ) - return searchResults - .map((doc: ExtractChunkData, index: number) => { + const returenResults = searchResults + .map((doc: KnowledgeSearchResult, index: number) => { const score = resultMap.get(index) if (score === undefined) return undefined - - return { - ...doc, - score - } + return { ...doc, score } }) - .filter((doc): doc is ExtractChunkData => doc !== undefined) + .filter((doc): doc is KnowledgeSearchResult => doc !== undefined) .sort((a, b) => b.score - a.score) - } + return returenResults + } public defaultHeaders() { return { Authorization: `Bearer ${this.base.rerankApiClient?.apiKey}`, 'Content-Type': 'application/json' } } - protected formatErrorMessage(url: string, error: any, requestBody: any) { const errorDetails = { url: url, diff --git a/src/main/knowledge/reranker/GeneralReranker.ts b/src/main/knowledge/reranker/GeneralReranker.ts index 5a0e240a9d..e4b3503606 100644 --- a/src/main/knowledge/reranker/GeneralReranker.ts +++ b/src/main/knowledge/reranker/GeneralReranker.ts @@ -1,19 +1,14 @@ -import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' -import { KnowledgeBaseParams } from '@types' +import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types' import { net } from 'electron' import BaseReranker from './BaseReranker' - export default class GeneralReranker extends BaseReranker { constructor(base: KnowledgeBaseParams) { super(base) } - - public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise => { + public rerank = async (query: string, searchResults: KnowledgeSearchResult[]): Promise => { const url = this.getRerankUrl() - const requestBody = this.getRerankRequestBody(query, searchResults) - try { const response = await net.fetch(url, { method: 'POST', diff --git a/src/main/knowledge/reranker/Reranker.ts b/src/main/knowledge/reranker/Reranker.ts index d42376ea20..59de4b0470 100644 --- a/src/main/knowledge/reranker/Reranker.ts +++ b/src/main/knowledge/reranker/Reranker.ts @@ -1,5 +1,4 @@ -import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' -import { KnowledgeBaseParams } from '@types' +import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types' import GeneralReranker from './GeneralReranker' @@ -8,7 +7,7 @@ export default class Reranker { constructor(base: KnowledgeBaseParams) { this.sdk = new GeneralReranker(base) } - public async rerank(query: string, searchResults: ExtractChunkData[]): Promise { + public async rerank(query: string, searchResults: KnowledgeSearchResult[]): Promise { return this.sdk.rerank(query, searchResults) } } diff --git a/src/main/knowledge/reranker/strategies/BailianStrategy.ts b/src/main/knowledge/reranker/strategies/BailianStrategy.ts new file mode 100644 index 0000000000..e5932b9f40 --- /dev/null +++ b/src/main/knowledge/reranker/strategies/BailianStrategy.ts @@ -0,0 +1,18 @@ +import { MultiModalDocument, RerankStrategy } from './RerankStrategy' +export class BailianStrategy implements RerankStrategy { + buildUrl(): string { + return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank' + } + buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) { + const textDocuments = documents.filter((d) => d.text).map((d) => d.text!) + + return { + model, + input: { query, documents: textDocuments }, + parameters: { top_n: topN } + } + } + extractResults(data: any) { + return data.output.results + } +} diff --git a/src/main/knowledge/reranker/strategies/DefaultStrategy.ts b/src/main/knowledge/reranker/strategies/DefaultStrategy.ts new file mode 100644 index 0000000000..59ee3fb47b --- /dev/null +++ b/src/main/knowledge/reranker/strategies/DefaultStrategy.ts @@ -0,0 +1,25 @@ +import { MultiModalDocument, RerankStrategy } from './RerankStrategy' +export class DefaultStrategy implements RerankStrategy { + buildUrl(baseURL?: string): string { + if (baseURL && baseURL.endsWith('/')) { + return `${baseURL}rerank` + } + if (baseURL && !baseURL.endsWith('/v1')) { + baseURL = `${baseURL}/v1` + } + return `${baseURL}/rerank` + } + buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) { + const textDocuments = documents.filter((d) => d.text).map((d) => d.text!) + + return { + model, + query, + documents: textDocuments, + top_n: topN + } + } + extractResults(data: any) { + return data.results + } +} diff --git a/src/main/knowledge/reranker/strategies/JinaStrategy.ts b/src/main/knowledge/reranker/strategies/JinaStrategy.ts new file mode 100644 index 0000000000..200190f544 --- /dev/null +++ b/src/main/knowledge/reranker/strategies/JinaStrategy.ts @@ -0,0 +1,33 @@ +import { MultiModalDocument, RerankStrategy } from './RerankStrategy' +export class JinaStrategy implements RerankStrategy { + buildUrl(baseURL?: string): string { + if (baseURL && baseURL.endsWith('/')) { + return `${baseURL}rerank` + } + if (baseURL && !baseURL.endsWith('/v1')) { + baseURL = `${baseURL}/v1` + } + return `${baseURL}/rerank` + } + buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) { + if (model === 'jina-reranker-m0') { + return { + model, + query, + documents, + top_n: topN + } + } + const textDocuments = documents.filter((d) => d.text).map((d) => d.text!) + + return { + model, + query, + documents: textDocuments, + top_n: topN + } + } + extractResults(data: any) { + return data.results + } +} diff --git a/src/main/knowledge/reranker/strategies/RerankStrategy.ts b/src/main/knowledge/reranker/strategies/RerankStrategy.ts new file mode 100644 index 0000000000..a23f630b7d --- /dev/null +++ b/src/main/knowledge/reranker/strategies/RerankStrategy.ts @@ -0,0 +1,9 @@ +export interface MultiModalDocument { + text?: string + image?: string +} +export interface RerankStrategy { + buildUrl(baseURL?: string): string + buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string): any + extractResults(data: any): Array<{ index: number; relevance_score: number }> +} diff --git a/src/main/knowledge/reranker/strategies/StrategyFactory.ts b/src/main/knowledge/reranker/strategies/StrategyFactory.ts new file mode 100644 index 0000000000..9e04547e31 --- /dev/null +++ b/src/main/knowledge/reranker/strategies/StrategyFactory.ts @@ -0,0 +1,25 @@ +import { BailianStrategy } from './BailianStrategy' +import { DefaultStrategy } from './DefaultStrategy' +import { JinaStrategy } from './JinaStrategy' +import { RerankStrategy } from './RerankStrategy' +import { TEIStrategy } from './TeiStrategy' +import { isTEIProvider, RERANKER_PROVIDERS } from './types' +import { VoyageAIStrategy } from './VoyageStrategy' + +export class StrategyFactory { + static createStrategy(provider?: string): RerankStrategy { + switch (provider) { + case RERANKER_PROVIDERS.VOYAGEAI: + return new VoyageAIStrategy() + case RERANKER_PROVIDERS.BAILIAN: + return new BailianStrategy() + case RERANKER_PROVIDERS.JINA: + return new JinaStrategy() + default: + if (isTEIProvider(provider)) { + return new TEIStrategy() + } + return new DefaultStrategy() + } + } +} diff --git a/src/main/knowledge/reranker/strategies/TeiStrategy.ts b/src/main/knowledge/reranker/strategies/TeiStrategy.ts new file mode 100644 index 0000000000..58f24661ba --- /dev/null +++ b/src/main/knowledge/reranker/strategies/TeiStrategy.ts @@ -0,0 +1,26 @@ +import { MultiModalDocument, RerankStrategy } from './RerankStrategy' +export class TEIStrategy implements RerankStrategy { + buildUrl(baseURL?: string): string { + if (baseURL && baseURL.endsWith('/')) { + return `${baseURL}rerank` + } + if (baseURL && !baseURL.endsWith('/v1')) { + baseURL = `${baseURL}/v1` + } + return `${baseURL}/rerank` + } + buildRequestBody(query: string, documents: MultiModalDocument[]) { + const textDocuments = documents.filter((d) => d.text).map((d) => d.text!) + return { + query, + texts: textDocuments, + return_text: true + } + } + extractResults(data: any) { + return data.map((item: any) => ({ + index: item.index, + relevance_score: item.score + })) + } +} diff --git a/src/main/knowledge/reranker/strategies/VoyageStrategy.ts b/src/main/knowledge/reranker/strategies/VoyageStrategy.ts new file mode 100644 index 0000000000..e81319f024 --- /dev/null +++ b/src/main/knowledge/reranker/strategies/VoyageStrategy.ts @@ -0,0 +1,24 @@ +import { MultiModalDocument, RerankStrategy } from './RerankStrategy' +export class VoyageAIStrategy implements RerankStrategy { + buildUrl(baseURL?: string): string { + if (baseURL && baseURL.endsWith('/')) { + return `${baseURL}rerank` + } + if (baseURL && !baseURL.endsWith('/v1')) { + baseURL = `${baseURL}/v1` + } + return `${baseURL}/rerank` + } + buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) { + const textDocuments = documents.filter((d) => d.text).map((d) => d.text!) + return { + model, + query, + documents: textDocuments, + top_k: topN + } + } + extractResults(data: any) { + return data.data + } +} diff --git a/src/main/knowledge/reranker/strategies/types.ts b/src/main/knowledge/reranker/strategies/types.ts new file mode 100644 index 0000000000..91bfdef8f6 --- /dev/null +++ b/src/main/knowledge/reranker/strategies/types.ts @@ -0,0 +1,19 @@ +import { objectValues } from '@types' + +export const RERANKER_PROVIDERS = { + VOYAGEAI: 'voyageai', + BAILIAN: 'bailian', + JINA: 'jina', + TEI: 'tei' +} as const + +export type RerankProvider = (typeof RERANKER_PROVIDERS)[keyof typeof RERANKER_PROVIDERS] + +export function isTEIProvider(provider?: string): boolean { + return provider?.includes(RERANKER_PROVIDERS.TEI) ?? false +} + +export function isKnownProvider(provider?: string): provider is RerankProvider { + if (!provider) return false + return objectValues(RERANKER_PROVIDERS).some((p) => p === provider) +} diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index 538a0ede79..775ca12db5 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -56,6 +56,45 @@ type CallToolArgs = { server: MCPServer; name: string; args: any; callId?: strin const logger = loggerService.withContext('MCPService') +// Redact potentially sensitive fields in objects (headers, tokens, api keys) +function redactSensitive(input: any): any { + const SENSITIVE_KEYS = ['authorization', 'Authorization', 'apiKey', 'api_key', 'apikey', 'token', 'access_token'] + const MAX_STRING = 300 + + const redact = (val: any): any => { + if (val == null) return val + if (typeof val === 'string') { + return val.length > MAX_STRING ? `${val.slice(0, MAX_STRING)}…<${val.length - MAX_STRING} more>` : val + } + if (Array.isArray(val)) return val.map((v) => redact(v)) + if (typeof val === 'object') { + const out: Record = {} + for (const [k, v] of Object.entries(val)) { + if (SENSITIVE_KEYS.includes(k)) { + out[k] = '' + } else { + out[k] = redact(v) + } + } + return out + } + return val + } + + return redact(input) +} + +// Create a context-aware logger for a server +function getServerLogger(server: MCPServer, extra?: Record) { + const base = { + serverName: server?.name, + serverId: server?.id, + baseUrl: server?.baseUrl, + type: server?.type || (server?.command ? 'stdio' : server?.baseUrl ? 'http' : 'inmemory') + } + return loggerService.withContext('MCPService', { ...base, ...(extra || {}) }) +} + /** * Higher-order function to add caching capability to any async function * @param fn The original function to be wrapped with caching @@ -74,15 +113,17 @@ function withCache( const cacheKey = getCacheKey(...args) if (CacheService.has(cacheKey)) { - logger.debug(`${logPrefix} loaded from cache`) + logger.debug(`${logPrefix} loaded from cache`, { cacheKey }) const cachedData = CacheService.get(cacheKey) if (cachedData) { return cachedData } } + const start = Date.now() const result = await fn(...args) CacheService.set(cacheKey, result, ttl) + logger.debug(`${logPrefix} cached`, { cacheKey, ttlMs: ttl, durationMs: Date.now() - start }) return result } } @@ -128,6 +169,7 @@ class McpService { // If there's a pending initialization, wait for it const pendingClient = this.pendingClients.get(serverKey) if (pendingClient) { + getServerLogger(server).silly(`Waiting for pending client initialization`) return pendingClient } @@ -136,8 +178,11 @@ class McpService { if (existingClient) { try { // Check if the existing client is still connected - const pingResult = await existingClient.ping() - logger.debug(`Ping result for ${server.name}:`, pingResult) + const pingResult = await existingClient.ping({ + // add short timeout to prevent hanging + timeout: 1000 + }) + getServerLogger(server).debug(`Ping result`, { ok: !!pingResult }) // If the ping fails, remove the client from the cache // and create a new one if (!pingResult) { @@ -146,7 +191,7 @@ class McpService { return existingClient } } catch (error: any) { - logger.error(`Error pinging server ${server.name}:`, error?.message) + getServerLogger(server).error(`Error pinging server`, error as Error) this.clients.delete(serverKey) } } @@ -172,15 +217,15 @@ class McpService { > => { // Create appropriate transport based on configuration if (isBuiltinMCPServer(server) && server.name !== BuiltinMCPServerNames.mcpAutoInstall) { - logger.debug(`Using in-memory transport for server: ${server.name}`) + getServerLogger(server).debug(`Using in-memory transport`) const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair() // start the in-memory server with the given name and environment variables const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {}) try { await inMemoryServer.connect(serverTransport) - logger.debug(`In-memory server started: ${server.name}`) + getServerLogger(server).debug(`In-memory server started`) } catch (error: Error | any) { - logger.error(`Error starting in-memory server: ${error}`) + getServerLogger(server).error(`Error starting in-memory server`, error as Error) throw new Error(`Failed to start in-memory server: ${error.message}`) } // set the client transport to the client @@ -193,7 +238,10 @@ class McpService { }, authProvider } - logger.debug(`StreamableHTTPClientTransport options:`, options) + // redact headers before logging + getServerLogger(server).debug(`StreamableHTTPClientTransport options`, { + options: redactSensitive(options) + }) return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options) } else if (server.type === 'sse') { const options: SSEClientTransportOptions = { @@ -209,7 +257,7 @@ class McpService { headers['Authorization'] = `Bearer ${tokens.access_token}` } } catch (error) { - logger.error('Failed to fetch tokens:', error as Error) + getServerLogger(server).error('Failed to fetch tokens:', error as Error) } } @@ -239,15 +287,18 @@ class McpService { ...server.env, ...resolvedConfig.env } - logger.debug(`Using resolved DXT config - command: ${cmd}, args: ${args?.join(' ')}`) + getServerLogger(server).debug(`Using resolved DXT config`, { + command: cmd, + args + }) } else { - logger.warn(`Failed to resolve DXT config for ${server.name}, falling back to manifest values`) + getServerLogger(server).warn(`Failed to resolve DXT config, falling back to manifest values`) } } if (server.command === 'npx') { cmd = await getBinaryPath('bun') - logger.debug(`Using command: ${cmd}`) + getServerLogger(server).debug(`Using command`, { command: cmd }) // add -x to args if args exist if (args && args.length > 0) { @@ -282,7 +333,7 @@ class McpService { } } - logger.debug(`Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`) + getServerLogger(server).debug(`Starting server`, { command: cmd, args }) // Logger.info(`[MCP] Environment variables for server:`, server.env) const loginShellEnv = await this.getLoginShellEnv() @@ -304,12 +355,14 @@ class McpService { // For DXT servers, set the working directory to the extracted path if (server.dxtPath) { transportOptions.cwd = server.dxtPath - logger.debug(`Setting working directory for DXT server: ${server.dxtPath}`) + getServerLogger(server).debug(`Setting working directory for DXT server`, { + cwd: server.dxtPath + }) } const stdioTransport = new StdioClientTransport(transportOptions) stdioTransport.stderr?.on('data', (data) => - logger.debug(`Stdio stderr for server: ${server.name}` + data.toString()) + getServerLogger(server).debug(`Stdio stderr`, { data: data.toString() }) ) return stdioTransport } else { @@ -318,7 +371,7 @@ class McpService { } const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => { - logger.debug(`Starting OAuth flow for server: ${server.name}`) + getServerLogger(server).debug(`Starting OAuth flow`) // Create an event emitter for the OAuth callback const events = new EventEmitter() @@ -331,27 +384,27 @@ class McpService { // Set a timeout to close the callback server const timeoutId = setTimeout(() => { - logger.warn(`OAuth flow timed out for server: ${server.name}`) + getServerLogger(server).warn(`OAuth flow timed out`) callbackServer.close() }, 300000) // 5 minutes timeout try { // Wait for the authorization code const authCode = await callbackServer.waitForAuthCode() - logger.debug(`Received auth code: ${authCode}`) + getServerLogger(server).debug(`Received auth code`) // Complete the OAuth flow await transport.finishAuth(authCode) - logger.debug(`OAuth flow completed for server: ${server.name}`) + getServerLogger(server).debug(`OAuth flow completed`) const newTransport = await initTransport() // Try to connect again await client.connect(newTransport) - logger.debug(`Successfully authenticated with server: ${server.name}`) + getServerLogger(server).debug(`Successfully authenticated`) } catch (oauthError) { - logger.error(`OAuth authentication failed for server ${server.name}:`, oauthError as Error) + getServerLogger(server).error(`OAuth authentication failed`, oauthError as Error) throw new Error( `OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}` ) @@ -390,7 +443,7 @@ class McpService { logger.debug(`Activated server: ${server.name}`) return client } catch (error: any) { - logger.error(`Error activating server ${server.name}:`, error?.message) + getServerLogger(server).error(`Error activating server`, error as Error) throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`) } } finally { @@ -450,9 +503,9 @@ class McpService { logger.debug(`Message from server ${server.name}:`, notification.params) }) - logger.debug(`Set up notification handlers for server: ${server.name}`) + getServerLogger(server).debug(`Set up notification handlers`) } catch (error) { - logger.error(`Failed to set up notification handlers for server ${server.name}:`, error as Error) + getServerLogger(server).error(`Failed to set up notification handlers`, error as Error) } } @@ -470,7 +523,7 @@ class McpService { CacheService.remove(`mcp:list_tool:${serverKey}`) CacheService.remove(`mcp:list_prompts:${serverKey}`) CacheService.remove(`mcp:list_resources:${serverKey}`) - logger.debug(`Cleared all caches for server: ${serverKey}`) + logger.debug(`Cleared all caches for server`, { serverKey }) } async closeClient(serverKey: string) { @@ -478,18 +531,18 @@ class McpService { if (client) { // Remove the client from the cache await client.close() - logger.debug(`Closed server: ${serverKey}`) + logger.debug(`Closed server`, { serverKey }) this.clients.delete(serverKey) // Clear all caches for this server this.clearServerCache(serverKey) } else { - logger.warn(`No client found for server: ${serverKey}`) + logger.warn(`No client found for server`, { serverKey }) } } async stopServer(_: Electron.IpcMainInvokeEvent, server: MCPServer) { const serverKey = this.getServerKey(server) - logger.debug(`Stopping server: ${server.name}`) + getServerLogger(server).debug(`Stopping server`) await this.closeClient(serverKey) } @@ -505,16 +558,16 @@ class McpService { try { const cleaned = this.dxtService.cleanupDxtServer(server.name) if (cleaned) { - logger.debug(`Cleaned up DXT server directory for: ${server.name}`) + getServerLogger(server).debug(`Cleaned up DXT server directory`) } } catch (error) { - logger.error(`Failed to cleanup DXT server: ${server.name}`, error as Error) + getServerLogger(server).error(`Failed to cleanup DXT server`, error as Error) } } } async restartServer(_: Electron.IpcMainInvokeEvent, server: MCPServer) { - logger.debug(`Restarting server: ${server.name}`) + getServerLogger(server).debug(`Restarting server`) const serverKey = this.getServerKey(server) await this.closeClient(serverKey) // Clear cache before restarting to ensure fresh data @@ -527,7 +580,7 @@ class McpService { try { await this.closeClient(key) } catch (error: any) { - logger.error(`Failed to close client: ${error?.message}`) + logger.error(`Failed to close client`, error as Error) } } } @@ -536,9 +589,9 @@ class McpService { * Check connectivity for an MCP server */ public async checkMcpConnectivity(_: Electron.IpcMainInvokeEvent, server: MCPServer): Promise { - logger.debug(`Checking connectivity for server: ${server.name}`) + getServerLogger(server).debug(`Checking connectivity`) try { - logger.debug(`About to call initClient for server: ${server.name}`, { hasInitClient: !!this.initClient }) + getServerLogger(server).debug(`About to call initClient`, { hasInitClient: !!this.initClient }) if (!this.initClient) { throw new Error('initClient method is not available') @@ -547,10 +600,10 @@ class McpService { const client = await this.initClient(server) // Attempt to list tools as a way to check connectivity await client.listTools() - logger.debug(`Connectivity check successful for server: ${server.name}`) + getServerLogger(server).debug(`Connectivity check successful`) return true } catch (error) { - logger.error(`Connectivity check failed for server: ${server.name}`, error as Error) + getServerLogger(server).error(`Connectivity check failed`, error as Error) // Close the client if connectivity check fails to ensure a clean state for the next attempt const serverKey = this.getServerKey(server) await this.closeClient(serverKey) @@ -559,9 +612,8 @@ class McpService { } private async listToolsImpl(server: MCPServer): Promise { - logger.debug(`Listing tools for server: ${server.name}`) + getServerLogger(server).debug(`Listing tools`) const client = await this.initClient(server) - logger.debug(`Client for server: ${server.name}`, client) try { const { tools } = await client.listTools() const serverTools: MCPTool[] = [] @@ -577,7 +629,7 @@ class McpService { }) return serverTools } catch (error: any) { - logger.error(`Failed to list tools for server: ${server.name}`, error?.message) + getServerLogger(server).error(`Failed to list tools`, error as Error) return [] } } @@ -614,12 +666,16 @@ class McpService { const callToolFunc = async ({ server, name, args }: CallToolArgs) => { try { - logger.debug(`Calling: ${server.name} ${name} ${JSON.stringify(args)} callId: ${toolCallId}`, server) + getServerLogger(server, { tool: name, callId: toolCallId }).debug(`Calling tool`, { + args: redactSensitive(args) + }) if (typeof args === 'string') { try { args = JSON.parse(args) } catch (e) { - logger.error('args parse error', args) + getServerLogger(server, { tool: name, callId: toolCallId }).error('args parse error', e as Error, { + args + }) } if (args === '') { args = {} @@ -628,8 +684,9 @@ class McpService { const client = await this.initClient(server) const result = await client.callTool({ name, arguments: args }, undefined, { onprogress: (process) => { - logger.debug(`Progress: ${process.progress / (process.total || 1)}`) - logger.debug(`Progress notification received for server: ${server.name}`, process) + getServerLogger(server, { tool: name, callId: toolCallId }).debug(`Progress`, { + ratio: process.progress / (process.total || 1) + }) const mainWindow = windowService.getMainWindow() if (mainWindow) { mainWindow.webContents.send('mcp-progress', process.progress / (process.total || 1)) @@ -644,7 +701,7 @@ class McpService { }) return result as MCPCallToolResponse } catch (error) { - logger.error(`Error calling tool ${name} on ${server.name}:`, error as Error) + getServerLogger(server, { tool: name, callId: toolCallId }).error(`Error calling tool`, error as Error) throw error } finally { this.activeToolCalls.delete(toolCallId) @@ -668,7 +725,7 @@ class McpService { */ private async listPromptsImpl(server: MCPServer): Promise { const client = await this.initClient(server) - logger.debug(`Listing prompts for server: ${server.name}`) + getServerLogger(server).debug(`Listing prompts`) try { const { prompts } = await client.listPrompts() return prompts.map((prompt: any) => ({ @@ -680,7 +737,7 @@ class McpService { } catch (error: any) { // -32601 is the code for the method not found if (error?.code !== -32601) { - logger.error(`Failed to list prompts for server: ${server.name}`, error?.message) + getServerLogger(server).error(`Failed to list prompts`, error as Error) } return [] } @@ -749,7 +806,7 @@ class McpService { } catch (error: any) { // -32601 is the code for the method not found if (error?.code !== -32601) { - logger.error(`Failed to list resources for server: ${server.name}`, error?.message) + getServerLogger(server).error(`Failed to list resources`, error as Error) } return [] } @@ -775,7 +832,7 @@ class McpService { * Get a specific resource from an MCP server (implementation) */ private async getResourceImpl(server: MCPServer, uri: string): Promise { - logger.debug(`Getting resource ${uri} from server: ${server.name}`) + getServerLogger(server, { uri }).debug(`Getting resource`) const client = await this.initClient(server) try { const result = await client.readResource({ uri: uri }) @@ -793,7 +850,7 @@ class McpService { contents: contents } } catch (error: Error | any) { - logger.error(`Failed to get resource ${uri} from server: ${server.name}`, error.message) + getServerLogger(server, { uri }).error(`Failed to get resource`, error as Error) throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`) } } @@ -838,10 +895,10 @@ class McpService { if (activeToolCall) { activeToolCall.abort() this.activeToolCalls.delete(callId) - logger.debug(`Aborted tool call: ${callId}`) + logger.debug(`Aborted tool call`, { callId }) return true } else { - logger.warn(`No active tool call found for callId: ${callId}`) + logger.warn(`No active tool call found for callId`, { callId }) return false } } @@ -851,22 +908,22 @@ class McpService { */ public async getServerVersion(_: Electron.IpcMainInvokeEvent, server: MCPServer): Promise { try { - logger.debug(`Getting server version for: ${server.name}`) + getServerLogger(server).debug(`Getting server version`) const client = await this.initClient(server) // Try to get server information which may include version const serverInfo = client.getServerVersion() - logger.debug(`Server info for ${server.name}:`, serverInfo) + getServerLogger(server).debug(`Server info`, redactSensitive(serverInfo)) if (serverInfo && serverInfo.version) { - logger.debug(`Server version for ${server.name}: ${serverInfo.version}`) + getServerLogger(server).debug(`Server version`, { version: serverInfo.version }) return serverInfo.version } - logger.warn(`No version information available for server: ${server.name}`) + getServerLogger(server).warn(`No version information available`) return null } catch (error: any) { - logger.error(`Failed to get server version for ${server.name}:`, error?.message) + getServerLogger(server).error(`Failed to get server version`, error as Error) return null } } diff --git a/src/main/services/KnowledgeService.ts b/src/main/services/knowledge/EmbedJsFramework.ts similarity index 52% rename from src/main/services/KnowledgeService.ts rename to src/main/services/knowledge/EmbedJsFramework.ts index 99879390e4..64ac77434e 100644 --- a/src/main/services/KnowledgeService.ts +++ b/src/main/services/knowledge/EmbedJsFramework.ts @@ -1,111 +1,40 @@ -/** - * Knowledge Service - Manages knowledge bases using RAG (Retrieval-Augmented Generation) - * - * This service handles creation, management, and querying of knowledge bases from various sources - * including files, directories, URLs, sitemaps, and notes. - * - * Features: - * - Concurrent task processing with workload management - * - Multiple data source support - * - Vector database integration - * - * For detailed documentation, see: - * @see {@link ../../../docs/technical/KnowledgeService.md} - */ - import * as fs from 'node:fs' import path from 'node:path' import { RAGApplication, RAGApplicationBuilder } from '@cherrystudio/embedjs' -import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import { LibSqlDb } from '@cherrystudio/embedjs-libsql' import { SitemapLoader } from '@cherrystudio/embedjs-loader-sitemap' import { WebLoader } from '@cherrystudio/embedjs-loader-web' import { loggerService } from '@logger' -import Embeddings from '@main/knowledge/embeddings/Embeddings' -import { addFileLoader } from '@main/knowledge/loader' -import { NoteLoader } from '@main/knowledge/loader/noteLoader' -import PreprocessProvider from '@main/knowledge/preprocess/PreprocessProvider' -import Reranker from '@main/knowledge/reranker/Reranker' -import { fileStorage } from '@main/services/FileStorage' -import { windowService } from '@main/services/WindowService' -import { getDataPath } from '@main/utils' +import Embeddings from '@main/knowledge/embedjs/embeddings/Embeddings' +import { addFileLoader } from '@main/knowledge/embedjs/loader' +import { NoteLoader } from '@main/knowledge/embedjs/loader/noteLoader' +import { preprocessingService } from '@main/knowledge/preprocess/PreprocessingService' import { getAllFiles } from '@main/utils/file' -import { TraceMethod } from '@mcp-trace/trace-core' import { MB } from '@shared/config/constant' -import type { LoaderReturn } from '@shared/config/types' +import { LoaderReturn } from '@shared/config/types' import { IpcChannel } from '@shared/IpcChannel' -import { FileMetadata, KnowledgeBaseParams, KnowledgeItem } from '@types' +import { FileMetadata, KnowledgeBaseParams, KnowledgeSearchResult } from '@types' import { v4 as uuidv4 } from 'uuid' +import { windowService } from '../WindowService' +import { + IKnowledgeFramework, + KnowledgeBaseAddItemOptionsNonNullableAttribute, + LoaderDoneReturn, + LoaderTask, + LoaderTaskItem, + LoaderTaskItemState +} from './IKnowledgeFramework' + const logger = loggerService.withContext('MainKnowledgeService') -export interface KnowledgeBaseAddItemOptions { - base: KnowledgeBaseParams - item: KnowledgeItem - forceReload?: boolean - userId?: string -} - -interface KnowledgeBaseAddItemOptionsNonNullableAttribute { - base: KnowledgeBaseParams - item: KnowledgeItem - forceReload: boolean - userId: string -} - -interface EvaluateTaskWorkload { - workload: number -} - -type LoaderDoneReturn = LoaderReturn | null - -enum LoaderTaskItemState { - PENDING, - PROCESSING, - DONE -} - -interface LoaderTaskItem { - state: LoaderTaskItemState - task: () => Promise - evaluateTaskWorkload: EvaluateTaskWorkload -} - -interface LoaderTask { - loaderTasks: LoaderTaskItem[] - loaderDoneReturn: LoaderDoneReturn -} - -interface LoaderTaskOfSet { - loaderTasks: Set - loaderDoneReturn: LoaderDoneReturn -} - -interface QueueTaskItem { - taskPromise: () => Promise - resolve: () => void - evaluateTaskWorkload: EvaluateTaskWorkload -} - -const loaderTaskIntoOfSet = (loaderTask: LoaderTask): LoaderTaskOfSet => { - return { - loaderTasks: new Set(loaderTask.loaderTasks), - loaderDoneReturn: loaderTask.loaderDoneReturn - } -} - -class KnowledgeService { - private storageDir = path.join(getDataPath(), 'KnowledgeBase') - private pendingDeleteFile = path.join(this.storageDir, 'knowledge_pending_delete.json') - // Byte based - private workload = 0 - private processingItemCount = 0 - private knowledgeItemProcessingQueueMappingPromise: Map void> = new Map() +export class EmbedJsFramework implements IKnowledgeFramework { + private storageDir: string private ragApplications: Map = new Map() + private pendingDeleteFile: string private dbInstances: Map = new Map() - private static MAXIMUM_WORKLOAD = 80 * MB - private static MAXIMUM_PROCESSING_ITEM_COUNT = 30 + private static ERROR_LOADER_RETURN: LoaderReturn = { entriesAdded: 0, uniqueId: '', @@ -114,7 +43,9 @@ class KnowledgeService { status: 'failed' } - constructor() { + constructor(storageDir: string) { + this.storageDir = storageDir + this.pendingDeleteFile = path.join(this.storageDir, 'knowledge_pending_delete.json') this.initStorageDir() this.cleanupOnStartup() } @@ -229,33 +160,28 @@ class KnowledgeService { logger.info(`Startup cleanup completed: ${deletedCount}/${pendingDeleteIds.length} knowledge bases deleted`) } - private getRagApplication = async ({ - id, - embedApiClient, - dimensions, - documentCount - }: KnowledgeBaseParams): Promise => { - if (this.ragApplications.has(id)) { - return this.ragApplications.get(id)! + private async getRagApplication(base: KnowledgeBaseParams): Promise { + if (this.ragApplications.has(base.id)) { + return this.ragApplications.get(base.id)! } let ragApplication: RAGApplication const embeddings = new Embeddings({ - embedApiClient, - dimensions + embedApiClient: base.embedApiClient, + dimensions: base.dimensions }) try { - const libSqlDb = new LibSqlDb({ path: path.join(this.storageDir, id) }) + const libSqlDb = new LibSqlDb({ path: path.join(this.storageDir, base.id) }) // Save database instance for later closing - this.dbInstances.set(id, libSqlDb) + this.dbInstances.set(base.id, libSqlDb) ragApplication = await new RAGApplicationBuilder() .setModel('NO_MODEL') .setEmbeddingModel(embeddings) .setVectorDatabase(libSqlDb) - .setSearchResultCount(documentCount || 30) + .setSearchResultCount(base.documentCount || 30) .build() - this.ragApplications.set(id, ragApplication) + this.ragApplications.set(base.id, ragApplication) } catch (e) { logger.error('Failed to create RAGApplication:', e as Error) throw new Error(`Failed to create RAGApplication: ${e}`) @@ -263,17 +189,14 @@ class KnowledgeService { return ragApplication } - - public create = async (_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise => { + async initialize(base: KnowledgeBaseParams): Promise { await this.getRagApplication(base) } - - public reset = async (_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise => { - const ragApplication = await this.getRagApplication(base) - await ragApplication.reset() + async reset(base: KnowledgeBaseParams): Promise { + const ragApp = await this.getRagApplication(base) + await ragApp.reset() } - - public async delete(_: Electron.IpcMainInvokeEvent, id: string): Promise { + async delete(id: string): Promise { logger.debug(`delete id: ${id}`) await this.cleanupKnowledgeResources(id) @@ -286,15 +209,41 @@ class KnowledgeService { this.pendingDeleteManager.add(id) } } - - private maximumLoad() { - return ( - this.processingItemCount >= KnowledgeService.MAXIMUM_PROCESSING_ITEM_COUNT || - this.workload >= KnowledgeService.MAXIMUM_WORKLOAD - ) + getLoaderTask(options: KnowledgeBaseAddItemOptionsNonNullableAttribute): LoaderTask { + const { item } = options + const getRagApplication = () => this.getRagApplication(options.base) + switch (item.type) { + case 'file': + return this.fileTask(getRagApplication, options) + case 'directory': + return this.directoryTask(getRagApplication, options) + case 'url': + return this.urlTask(getRagApplication, options) + case 'sitemap': + return this.sitemapTask(getRagApplication, options) + case 'note': + return this.noteTask(getRagApplication, options) + default: + return { + loaderTasks: [], + loaderDoneReturn: null + } + } } + + async remove(options: { uniqueIds: string[]; base: KnowledgeBaseParams }): Promise { + const ragApp = await this.getRagApplication(options.base) + for (const id of options.uniqueIds) { + await ragApp.deleteLoader(id) + } + } + async search(options: { search: string; base: KnowledgeBaseParams }): Promise { + const ragApp = await this.getRagApplication(options.base) + return await ragApp.search(options.search) + } + private fileTask( - ragApplication: RAGApplication, + getRagApplication: () => Promise, options: KnowledgeBaseAddItemOptionsNonNullableAttribute ): LoaderTask { const { base, item, forceReload, userId } = options @@ -307,7 +256,8 @@ class KnowledgeService { task: async () => { try { // Add preprocessing logic - const fileToProcess: FileMetadata = await this.preprocessing(file, base, item, userId) + const ragApplication = await getRagApplication() + const fileToProcess: FileMetadata = await preprocessingService.preprocessFile(file, base, item, userId) // Use processed file for loading return addFileLoader(ragApplication, fileToProcess, base, forceReload) @@ -318,7 +268,7 @@ class KnowledgeService { .catch((e) => { logger.error(`Error in addFileLoader for ${file.name}: ${e}`) const errorResult: LoaderReturn = { - ...KnowledgeService.ERROR_LOADER_RETURN, + ...EmbedJsFramework.ERROR_LOADER_RETURN, message: e.message, messageSource: 'embedding' } @@ -328,7 +278,7 @@ class KnowledgeService { } catch (e: any) { logger.error(`Preprocessing failed for ${file.name}: ${e}`) const errorResult: LoaderReturn = { - ...KnowledgeService.ERROR_LOADER_RETURN, + ...EmbedJsFramework.ERROR_LOADER_RETURN, message: e.message, messageSource: 'preprocess' } @@ -345,7 +295,7 @@ class KnowledgeService { return loaderTask } private directoryTask( - ragApplication: RAGApplication, + getRagApplication: () => Promise, options: KnowledgeBaseAddItemOptionsNonNullableAttribute ): LoaderTask { const { base, item, forceReload } = options @@ -372,8 +322,9 @@ class KnowledgeService { for (const file of files) { loaderTasks.push({ state: LoaderTaskItemState.PENDING, - task: () => - addFileLoader(ragApplication, file, base, forceReload) + task: async () => { + const ragApplication = await getRagApplication() + return addFileLoader(ragApplication, file, base, forceReload) .then((result) => { loaderDoneReturn.entriesAdded += 1 processedFiles += 1 @@ -384,11 +335,12 @@ class KnowledgeService { .catch((err) => { logger.error('Failed to add dir loader:', err) return { - ...KnowledgeService.ERROR_LOADER_RETURN, + ...EmbedJsFramework.ERROR_LOADER_RETURN, message: `Failed to add dir loader: ${err.message}`, messageSource: 'embedding' } - }), + }) + }, evaluateTaskWorkload: { workload: file.size } }) } @@ -400,7 +352,7 @@ class KnowledgeService { } private urlTask( - ragApplication: RAGApplication, + getRagApplication: () => Promise, options: KnowledgeBaseAddItemOptionsNonNullableAttribute ): LoaderTask { const { base, item, forceReload } = options @@ -410,7 +362,8 @@ class KnowledgeService { loaderTasks: [ { state: LoaderTaskItemState.PENDING, - task: () => { + task: async () => { + const ragApplication = await getRagApplication() const loaderReturn = ragApplication.addLoader( new WebLoader({ urlOrContent: content, @@ -434,7 +387,7 @@ class KnowledgeService { .catch((err) => { logger.error('Failed to add url loader:', err) return { - ...KnowledgeService.ERROR_LOADER_RETURN, + ...EmbedJsFramework.ERROR_LOADER_RETURN, message: `Failed to add url loader: ${err.message}`, messageSource: 'embedding' } @@ -449,7 +402,7 @@ class KnowledgeService { } private sitemapTask( - ragApplication: RAGApplication, + getRagApplication: () => Promise, options: KnowledgeBaseAddItemOptionsNonNullableAttribute ): LoaderTask { const { base, item, forceReload } = options @@ -459,8 +412,9 @@ class KnowledgeService { loaderTasks: [ { state: LoaderTaskItemState.PENDING, - task: () => - ragApplication + task: async () => { + const ragApplication = await getRagApplication() + return ragApplication .addLoader( new SitemapLoader({ url: content, chunkSize: base.chunkSize, chunkOverlap: base.chunkOverlap }) as any, forceReload @@ -478,11 +432,12 @@ class KnowledgeService { .catch((err) => { logger.error('Failed to add sitemap loader:', err) return { - ...KnowledgeService.ERROR_LOADER_RETURN, + ...EmbedJsFramework.ERROR_LOADER_RETURN, message: `Failed to add sitemap loader: ${err.message}`, messageSource: 'embedding' } - }), + }) + }, evaluateTaskWorkload: { workload: 20 * MB } } ], @@ -492,7 +447,7 @@ class KnowledgeService { } private noteTask( - ragApplication: RAGApplication, + getRagApplication: () => Promise, options: KnowledgeBaseAddItemOptionsNonNullableAttribute ): LoaderTask { const { base, item, forceReload } = options @@ -505,7 +460,8 @@ class KnowledgeService { loaderTasks: [ { state: LoaderTaskItemState.PENDING, - task: () => { + task: async () => { + const ragApplication = await getRagApplication() const loaderReturn = ragApplication.addLoader( new NoteLoader({ text: content, @@ -528,7 +484,7 @@ class KnowledgeService { .catch((err) => { logger.error('Failed to add note loader:', err) return { - ...KnowledgeService.ERROR_LOADER_RETURN, + ...EmbedJsFramework.ERROR_LOADER_RETURN, message: `Failed to add note loader: ${err.message}`, messageSource: 'embedding' } @@ -541,199 +497,4 @@ class KnowledgeService { } return loaderTask } - - private processingQueueHandle() { - const getSubtasksUntilMaximumLoad = (): QueueTaskItem[] => { - const queueTaskList: QueueTaskItem[] = [] - that: for (const [task, resolve] of this.knowledgeItemProcessingQueueMappingPromise) { - for (const item of task.loaderTasks) { - if (this.maximumLoad()) { - break that - } - - const { state, task: taskPromise, evaluateTaskWorkload } = item - - if (state !== LoaderTaskItemState.PENDING) { - continue - } - - const { workload } = evaluateTaskWorkload - this.workload += workload - this.processingItemCount += 1 - item.state = LoaderTaskItemState.PROCESSING - queueTaskList.push({ - taskPromise: () => - taskPromise().then(() => { - this.workload -= workload - this.processingItemCount -= 1 - task.loaderTasks.delete(item) - if (task.loaderTasks.size === 0) { - this.knowledgeItemProcessingQueueMappingPromise.delete(task) - resolve() - } - this.processingQueueHandle() - }), - resolve: () => {}, - evaluateTaskWorkload - }) - } - } - return queueTaskList - } - const subTasks = getSubtasksUntilMaximumLoad() - if (subTasks.length > 0) { - const subTaskPromises = subTasks.map(({ taskPromise }) => taskPromise()) - Promise.all(subTaskPromises).then(() => { - subTasks.forEach(({ resolve }) => resolve()) - }) - } - } - - private appendProcessingQueue(task: LoaderTask): Promise { - return new Promise((resolve) => { - this.knowledgeItemProcessingQueueMappingPromise.set(loaderTaskIntoOfSet(task), () => { - resolve(task.loaderDoneReturn!) - }) - }) - } - - public add = (_: Electron.IpcMainInvokeEvent, options: KnowledgeBaseAddItemOptions): Promise => { - return new Promise((resolve) => { - const { base, item, forceReload = false, userId = '' } = options - const optionsNonNullableAttribute = { base, item, forceReload, userId } - this.getRagApplication(base) - .then((ragApplication) => { - const task = (() => { - switch (item.type) { - case 'file': - return this.fileTask(ragApplication, optionsNonNullableAttribute) - case 'directory': - return this.directoryTask(ragApplication, optionsNonNullableAttribute) - case 'url': - return this.urlTask(ragApplication, optionsNonNullableAttribute) - case 'sitemap': - return this.sitemapTask(ragApplication, optionsNonNullableAttribute) - case 'note': - return this.noteTask(ragApplication, optionsNonNullableAttribute) - default: - return null - } - })() - - if (task) { - this.appendProcessingQueue(task).then(() => { - resolve(task.loaderDoneReturn!) - }) - this.processingQueueHandle() - } else { - resolve({ - ...KnowledgeService.ERROR_LOADER_RETURN, - message: 'Unsupported item type', - messageSource: 'embedding' - }) - } - }) - .catch((err) => { - logger.error('Failed to add item:', err) - resolve({ - ...KnowledgeService.ERROR_LOADER_RETURN, - message: `Failed to add item: ${err.message}`, - messageSource: 'embedding' - }) - }) - }) - } - - @TraceMethod({ spanName: 'remove', tag: 'Knowledge' }) - public async remove( - _: Electron.IpcMainInvokeEvent, - { uniqueId, uniqueIds, base }: { uniqueId: string; uniqueIds: string[]; base: KnowledgeBaseParams } - ): Promise { - const ragApplication = await this.getRagApplication(base) - logger.debug(`Remove Item UniqueId: ${uniqueId}`) - for (const id of uniqueIds) { - await ragApplication.deleteLoader(id) - } - } - - @TraceMethod({ spanName: 'RagSearch', tag: 'Knowledge' }) - public async search( - _: Electron.IpcMainInvokeEvent, - { search, base }: { search: string; base: KnowledgeBaseParams } - ): Promise { - const ragApplication = await this.getRagApplication(base) - return await ragApplication.search(search) - } - - @TraceMethod({ spanName: 'rerank', tag: 'Knowledge' }) - public async rerank( - _: Electron.IpcMainInvokeEvent, - { search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] } - ): Promise { - if (results.length === 0) { - return results - } - return await new Reranker(base).rerank(search, results) - } - - public getStorageDir = (): string => { - return this.storageDir - } - - private preprocessing = async ( - file: FileMetadata, - base: KnowledgeBaseParams, - item: KnowledgeItem, - userId: string - ): Promise => { - let fileToProcess: FileMetadata = file - if (base.preprocessProvider && file.ext.toLowerCase() === '.pdf') { - try { - const provider = new PreprocessProvider(base.preprocessProvider.provider, userId) - const filePath = fileStorage.getFilePathById(file) - // Check if file has already been preprocessed - const alreadyProcessed = await provider.checkIfAlreadyProcessed(file) - if (alreadyProcessed) { - logger.debug(`File already preprocess processed, using cached result: ${filePath}`) - return alreadyProcessed - } - - // Execute preprocessing - logger.debug(`Starting preprocess processing for scanned PDF: ${filePath}`) - const { processedFile, quota } = await provider.parseFile(item.id, file) - fileToProcess = processedFile - const mainWindow = windowService.getMainWindow() - mainWindow?.webContents.send('file-preprocess-finished', { - itemId: item.id, - quota: quota - }) - } catch (err) { - logger.error(`Preprocess processing failed: ${err}`) - // If preprocessing fails, use original file - // fileToProcess = file - throw new Error(`Preprocess processing failed: ${err}`) - } - } - - return fileToProcess - } - - public checkQuota = async ( - _: Electron.IpcMainInvokeEvent, - base: KnowledgeBaseParams, - userId: string - ): Promise => { - try { - if (base.preprocessProvider && base.preprocessProvider.type === 'preprocess') { - const provider = new PreprocessProvider(base.preprocessProvider.provider, userId) - return await provider.checkQuota() - } - throw new Error('No preprocess provider configured') - } catch (err) { - logger.error(`Failed to check quota: ${err}`) - throw new Error(`Failed to check quota: ${err}`) - } - } } - -export default new KnowledgeService() diff --git a/src/main/services/knowledge/IKnowledgeFramework.ts b/src/main/services/knowledge/IKnowledgeFramework.ts new file mode 100644 index 0000000000..2afbbec713 --- /dev/null +++ b/src/main/services/knowledge/IKnowledgeFramework.ts @@ -0,0 +1,72 @@ +import { LoaderReturn } from '@shared/config/types' +import { KnowledgeBaseParams, KnowledgeItem, KnowledgeSearchResult } from '@types' + +export interface KnowledgeBaseAddItemOptions { + base: KnowledgeBaseParams + item: KnowledgeItem + forceReload?: boolean + userId?: string +} + +export interface KnowledgeBaseAddItemOptionsNonNullableAttribute { + base: KnowledgeBaseParams + item: KnowledgeItem + forceReload: boolean + userId: string +} + +export interface EvaluateTaskWorkload { + workload: number +} + +export type LoaderDoneReturn = LoaderReturn | null + +export enum LoaderTaskItemState { + PENDING, + PROCESSING, + DONE +} + +export interface LoaderTaskItem { + state: LoaderTaskItemState + task: () => Promise + evaluateTaskWorkload: EvaluateTaskWorkload +} + +export interface LoaderTask { + loaderTasks: LoaderTaskItem[] + loaderDoneReturn: LoaderDoneReturn +} + +export interface LoaderTaskOfSet { + loaderTasks: Set + loaderDoneReturn: LoaderDoneReturn +} + +export interface QueueTaskItem { + taskPromise: () => Promise + resolve: () => void + evaluateTaskWorkload: EvaluateTaskWorkload +} + +export const loaderTaskIntoOfSet = (loaderTask: LoaderTask): LoaderTaskOfSet => { + return { + loaderTasks: new Set(loaderTask.loaderTasks), + loaderDoneReturn: loaderTask.loaderDoneReturn + } +} + +export interface IKnowledgeFramework { + /** 为给定知识库初始化框架资源 */ + initialize(base: KnowledgeBaseParams): Promise + /** 重置知识库,删除其所有内容 */ + reset(base: KnowledgeBaseParams): Promise + /** 删除与知识库关联的资源,包括文件 */ + delete(id: string): Promise + /** 生成用于添加条目的任务对象,由队列处理 */ + getLoaderTask(options: KnowledgeBaseAddItemOptionsNonNullableAttribute): LoaderTask + /** 从知识库中删除特定条目 */ + remove(options: { uniqueIds: string[]; base: KnowledgeBaseParams }): Promise + /** 搜索知识库 */ + search(options: { search: string; base: KnowledgeBaseParams }): Promise +} diff --git a/src/main/services/knowledge/KnowledgeFrameworkFactory.ts b/src/main/services/knowledge/KnowledgeFrameworkFactory.ts new file mode 100644 index 0000000000..cf26749564 --- /dev/null +++ b/src/main/services/knowledge/KnowledgeFrameworkFactory.ts @@ -0,0 +1,48 @@ +import path from 'node:path' + +import { KnowledgeBaseParams } from '@types' +import { app } from 'electron' + +import { EmbedJsFramework } from './EmbedJsFramework' +import { IKnowledgeFramework } from './IKnowledgeFramework' +import { LangChainFramework } from './LangChainFramework' +class KnowledgeFrameworkFactory { + private static instance: KnowledgeFrameworkFactory + private frameworks: Map = new Map() + private storageDir: string + + private constructor(storageDir: string) { + this.storageDir = storageDir + } + + public static getInstance(storageDir: string): KnowledgeFrameworkFactory { + if (!KnowledgeFrameworkFactory.instance) { + KnowledgeFrameworkFactory.instance = new KnowledgeFrameworkFactory(storageDir) + } + return KnowledgeFrameworkFactory.instance + } + + public getFramework(base: KnowledgeBaseParams): IKnowledgeFramework { + const frameworkType = base.framework || 'embedjs' // 如果未指定,默认为 embedjs + if (this.frameworks.has(frameworkType)) { + return this.frameworks.get(frameworkType)! + } + let framework: IKnowledgeFramework + switch (frameworkType) { + case 'langchain': + framework = new LangChainFramework(this.storageDir) + break + case 'embedjs': + default: + framework = new EmbedJsFramework(this.storageDir) + break + } + + this.frameworks.set(frameworkType, framework) + return framework + } +} + +export const knowledgeFrameworkFactory = KnowledgeFrameworkFactory.getInstance( + path.join(app.getPath('userData'), 'Data', 'KnowledgeBase') +) diff --git a/src/main/services/knowledge/KnowledgeService.ts b/src/main/services/knowledge/KnowledgeService.ts new file mode 100644 index 0000000000..f34a2b31b6 --- /dev/null +++ b/src/main/services/knowledge/KnowledgeService.ts @@ -0,0 +1,190 @@ +import * as fs from 'node:fs' +import path from 'node:path' + +import { loggerService } from '@logger' +import { preprocessingService } from '@main/knowledge/preprocess/PreprocessingService' +import Reranker from '@main/knowledge/reranker/Reranker' +import { TraceMethod } from '@mcp-trace/trace-core' +import { MB } from '@shared/config/constant' +import { LoaderReturn } from '@shared/config/types' +import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types' +import { app } from 'electron' + +import { + KnowledgeBaseAddItemOptions, + LoaderTask, + loaderTaskIntoOfSet, + LoaderTaskItemState, + LoaderTaskOfSet, + QueueTaskItem +} from './IKnowledgeFramework' +import { knowledgeFrameworkFactory } from './KnowledgeFrameworkFactory' + +const logger = loggerService.withContext('MainKnowledgeService') + +class KnowledgeService { + private storageDir = path.join(app.getPath('userData'), 'Data', 'KnowledgeBase') + + private workload = 0 + private processingItemCount = 0 + private knowledgeItemProcessingQueueMappingPromise: Map void> = new Map() + private static MAXIMUM_WORKLOAD = 80 * MB + private static MAXIMUM_PROCESSING_ITEM_COUNT = 30 + private static ERROR_LOADER_RETURN: LoaderReturn = { + entriesAdded: 0, + uniqueId: '', + uniqueIds: [''], + loaderType: '', + status: 'failed' + } + + constructor() { + this.initStorageDir() + } + + private initStorageDir = (): void => { + if (!fs.existsSync(this.storageDir)) { + fs.mkdirSync(this.storageDir, { recursive: true }) + } + } + + private maximumLoad() { + return ( + this.processingItemCount >= KnowledgeService.MAXIMUM_PROCESSING_ITEM_COUNT || + this.workload >= KnowledgeService.MAXIMUM_WORKLOAD + ) + } + + private processingQueueHandle() { + const getSubtasksUntilMaximumLoad = (): QueueTaskItem[] => { + const queueTaskList: QueueTaskItem[] = [] + that: for (const [task, resolve] of this.knowledgeItemProcessingQueueMappingPromise) { + for (const item of task.loaderTasks) { + if (this.maximumLoad()) { + break that + } + + const { state, task: taskPromise, evaluateTaskWorkload } = item + + if (state !== LoaderTaskItemState.PENDING) { + continue + } + + const { workload } = evaluateTaskWorkload + this.workload += workload + this.processingItemCount += 1 + item.state = LoaderTaskItemState.PROCESSING + queueTaskList.push({ + taskPromise: () => + taskPromise().then(() => { + this.workload -= workload + this.processingItemCount -= 1 + task.loaderTasks.delete(item) + if (task.loaderTasks.size === 0) { + this.knowledgeItemProcessingQueueMappingPromise.delete(task) + resolve() + } + this.processingQueueHandle() + }), + resolve: () => {}, + evaluateTaskWorkload + }) + } + } + return queueTaskList + } + const subTasks = getSubtasksUntilMaximumLoad() + if (subTasks.length > 0) { + const subTaskPromises = subTasks.map(({ taskPromise }) => taskPromise()) + Promise.all(subTaskPromises).then(() => { + subTasks.forEach(({ resolve }) => resolve()) + }) + } + } + + private appendProcessingQueue(task: LoaderTask): Promise { + return new Promise((resolve) => { + this.knowledgeItemProcessingQueueMappingPromise.set(loaderTaskIntoOfSet(task), () => { + resolve(task.loaderDoneReturn!) + }) + }) + } + + public async create(_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise { + logger.info(`Creating knowledge base: ${JSON.stringify(base)}`) + const framework = knowledgeFrameworkFactory.getFramework(base) + await framework.initialize(base) + } + public async reset(_: Electron.IpcMainInvokeEvent, { base }: { base: KnowledgeBaseParams }): Promise { + const framework = knowledgeFrameworkFactory.getFramework(base) + await framework.reset(base) + } + + public async delete(_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams, id: string): Promise { + logger.info(`Deleting knowledge base: ${JSON.stringify(base)}`) + const framework = knowledgeFrameworkFactory.getFramework(base) + await framework.delete(id) + } + + public add = async (_: Electron.IpcMainInvokeEvent, options: KnowledgeBaseAddItemOptions): Promise => { + logger.info(`Adding item to knowledge base: ${JSON.stringify(options)}`) + return new Promise((resolve) => { + const { base, item, forceReload = false, userId = '' } = options + const framework = knowledgeFrameworkFactory.getFramework(base) + + const task = framework.getLoaderTask({ base, item, forceReload, userId }) + + if (task) { + this.appendProcessingQueue(task).then(() => { + resolve(task.loaderDoneReturn!) + }) + this.processingQueueHandle() + } else { + resolve({ + ...KnowledgeService.ERROR_LOADER_RETURN, + message: 'Unsupported item type', + messageSource: 'embedding' + }) + } + }) + } + + public async remove( + _: Electron.IpcMainInvokeEvent, + { uniqueIds, base }: { uniqueIds: string[]; base: KnowledgeBaseParams } + ): Promise { + logger.info(`Removing items from knowledge base: ${JSON.stringify({ uniqueIds, base })}`) + const framework = knowledgeFrameworkFactory.getFramework(base) + await framework.remove({ uniqueIds, base }) + } + public async search( + _: Electron.IpcMainInvokeEvent, + { search, base }: { search: string; base: KnowledgeBaseParams } + ): Promise { + logger.info(`Searching knowledge base: ${JSON.stringify({ search, base })}`) + const framework = knowledgeFrameworkFactory.getFramework(base) + return framework.search({ search, base }) + } + + @TraceMethod({ spanName: 'rerank', tag: 'Knowledge' }) + public async rerank( + _: Electron.IpcMainInvokeEvent, + { search, base, results }: { search: string; base: KnowledgeBaseParams; results: KnowledgeSearchResult[] } + ): Promise { + logger.info(`Reranking knowledge base: ${JSON.stringify({ search, base, results })}`) + if (results.length === 0) { + return results + } + return await new Reranker(base).rerank(search, results) + } + + public getStorageDir = (): string => { + return this.storageDir + } + + public async checkQuota(_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams, userId: string): Promise { + return preprocessingService.checkQuota(base, userId) + } +} + +export default new KnowledgeService() diff --git a/src/main/services/knowledge/LangChainFramework.ts b/src/main/services/knowledge/LangChainFramework.ts new file mode 100644 index 0000000000..b82242e102 --- /dev/null +++ b/src/main/services/knowledge/LangChainFramework.ts @@ -0,0 +1,555 @@ +import * as fs from 'node:fs' +import path from 'node:path' + +import { FaissStore } from '@langchain/community/vectorstores/faiss' +import type { Document } from '@langchain/core/documents' +import { loggerService } from '@logger' +import TextEmbeddings from '@main/knowledge/langchain/embeddings/TextEmbeddings' +import { + addFileLoader, + addNoteLoader, + addSitemapLoader, + addVideoLoader, + addWebLoader +} from '@main/knowledge/langchain/loader' +import { RetrieverFactory } from '@main/knowledge/langchain/retriever' +import { preprocessingService } from '@main/knowledge/preprocess/PreprocessingService' +import { getAllFiles } from '@main/utils/file' +import { getUrlSource } from '@main/utils/knowledge' +import { MB } from '@shared/config/constant' +import { LoaderReturn } from '@shared/config/types' +import { IpcChannel } from '@shared/IpcChannel' +import { + FileMetadata, + isKnowledgeDirectoryItem, + isKnowledgeFileItem, + isKnowledgeNoteItem, + isKnowledgeSitemapItem, + isKnowledgeUrlItem, + isKnowledgeVideoItem, + KnowledgeBaseParams, + KnowledgeSearchResult +} from '@types' +import { uuidv4 } from 'zod/v4' + +import { windowService } from '../WindowService' +import { + IKnowledgeFramework, + KnowledgeBaseAddItemOptionsNonNullableAttribute, + LoaderDoneReturn, + LoaderTask, + LoaderTaskItem, + LoaderTaskItemState +} from './IKnowledgeFramework' + +const logger = loggerService.withContext('LangChainFramework') + +export class LangChainFramework implements IKnowledgeFramework { + private storageDir: string + + private static ERROR_LOADER_RETURN: LoaderReturn = { + entriesAdded: 0, + uniqueId: '', + uniqueIds: [''], + loaderType: '', + status: 'failed' + } + + constructor(storageDir: string) { + this.storageDir = storageDir + this.initStorageDir() + } + private initStorageDir = (): void => { + if (!fs.existsSync(this.storageDir)) { + fs.mkdirSync(this.storageDir, { recursive: true }) + } + } + + private async createDatabase(base: KnowledgeBaseParams): Promise { + const dbPath = path.join(this.storageDir, base.id) + const embeddings = this.getEmbeddings(base) + const vectorStore = new FaissStore(embeddings, {}) + + const mockDocument: Document = { + pageContent: 'Create Database Document', + metadata: {} + } + + await vectorStore.addDocuments([mockDocument], { ids: ['1'] }) + await vectorStore.save(dbPath) + await vectorStore.delete({ ids: ['1'] }) + await vectorStore.save(dbPath) + } + + private getEmbeddings(base: KnowledgeBaseParams): TextEmbeddings { + return new TextEmbeddings({ + embedApiClient: base.embedApiClient, + dimensions: base.dimensions + }) + } + + private async getVectorStore(base: KnowledgeBaseParams): Promise { + const embeddings = this.getEmbeddings(base) + const vectorStore = await FaissStore.load(path.join(this.storageDir, base.id), embeddings) + + return vectorStore + } + + async initialize(base: KnowledgeBaseParams): Promise { + await this.createDatabase(base) + } + async reset(base: KnowledgeBaseParams): Promise { + const dbPath = path.join(this.storageDir, base.id) + if (fs.existsSync(dbPath)) { + fs.rmSync(dbPath, { recursive: true }) + } + } + + async delete(id: string): Promise { + const dbPath = path.join(this.storageDir, id) + if (fs.existsSync(dbPath)) { + fs.rmSync(dbPath, { recursive: true }) + } + } + getLoaderTask(options: KnowledgeBaseAddItemOptionsNonNullableAttribute): LoaderTask { + const { item } = options + const getStore = () => this.getVectorStore(options.base) + switch (item.type) { + case 'file': + return this.fileTask(getStore, options) + case 'directory': + return this.directoryTask(getStore, options) + case 'url': + return this.urlTask(getStore, options) + case 'sitemap': + return this.sitemapTask(getStore, options) + case 'note': + return this.noteTask(getStore, options) + case 'video': + return this.videoTask(getStore, options) + default: + return { + loaderTasks: [], + loaderDoneReturn: null + } + } + } + async remove(options: { uniqueIds: string[]; base: KnowledgeBaseParams }): Promise { + const { uniqueIds, base } = options + const vectorStore = await this.getVectorStore(base) + logger.info(`[ KnowledgeService Remove Item UniqueIds: ${uniqueIds}]`) + + await vectorStore.delete({ ids: uniqueIds }) + await vectorStore.save(path.join(this.storageDir, base.id)) + } + async search(options: { search: string; base: KnowledgeBaseParams }): Promise { + const { search, base } = options + logger.info(`search base: ${JSON.stringify(base)}`) + + try { + const vectorStore = await this.getVectorStore(base) + + // 如果是 bm25 或 hybrid 模式,则从数据库获取所有文档 + const documents: Document[] = await this.getAllDocuments(base) + if (documents.length === 0) return [] + + const retrieverFactory = new RetrieverFactory() + const retriever = retrieverFactory.createRetriever(base, vectorStore, documents) + + const results = await retriever.invoke(search) + logger.info(`Search Results: ${JSON.stringify(results)}`) + + // VectorStoreRetriever 和 EnsembleRetriever 会将分数附加到 metadata.score + // BM25Retriever 默认不返回分数,所以我们需要处理这种情况 + return results.map((item) => { + return { + pageContent: item.pageContent, + metadata: item.metadata, + // 如果 metadata 中没有 score,提供一个默认值 + score: typeof item.metadata.score === 'number' ? item.metadata.score : 0 + } + }) + } catch (error: any) { + logger.error(`Error during search in knowledge base ${base.id}: ${error.message}`) + return [] + } + } + + private fileTask( + getVectorStore: () => Promise, + options: KnowledgeBaseAddItemOptionsNonNullableAttribute + ): LoaderTask { + const { base, item, userId } = options + + if (!isKnowledgeFileItem(item)) { + logger.error(`Invalid item type for fileTask: expected 'file', got '${item.type}'`) + return { + loaderTasks: [], + loaderDoneReturn: { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Invalid item type: expected 'file', got '${item.type}'`, + messageSource: 'validation' + } + } + } + + const file = item.content + + const loaderTask: LoaderTask = { + loaderTasks: [ + { + state: LoaderTaskItemState.PENDING, + task: async () => { + try { + const vectorStore = await getVectorStore() + + // 添加预处理逻辑 + const fileToProcess: FileMetadata = await preprocessingService.preprocessFile(file, base, item, userId) + + // 使用处理后的文件进行加载 + return addFileLoader(base, vectorStore, fileToProcess) + .then((result) => { + loaderTask.loaderDoneReturn = result + return result + }) + .then(async () => { + await vectorStore.save(path.join(this.storageDir, base.id)) + }) + .catch((e) => { + logger.error(`Error in addFileLoader for ${file.name}: ${e}`) + const errorResult: LoaderReturn = { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: e.message, + messageSource: 'embedding' + } + loaderTask.loaderDoneReturn = errorResult + return errorResult + }) + } catch (e: any) { + logger.error(`Preprocessing failed for ${file.name}: ${e}`) + const errorResult: LoaderReturn = { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: e.message, + messageSource: 'preprocess' + } + loaderTask.loaderDoneReturn = errorResult + return errorResult + } + }, + evaluateTaskWorkload: { workload: file.size } + } + ], + loaderDoneReturn: null + } + + return loaderTask + } + private directoryTask( + getVectorStore: () => Promise, + options: KnowledgeBaseAddItemOptionsNonNullableAttribute + ): LoaderTask { + const { base, item } = options + + if (!isKnowledgeDirectoryItem(item)) { + logger.error(`Invalid item type for directoryTask: expected 'directory', got '${item.type}'`) + return { + loaderTasks: [], + loaderDoneReturn: { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Invalid item type: expected 'directory', got '${item.type}'`, + messageSource: 'validation' + } + } + } + + const directory = item.content + const files = getAllFiles(directory) + const totalFiles = files.length + let processedFiles = 0 + + const sendDirectoryProcessingPercent = (totalFiles: number, processedFiles: number) => { + const mainWindow = windowService.getMainWindow() + mainWindow?.webContents.send(IpcChannel.DirectoryProcessingPercent, { + itemId: item.id, + percent: (processedFiles / totalFiles) * 100 + }) + } + + const loaderDoneReturn: LoaderDoneReturn = { + entriesAdded: 0, + uniqueId: `DirectoryLoader_${uuidv4()}`, + uniqueIds: [], + loaderType: 'DirectoryLoader' + } + const loaderTasks: LoaderTaskItem[] = [] + for (const file of files) { + loaderTasks.push({ + state: LoaderTaskItemState.PENDING, + task: async () => { + const vectorStore = await getVectorStore() + return addFileLoader(base, vectorStore, file) + .then((result) => { + loaderDoneReturn.entriesAdded += 1 + processedFiles += 1 + sendDirectoryProcessingPercent(totalFiles, processedFiles) + loaderDoneReturn.uniqueIds.push(result.uniqueId) + return result + }) + .then(async () => { + await vectorStore.save(path.join(this.storageDir, base.id)) + }) + .catch((err) => { + logger.error(err) + return { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Failed to add dir loader: ${err.message}`, + messageSource: 'embedding' + } + }) + }, + evaluateTaskWorkload: { workload: file.size } + }) + } + + return { + loaderTasks, + loaderDoneReturn + } + } + + private urlTask( + getVectorStore: () => Promise, + options: KnowledgeBaseAddItemOptionsNonNullableAttribute + ): LoaderTask { + const { base, item } = options + + if (!isKnowledgeUrlItem(item)) { + logger.error(`Invalid item type for urlTask: expected 'url', got '${item.type}'`) + return { + loaderTasks: [], + loaderDoneReturn: { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Invalid item type: expected 'url', got '${item.type}'`, + messageSource: 'validation' + } + } + } + + const url = item.content + + const loaderTask: LoaderTask = { + loaderTasks: [ + { + state: LoaderTaskItemState.PENDING, + task: async () => { + // 使用处理后的网页进行加载 + const vectorStore = await getVectorStore() + return addWebLoader(base, vectorStore, url, getUrlSource(url)) + .then((result) => { + loaderTask.loaderDoneReturn = result + return result + }) + .then(async () => { + await vectorStore.save(path.join(this.storageDir, base.id)) + }) + .catch((e) => { + logger.error(`Error in addWebLoader for ${url}: ${e}`) + const errorResult: LoaderReturn = { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: e.message, + messageSource: 'embedding' + } + loaderTask.loaderDoneReturn = errorResult + return errorResult + }) + }, + evaluateTaskWorkload: { workload: 2 * MB } + } + ], + loaderDoneReturn: null + } + return loaderTask + } + + private sitemapTask( + getVectorStore: () => Promise, + options: KnowledgeBaseAddItemOptionsNonNullableAttribute + ): LoaderTask { + const { base, item } = options + + if (!isKnowledgeSitemapItem(item)) { + logger.error(`Invalid item type for sitemapTask: expected 'sitemap', got '${item.type}'`) + return { + loaderTasks: [], + loaderDoneReturn: { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Invalid item type: expected 'sitemap', got '${item.type}'`, + messageSource: 'validation' + } + } + } + + const url = item.content + + const loaderTask: LoaderTask = { + loaderTasks: [ + { + state: LoaderTaskItemState.PENDING, + task: async () => { + // 使用处理后的网页进行加载 + const vectorStore = await getVectorStore() + return addSitemapLoader(base, vectorStore, url) + .then((result) => { + loaderTask.loaderDoneReturn = result + return result + }) + .then(async () => { + await vectorStore.save(path.join(this.storageDir, base.id)) + }) + .catch((e) => { + logger.error(`Error in addWebLoader for ${url}: ${e}`) + const errorResult: LoaderReturn = { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: e.message, + messageSource: 'embedding' + } + loaderTask.loaderDoneReturn = errorResult + return errorResult + }) + }, + evaluateTaskWorkload: { workload: 2 * MB } + } + ], + loaderDoneReturn: null + } + return loaderTask + } + + private noteTask( + getVectorStore: () => Promise, + options: KnowledgeBaseAddItemOptionsNonNullableAttribute + ): LoaderTask { + const { base, item } = options + + if (!isKnowledgeNoteItem(item)) { + logger.error(`Invalid item type for noteTask: expected 'note', got '${item.type}'`) + return { + loaderTasks: [], + loaderDoneReturn: { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Invalid item type: expected 'note', got '${item.type}'`, + messageSource: 'validation' + } + } + } + + const content = item.content + const sourceUrl = item.sourceUrl ?? '' + + logger.info(`noteTask ${content}, ${sourceUrl}`) + + const encoder = new TextEncoder() + const contentBytes = encoder.encode(content) + const loaderTask: LoaderTask = { + loaderTasks: [ + { + state: LoaderTaskItemState.PENDING, + task: async () => { + // 使用处理后的笔记进行加载 + const vectorStore = await getVectorStore() + return addNoteLoader(base, vectorStore, content, sourceUrl) + .then((result) => { + loaderTask.loaderDoneReturn = result + return result + }) + .then(async () => { + await vectorStore.save(path.join(this.storageDir, base.id)) + }) + .catch((e) => { + logger.error(`Error in addNoteLoader for ${sourceUrl}: ${e}`) + const errorResult: LoaderReturn = { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: e.message, + messageSource: 'embedding' + } + loaderTask.loaderDoneReturn = errorResult + return errorResult + }) + }, + evaluateTaskWorkload: { workload: contentBytes.length } + } + ], + loaderDoneReturn: null + } + return loaderTask + } + + private videoTask( + getVectorStore: () => Promise, + options: KnowledgeBaseAddItemOptionsNonNullableAttribute + ): LoaderTask { + const { base, item } = options + + if (!isKnowledgeVideoItem(item)) { + logger.error(`Invalid item type for videoTask: expected 'video', got '${item.type}'`) + return { + loaderTasks: [], + loaderDoneReturn: { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Invalid item type: expected 'video', got '${item.type}'`, + messageSource: 'validation' + } + } + } + + const files = item.content + + const loaderTask: LoaderTask = { + loaderTasks: [ + { + state: LoaderTaskItemState.PENDING, + task: async () => { + const vectorStore = await getVectorStore() + return addVideoLoader(base, vectorStore, files) + .then((result) => { + loaderTask.loaderDoneReturn = result + return result + }) + .then(async () => { + await vectorStore.save(path.join(this.storageDir, base.id)) + }) + .catch((e) => { + logger.error(`Preprocessing failed for ${files[0].name}: ${e}`) + const errorResult: LoaderReturn = { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: e.message, + messageSource: 'preprocess' + } + loaderTask.loaderDoneReturn = errorResult + return errorResult + }) + }, + evaluateTaskWorkload: { workload: files[0].size } + } + ], + loaderDoneReturn: null + } + return loaderTask + } + + private async getAllDocuments(base: KnowledgeBaseParams): Promise { + logger.info(`Fetching all documents from database for knowledge base: ${base.id}`) + + try { + const results = (await this.getVectorStore(base)).docstore._docs + + const documents: Document[] = Array.from(results.values()) + logger.info(`Fetched ${documents.length} documents for BM25/Hybrid retriever.`) + return documents + } catch (e) { + logger.error(`Could not fetch documents from database for base ${base.id}: ${e}`) + // 如果表不存在或查询失败,返回空数组 + return [] + } + } +} diff --git a/src/main/services/memory/MemoryService.ts b/src/main/services/memory/MemoryService.ts index aba341391f..85f182b686 100644 --- a/src/main/services/memory/MemoryService.ts +++ b/src/main/services/memory/MemoryService.ts @@ -1,6 +1,6 @@ import { Client, createClient } from '@libsql/client' import { loggerService } from '@logger' -import Embeddings from '@main/knowledge/embeddings/Embeddings' +import Embeddings from '@main/knowledge/embedjs/embeddings/Embeddings' import type { AddMemoryOptions, AssistantMessage, diff --git a/src/main/services/ocr/OcrService.ts b/src/main/services/ocr/OcrService.ts index dfd796346f..471d31edce 100644 --- a/src/main/services/ocr/OcrService.ts +++ b/src/main/services/ocr/OcrService.ts @@ -2,6 +2,7 @@ import { loggerService } from '@logger' import { isLinux } from '@main/constant' import { BuiltinOcrProviderIds, OcrHandler, OcrProvider, OcrResult, SupportedOcrFile } from '@types' +import { ppocrService } from './builtin/PpocrService' import { systemOcrService } from './builtin/SystemOcrService' import { tesseractService } from './builtin/TesseractService' @@ -36,3 +37,5 @@ export const ocrService = new OcrService() ocrService.register(BuiltinOcrProviderIds.tesseract, tesseractService.ocr.bind(tesseractService)) !isLinux && ocrService.register(BuiltinOcrProviderIds.system, systemOcrService.ocr.bind(systemOcrService)) + +ocrService.register(BuiltinOcrProviderIds.paddleocr, ppocrService.ocr.bind(ppocrService)) diff --git a/src/main/services/ocr/builtin/PpocrService.ts b/src/main/services/ocr/builtin/PpocrService.ts new file mode 100644 index 0000000000..2079f2d6b8 --- /dev/null +++ b/src/main/services/ocr/builtin/PpocrService.ts @@ -0,0 +1,100 @@ +import { loadOcrImage } from '@main/utils/ocr' +import { ImageFileMetadata, isImageFileMetadata, OcrPpocrConfig, OcrResult, SupportedOcrFile } from '@types' +import { net } from 'electron' +import { z } from 'zod' + +import { OcrBaseService } from './OcrBaseService' + +enum FileType { + PDF = 0, + Image = 1 +} + +// API Reference: https://www.paddleocr.ai/latest/version3.x/pipeline_usage/OCR.html#3 +interface OcrPayload { + file: string + fileType?: FileType | null + useDocOrientationClassify?: boolean | null + useDocUnwarping?: boolean | null + useTextlineOrientation?: boolean | null + textDetLimitSideLen?: number | null + textDetLimitType?: string | null + textDetThresh?: number | null + textDetBoxThresh?: number | null + textDetUnclipRatio?: number | null + textRecScoreThresh?: number | null + visualize?: boolean | null +} + +const OcrResponseSchema = z.object({ + result: z.object({ + ocrResults: z.array( + z.object({ + prunedResult: z.object({ + rec_texts: z.array(z.string()) + }) + }) + ) + }) +}) + +export class PpocrService extends OcrBaseService { + public ocr = async (file: SupportedOcrFile, options?: OcrPpocrConfig): Promise => { + if (!isImageFileMetadata(file)) { + throw new Error('Only image files are supported currently') + } + if (!options) { + throw new Error('config is required') + } + return this.imageOcr(file, options) + } + + private async imageOcr(file: ImageFileMetadata, options: OcrPpocrConfig): Promise { + if (!options.apiUrl) { + throw new Error('API URL is required') + } + const apiUrl = options.apiUrl + + const buffer = await loadOcrImage(file) + const base64 = buffer.toString('base64') + const payload = { + file: base64, + fileType: FileType.Image, + useDocOrientationClassify: false, + useDocUnwarping: false, + visualize: false + } satisfies OcrPayload + + const headers: Record = { + 'Content-Type': 'application/json' + } + + if (options.accessToken) { + headers['Authorization'] = `token ${options.accessToken}` + } + + try { + const response = await net.fetch(apiUrl, { + method: 'POST', + headers, + body: JSON.stringify(payload) + }) + + if (!response.ok) { + const text = await response.text() + throw new Error(`OCR service error: ${response.status} ${response.statusText} - ${text}`) + } + + const data = await response.json() + + const validatedResponse = OcrResponseSchema.parse(data) + const recTexts = validatedResponse.result.ocrResults[0].prunedResult.rec_texts + + return { text: recTexts.join('\n') } + } catch (error: any) { + throw new Error(`OCR service error: ${error.message}`) + } + } +} + +export const ppocrService = new PpocrService() diff --git a/src/main/utils/file.ts b/src/main/utils/file.ts index 2f622d3544..97e87bf9ea 100644 --- a/src/main/utils/file.ts +++ b/src/main/utils/file.ts @@ -205,6 +205,19 @@ export async function readTextFileWithAutoEncoding(filePath: string): Promise { + const filePath = path.join(getFilesDir(), `${file.id}${file.ext}`) + const data = await fs.promises.readFile(filePath) + const base64 = data.toString('base64') + const ext = path.extname(filePath).slice(1) == 'jpg' ? 'jpeg' : path.extname(filePath).slice(1) + const mime = `image/${ext}` + return { + mime, + base64, + data: `data:${mime};base64,${base64}` + } +} + /** * 递归扫描目录,获取符合条件的文件和目录结构 * @param dirPath 当前要扫描的路径 diff --git a/src/main/utils/knowledge.ts b/src/main/utils/knowledge.ts new file mode 100644 index 0000000000..cce85829d0 --- /dev/null +++ b/src/main/utils/knowledge.ts @@ -0,0 +1,13 @@ +export const DEFAULT_DOCUMENT_COUNT = 6 +export const DEFAULT_RELEVANT_SCORE = 0 +export type UrlSource = 'normal' | 'github' | 'youtube' + +const youtubeRegex = /^(https?:\/\/)?(www\.)?(youtube\.com|youtu\.be|youtube\.be|yt\.be)/i + +export function getUrlSource(url: string): UrlSource { + if (youtubeRegex.test(url)) { + return 'youtube' + } else { + return 'normal' + } +} diff --git a/src/preload/index.ts b/src/preload/index.ts index 2244264753..0600b6b310 100644 --- a/src/preload/index.ts +++ b/src/preload/index.ts @@ -1,4 +1,3 @@ -import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import { electronAPI } from '@electron-toolkit/preload' import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' import { SpanContext } from '@opentelemetry/api' @@ -14,6 +13,7 @@ import { FileUploadResponse, KnowledgeBaseParams, KnowledgeItem, + KnowledgeSearchResult, MCPServer, MemoryConfig, MemoryListOptions, @@ -166,7 +166,8 @@ const api = { selectFolder: (options?: OpenDialogOptions) => ipcRenderer.invoke(IpcChannel.File_SelectFolder, options), saveImage: (name: string, data: string) => ipcRenderer.invoke(IpcChannel.File_SaveImage, name, data), binaryImage: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryImage, fileId), - base64Image: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64Image, fileId), + base64Image: (fileId: string): Promise<{ mime: string; base64: string; data: string }> => + ipcRenderer.invoke(IpcChannel.File_Base64Image, fileId), saveBase64Image: (data: string) => ipcRenderer.invoke(IpcChannel.File_SaveBase64Image, data), savePastedImage: (imageData: Uint8Array, extension?: string) => ipcRenderer.invoke(IpcChannel.File_SavePastedImage, imageData, extension), @@ -215,7 +216,7 @@ const api = { create: (base: KnowledgeBaseParams, context?: SpanContext) => tracedInvoke(IpcChannel.KnowledgeBase_Create, context, base), reset: (base: KnowledgeBaseParams) => ipcRenderer.invoke(IpcChannel.KnowledgeBase_Reset, base), - delete: (id: string) => ipcRenderer.invoke(IpcChannel.KnowledgeBase_Delete, id), + delete: (base: KnowledgeBaseParams, id: string) => ipcRenderer.invoke(IpcChannel.KnowledgeBase_Delete, base, id), add: ({ base, item, @@ -232,7 +233,7 @@ const api = { search: ({ search, base }: { search: string; base: KnowledgeBaseParams }, context?: SpanContext) => tracedInvoke(IpcChannel.KnowledgeBase_Search, context, { search, base }), rerank: ( - { search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] }, + { search, base, results }: { search: string; base: KnowledgeBaseParams; results: KnowledgeSearchResult[] }, context?: SpanContext ) => tracedInvoke(IpcChannel.KnowledgeBase_Rerank, context, { search, base, results }), checkQuota: ({ base, userId }: { base: KnowledgeBaseParams; userId: string }) => diff --git a/src/renderer/index.html b/src/renderer/index.html index 239d9c794c..7b9bfdc329 100644 --- a/src/renderer/index.html +++ b/src/renderer/index.html @@ -5,7 +5,7 @@ + content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' 'unsafe-inline' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; media-src 'self' file:; frame-src * file:" /> Cherry Studio