From cac84a8795957ed1969e5283ae63724e913cf3ab Mon Sep 17 00:00:00 2001 From: LiuVaayne <10231735+vaayne@users.noreply.github.com> Date: Thu, 4 Sep 2025 14:52:47 +0800 Subject: [PATCH 1/3] refactor(mcp): enhance MCPService logging and error handling (#9878) --- src/main/services/MCPService.ts | 169 ++++++++++++------ .../settings/MCPSettings/McpServersList.tsx | 10 +- 2 files changed, 119 insertions(+), 60 deletions(-) diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index 538a0ede79..775ca12db5 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -56,6 +56,45 @@ type CallToolArgs = { server: MCPServer; name: string; args: any; callId?: strin const logger = loggerService.withContext('MCPService') +// Redact potentially sensitive fields in objects (headers, tokens, api keys) +function redactSensitive(input: any): any { + const SENSITIVE_KEYS = ['authorization', 'Authorization', 'apiKey', 'api_key', 'apikey', 'token', 'access_token'] + const MAX_STRING = 300 + + const redact = (val: any): any => { + if (val == null) return val + if (typeof val === 'string') { + return val.length > MAX_STRING ? `${val.slice(0, MAX_STRING)}…<${val.length - MAX_STRING} more>` : val + } + if (Array.isArray(val)) return val.map((v) => redact(v)) + if (typeof val === 'object') { + const out: Record = {} + for (const [k, v] of Object.entries(val)) { + if (SENSITIVE_KEYS.includes(k)) { + out[k] = '' + } else { + out[k] = redact(v) + } + } + return out + } + return val + } + + return redact(input) +} + +// Create a context-aware logger for a server +function getServerLogger(server: MCPServer, extra?: Record) { + const base = { + serverName: server?.name, + serverId: server?.id, + baseUrl: server?.baseUrl, + type: server?.type || (server?.command ? 'stdio' : server?.baseUrl ? 'http' : 'inmemory') + } + return loggerService.withContext('MCPService', { ...base, ...(extra || {}) }) +} + /** * Higher-order function to add caching capability to any async function * @param fn The original function to be wrapped with caching @@ -74,15 +113,17 @@ function withCache( const cacheKey = getCacheKey(...args) if (CacheService.has(cacheKey)) { - logger.debug(`${logPrefix} loaded from cache`) + logger.debug(`${logPrefix} loaded from cache`, { cacheKey }) const cachedData = CacheService.get(cacheKey) if (cachedData) { return cachedData } } + const start = Date.now() const result = await fn(...args) CacheService.set(cacheKey, result, ttl) + logger.debug(`${logPrefix} cached`, { cacheKey, ttlMs: ttl, durationMs: Date.now() - start }) return result } } @@ -128,6 +169,7 @@ class McpService { // If there's a pending initialization, wait for it const pendingClient = this.pendingClients.get(serverKey) if (pendingClient) { + getServerLogger(server).silly(`Waiting for pending client initialization`) return pendingClient } @@ -136,8 +178,11 @@ class McpService { if (existingClient) { try { // Check if the existing client is still connected - const pingResult = await existingClient.ping() - logger.debug(`Ping result for ${server.name}:`, pingResult) + const pingResult = await existingClient.ping({ + // add short timeout to prevent hanging + timeout: 1000 + }) + getServerLogger(server).debug(`Ping result`, { ok: !!pingResult }) // If the ping fails, remove the client from the cache // and create a new one if (!pingResult) { @@ -146,7 +191,7 @@ class McpService { return existingClient } } catch (error: any) { - logger.error(`Error pinging server ${server.name}:`, error?.message) + getServerLogger(server).error(`Error pinging server`, error as Error) this.clients.delete(serverKey) } } @@ -172,15 +217,15 @@ class McpService { > => { // Create appropriate transport based on configuration if (isBuiltinMCPServer(server) && server.name !== BuiltinMCPServerNames.mcpAutoInstall) { - logger.debug(`Using in-memory transport for server: ${server.name}`) + getServerLogger(server).debug(`Using in-memory transport`) const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair() // start the in-memory server with the given name and environment variables const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {}) try { await inMemoryServer.connect(serverTransport) - logger.debug(`In-memory server started: ${server.name}`) + getServerLogger(server).debug(`In-memory server started`) } catch (error: Error | any) { - logger.error(`Error starting in-memory server: ${error}`) + getServerLogger(server).error(`Error starting in-memory server`, error as Error) throw new Error(`Failed to start in-memory server: ${error.message}`) } // set the client transport to the client @@ -193,7 +238,10 @@ class McpService { }, authProvider } - logger.debug(`StreamableHTTPClientTransport options:`, options) + // redact headers before logging + getServerLogger(server).debug(`StreamableHTTPClientTransport options`, { + options: redactSensitive(options) + }) return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options) } else if (server.type === 'sse') { const options: SSEClientTransportOptions = { @@ -209,7 +257,7 @@ class McpService { headers['Authorization'] = `Bearer ${tokens.access_token}` } } catch (error) { - logger.error('Failed to fetch tokens:', error as Error) + getServerLogger(server).error('Failed to fetch tokens:', error as Error) } } @@ -239,15 +287,18 @@ class McpService { ...server.env, ...resolvedConfig.env } - logger.debug(`Using resolved DXT config - command: ${cmd}, args: ${args?.join(' ')}`) + getServerLogger(server).debug(`Using resolved DXT config`, { + command: cmd, + args + }) } else { - logger.warn(`Failed to resolve DXT config for ${server.name}, falling back to manifest values`) + getServerLogger(server).warn(`Failed to resolve DXT config, falling back to manifest values`) } } if (server.command === 'npx') { cmd = await getBinaryPath('bun') - logger.debug(`Using command: ${cmd}`) + getServerLogger(server).debug(`Using command`, { command: cmd }) // add -x to args if args exist if (args && args.length > 0) { @@ -282,7 +333,7 @@ class McpService { } } - logger.debug(`Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`) + getServerLogger(server).debug(`Starting server`, { command: cmd, args }) // Logger.info(`[MCP] Environment variables for server:`, server.env) const loginShellEnv = await this.getLoginShellEnv() @@ -304,12 +355,14 @@ class McpService { // For DXT servers, set the working directory to the extracted path if (server.dxtPath) { transportOptions.cwd = server.dxtPath - logger.debug(`Setting working directory for DXT server: ${server.dxtPath}`) + getServerLogger(server).debug(`Setting working directory for DXT server`, { + cwd: server.dxtPath + }) } const stdioTransport = new StdioClientTransport(transportOptions) stdioTransport.stderr?.on('data', (data) => - logger.debug(`Stdio stderr for server: ${server.name}` + data.toString()) + getServerLogger(server).debug(`Stdio stderr`, { data: data.toString() }) ) return stdioTransport } else { @@ -318,7 +371,7 @@ class McpService { } const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => { - logger.debug(`Starting OAuth flow for server: ${server.name}`) + getServerLogger(server).debug(`Starting OAuth flow`) // Create an event emitter for the OAuth callback const events = new EventEmitter() @@ -331,27 +384,27 @@ class McpService { // Set a timeout to close the callback server const timeoutId = setTimeout(() => { - logger.warn(`OAuth flow timed out for server: ${server.name}`) + getServerLogger(server).warn(`OAuth flow timed out`) callbackServer.close() }, 300000) // 5 minutes timeout try { // Wait for the authorization code const authCode = await callbackServer.waitForAuthCode() - logger.debug(`Received auth code: ${authCode}`) + getServerLogger(server).debug(`Received auth code`) // Complete the OAuth flow await transport.finishAuth(authCode) - logger.debug(`OAuth flow completed for server: ${server.name}`) + getServerLogger(server).debug(`OAuth flow completed`) const newTransport = await initTransport() // Try to connect again await client.connect(newTransport) - logger.debug(`Successfully authenticated with server: ${server.name}`) + getServerLogger(server).debug(`Successfully authenticated`) } catch (oauthError) { - logger.error(`OAuth authentication failed for server ${server.name}:`, oauthError as Error) + getServerLogger(server).error(`OAuth authentication failed`, oauthError as Error) throw new Error( `OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}` ) @@ -390,7 +443,7 @@ class McpService { logger.debug(`Activated server: ${server.name}`) return client } catch (error: any) { - logger.error(`Error activating server ${server.name}:`, error?.message) + getServerLogger(server).error(`Error activating server`, error as Error) throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`) } } finally { @@ -450,9 +503,9 @@ class McpService { logger.debug(`Message from server ${server.name}:`, notification.params) }) - logger.debug(`Set up notification handlers for server: ${server.name}`) + getServerLogger(server).debug(`Set up notification handlers`) } catch (error) { - logger.error(`Failed to set up notification handlers for server ${server.name}:`, error as Error) + getServerLogger(server).error(`Failed to set up notification handlers`, error as Error) } } @@ -470,7 +523,7 @@ class McpService { CacheService.remove(`mcp:list_tool:${serverKey}`) CacheService.remove(`mcp:list_prompts:${serverKey}`) CacheService.remove(`mcp:list_resources:${serverKey}`) - logger.debug(`Cleared all caches for server: ${serverKey}`) + logger.debug(`Cleared all caches for server`, { serverKey }) } async closeClient(serverKey: string) { @@ -478,18 +531,18 @@ class McpService { if (client) { // Remove the client from the cache await client.close() - logger.debug(`Closed server: ${serverKey}`) + logger.debug(`Closed server`, { serverKey }) this.clients.delete(serverKey) // Clear all caches for this server this.clearServerCache(serverKey) } else { - logger.warn(`No client found for server: ${serverKey}`) + logger.warn(`No client found for server`, { serverKey }) } } async stopServer(_: Electron.IpcMainInvokeEvent, server: MCPServer) { const serverKey = this.getServerKey(server) - logger.debug(`Stopping server: ${server.name}`) + getServerLogger(server).debug(`Stopping server`) await this.closeClient(serverKey) } @@ -505,16 +558,16 @@ class McpService { try { const cleaned = this.dxtService.cleanupDxtServer(server.name) if (cleaned) { - logger.debug(`Cleaned up DXT server directory for: ${server.name}`) + getServerLogger(server).debug(`Cleaned up DXT server directory`) } } catch (error) { - logger.error(`Failed to cleanup DXT server: ${server.name}`, error as Error) + getServerLogger(server).error(`Failed to cleanup DXT server`, error as Error) } } } async restartServer(_: Electron.IpcMainInvokeEvent, server: MCPServer) { - logger.debug(`Restarting server: ${server.name}`) + getServerLogger(server).debug(`Restarting server`) const serverKey = this.getServerKey(server) await this.closeClient(serverKey) // Clear cache before restarting to ensure fresh data @@ -527,7 +580,7 @@ class McpService { try { await this.closeClient(key) } catch (error: any) { - logger.error(`Failed to close client: ${error?.message}`) + logger.error(`Failed to close client`, error as Error) } } } @@ -536,9 +589,9 @@ class McpService { * Check connectivity for an MCP server */ public async checkMcpConnectivity(_: Electron.IpcMainInvokeEvent, server: MCPServer): Promise { - logger.debug(`Checking connectivity for server: ${server.name}`) + getServerLogger(server).debug(`Checking connectivity`) try { - logger.debug(`About to call initClient for server: ${server.name}`, { hasInitClient: !!this.initClient }) + getServerLogger(server).debug(`About to call initClient`, { hasInitClient: !!this.initClient }) if (!this.initClient) { throw new Error('initClient method is not available') @@ -547,10 +600,10 @@ class McpService { const client = await this.initClient(server) // Attempt to list tools as a way to check connectivity await client.listTools() - logger.debug(`Connectivity check successful for server: ${server.name}`) + getServerLogger(server).debug(`Connectivity check successful`) return true } catch (error) { - logger.error(`Connectivity check failed for server: ${server.name}`, error as Error) + getServerLogger(server).error(`Connectivity check failed`, error as Error) // Close the client if connectivity check fails to ensure a clean state for the next attempt const serverKey = this.getServerKey(server) await this.closeClient(serverKey) @@ -559,9 +612,8 @@ class McpService { } private async listToolsImpl(server: MCPServer): Promise { - logger.debug(`Listing tools for server: ${server.name}`) + getServerLogger(server).debug(`Listing tools`) const client = await this.initClient(server) - logger.debug(`Client for server: ${server.name}`, client) try { const { tools } = await client.listTools() const serverTools: MCPTool[] = [] @@ -577,7 +629,7 @@ class McpService { }) return serverTools } catch (error: any) { - logger.error(`Failed to list tools for server: ${server.name}`, error?.message) + getServerLogger(server).error(`Failed to list tools`, error as Error) return [] } } @@ -614,12 +666,16 @@ class McpService { const callToolFunc = async ({ server, name, args }: CallToolArgs) => { try { - logger.debug(`Calling: ${server.name} ${name} ${JSON.stringify(args)} callId: ${toolCallId}`, server) + getServerLogger(server, { tool: name, callId: toolCallId }).debug(`Calling tool`, { + args: redactSensitive(args) + }) if (typeof args === 'string') { try { args = JSON.parse(args) } catch (e) { - logger.error('args parse error', args) + getServerLogger(server, { tool: name, callId: toolCallId }).error('args parse error', e as Error, { + args + }) } if (args === '') { args = {} @@ -628,8 +684,9 @@ class McpService { const client = await this.initClient(server) const result = await client.callTool({ name, arguments: args }, undefined, { onprogress: (process) => { - logger.debug(`Progress: ${process.progress / (process.total || 1)}`) - logger.debug(`Progress notification received for server: ${server.name}`, process) + getServerLogger(server, { tool: name, callId: toolCallId }).debug(`Progress`, { + ratio: process.progress / (process.total || 1) + }) const mainWindow = windowService.getMainWindow() if (mainWindow) { mainWindow.webContents.send('mcp-progress', process.progress / (process.total || 1)) @@ -644,7 +701,7 @@ class McpService { }) return result as MCPCallToolResponse } catch (error) { - logger.error(`Error calling tool ${name} on ${server.name}:`, error as Error) + getServerLogger(server, { tool: name, callId: toolCallId }).error(`Error calling tool`, error as Error) throw error } finally { this.activeToolCalls.delete(toolCallId) @@ -668,7 +725,7 @@ class McpService { */ private async listPromptsImpl(server: MCPServer): Promise { const client = await this.initClient(server) - logger.debug(`Listing prompts for server: ${server.name}`) + getServerLogger(server).debug(`Listing prompts`) try { const { prompts } = await client.listPrompts() return prompts.map((prompt: any) => ({ @@ -680,7 +737,7 @@ class McpService { } catch (error: any) { // -32601 is the code for the method not found if (error?.code !== -32601) { - logger.error(`Failed to list prompts for server: ${server.name}`, error?.message) + getServerLogger(server).error(`Failed to list prompts`, error as Error) } return [] } @@ -749,7 +806,7 @@ class McpService { } catch (error: any) { // -32601 is the code for the method not found if (error?.code !== -32601) { - logger.error(`Failed to list resources for server: ${server.name}`, error?.message) + getServerLogger(server).error(`Failed to list resources`, error as Error) } return [] } @@ -775,7 +832,7 @@ class McpService { * Get a specific resource from an MCP server (implementation) */ private async getResourceImpl(server: MCPServer, uri: string): Promise { - logger.debug(`Getting resource ${uri} from server: ${server.name}`) + getServerLogger(server, { uri }).debug(`Getting resource`) const client = await this.initClient(server) try { const result = await client.readResource({ uri: uri }) @@ -793,7 +850,7 @@ class McpService { contents: contents } } catch (error: Error | any) { - logger.error(`Failed to get resource ${uri} from server: ${server.name}`, error.message) + getServerLogger(server, { uri }).error(`Failed to get resource`, error as Error) throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`) } } @@ -838,10 +895,10 @@ class McpService { if (activeToolCall) { activeToolCall.abort() this.activeToolCalls.delete(callId) - logger.debug(`Aborted tool call: ${callId}`) + logger.debug(`Aborted tool call`, { callId }) return true } else { - logger.warn(`No active tool call found for callId: ${callId}`) + logger.warn(`No active tool call found for callId`, { callId }) return false } } @@ -851,22 +908,22 @@ class McpService { */ public async getServerVersion(_: Electron.IpcMainInvokeEvent, server: MCPServer): Promise { try { - logger.debug(`Getting server version for: ${server.name}`) + getServerLogger(server).debug(`Getting server version`) const client = await this.initClient(server) // Try to get server information which may include version const serverInfo = client.getServerVersion() - logger.debug(`Server info for ${server.name}:`, serverInfo) + getServerLogger(server).debug(`Server info`, redactSensitive(serverInfo)) if (serverInfo && serverInfo.version) { - logger.debug(`Server version for ${server.name}: ${serverInfo.version}`) + getServerLogger(server).debug(`Server version`, { version: serverInfo.version }) return serverInfo.version } - logger.warn(`No version information available for server: ${server.name}`) + getServerLogger(server).warn(`No version information available`) return null } catch (error: any) { - logger.error(`Failed to get server version for ${server.name}:`, error?.message) + getServerLogger(server).error(`Failed to get server version`, error as Error) return null } } diff --git a/src/renderer/src/pages/settings/MCPSettings/McpServersList.tsx b/src/renderer/src/pages/settings/MCPSettings/McpServersList.tsx index bb78b49db7..3bf7d619a1 100644 --- a/src/renderer/src/pages/settings/MCPSettings/McpServersList.tsx +++ b/src/renderer/src/pages/settings/MCPSettings/McpServersList.tsx @@ -1,3 +1,4 @@ +import { loggerService } from '@logger' import { nanoid } from '@reduxjs/toolkit' import CollapsibleSearchBar from '@renderer/components/CollapsibleSearchBar' import { Sortable, useDndReorder } from '@renderer/components/dnd' @@ -23,6 +24,8 @@ import McpMarketList from './McpMarketList' import McpServerCard from './McpServerCard' import SyncServersPopup from './SyncServersPopup' +const logger = loggerService.withContext('McpServersList') + const McpServersList: FC = () => { const { mcpServers, addMCPServer, deleteMCPServer, updateMcpServers, updateMCPServer } = useMCPServers() const { t } = useTranslation() @@ -158,12 +161,11 @@ const McpServersList: FC = () => { const handleToggleActive = async (server: MCPServer, active: boolean) => { setLoadingServerIds((prev) => new Set(prev).add(server.id)) const oldActiveState = server.isActive - + logger.silly('toggle activate', { serverId: server.id, active }) try { if (active) { - await window.api.mcp.listTools(server) // Fetch version when server is activated - fetchServerVersion({ ...server, isActive: active }) + await fetchServerVersion({ ...server, isActive: active }) } else { await window.api.mcp.stopServer(server) // Clear version when server is deactivated @@ -259,7 +261,7 @@ const McpServersList: FC = () => { server={server} version={serverVersions[server.id]} isLoading={loadingServerIds.has(server.id)} - onToggle={(active) => handleToggleActive(server, active)} + onToggle={async (active) => await handleToggleActive(server, active)} onDelete={() => onDeleteMcpServer(server)} onEdit={() => navigate(`/settings/mcp/settings/${encodeURIComponent(server.id)}`)} onOpenUrl={(url) => window.open(url, '_blank')} From 7de31d8cb6443ee74a4e4a638276b28d493d4e44 Mon Sep 17 00:00:00 2001 From: Lin Manhui Date: Thu, 4 Sep 2025 17:13:58 +0800 Subject: [PATCH 2/3] feat: Add PaddleOCR as a new OCR provider (#9876) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: support PaddleOCR as an OCR provider * style: fix format * fix: update persistReducer version * update wrt comments * fix(ocr): 修复迁移147中OCR提供商的设置错误 将直接赋值改为使用addOcrProvider方法添加内置PaddleOCR提供商,确保正确初始化OCR服务 * Replace bare fetch with net.fetch * Use '\n' as delimiter * Optimize code wrt comments * Add tip --------- Co-authored-by: icarus --- src/main/services/ocr/OcrService.ts | 3 + src/main/services/ocr/builtin/PpocrService.ts | 100 ++++++++++++++++++ .../src/assets/images/providers/paddleocr.png | Bin 0 -> 16805 bytes src/renderer/src/config/ocr.ts | 16 ++- src/renderer/src/hooks/useOcrProvider.tsx | 3 + src/renderer/src/i18n/label.ts | 4 +- src/renderer/src/i18n/locales/en-us.json | 7 ++ src/renderer/src/i18n/locales/ja-jp.json | 7 ++ src/renderer/src/i18n/locales/ru-ru.json | 7 ++ src/renderer/src/i18n/locales/zh-cn.json | 7 ++ src/renderer/src/i18n/locales/zh-tw.json | 7 ++ src/renderer/src/i18n/translate/el-gr.json | 7 ++ src/renderer/src/i18n/translate/es-es.json | 7 ++ src/renderer/src/i18n/translate/fr-fr.json | 7 ++ src/renderer/src/i18n/translate/pt-pt.json | 7 ++ .../DocProcessSettings/OcrPpocrSettings.tsx | 83 +++++++++++++++ .../OcrProviderSettings.tsx | 3 + src/renderer/src/store/index.ts | 2 +- src/renderer/src/store/migrate.ts | 9 ++ src/renderer/src/types/ocr.ts | 22 +++- 20 files changed, 303 insertions(+), 5 deletions(-) create mode 100644 src/main/services/ocr/builtin/PpocrService.ts create mode 100644 src/renderer/src/assets/images/providers/paddleocr.png create mode 100644 src/renderer/src/pages/settings/DocProcessSettings/OcrPpocrSettings.tsx diff --git a/src/main/services/ocr/OcrService.ts b/src/main/services/ocr/OcrService.ts index dfd796346f..471d31edce 100644 --- a/src/main/services/ocr/OcrService.ts +++ b/src/main/services/ocr/OcrService.ts @@ -2,6 +2,7 @@ import { loggerService } from '@logger' import { isLinux } from '@main/constant' import { BuiltinOcrProviderIds, OcrHandler, OcrProvider, OcrResult, SupportedOcrFile } from '@types' +import { ppocrService } from './builtin/PpocrService' import { systemOcrService } from './builtin/SystemOcrService' import { tesseractService } from './builtin/TesseractService' @@ -36,3 +37,5 @@ export const ocrService = new OcrService() ocrService.register(BuiltinOcrProviderIds.tesseract, tesseractService.ocr.bind(tesseractService)) !isLinux && ocrService.register(BuiltinOcrProviderIds.system, systemOcrService.ocr.bind(systemOcrService)) + +ocrService.register(BuiltinOcrProviderIds.paddleocr, ppocrService.ocr.bind(ppocrService)) diff --git a/src/main/services/ocr/builtin/PpocrService.ts b/src/main/services/ocr/builtin/PpocrService.ts new file mode 100644 index 0000000000..2079f2d6b8 --- /dev/null +++ b/src/main/services/ocr/builtin/PpocrService.ts @@ -0,0 +1,100 @@ +import { loadOcrImage } from '@main/utils/ocr' +import { ImageFileMetadata, isImageFileMetadata, OcrPpocrConfig, OcrResult, SupportedOcrFile } from '@types' +import { net } from 'electron' +import { z } from 'zod' + +import { OcrBaseService } from './OcrBaseService' + +enum FileType { + PDF = 0, + Image = 1 +} + +// API Reference: https://www.paddleocr.ai/latest/version3.x/pipeline_usage/OCR.html#3 +interface OcrPayload { + file: string + fileType?: FileType | null + useDocOrientationClassify?: boolean | null + useDocUnwarping?: boolean | null + useTextlineOrientation?: boolean | null + textDetLimitSideLen?: number | null + textDetLimitType?: string | null + textDetThresh?: number | null + textDetBoxThresh?: number | null + textDetUnclipRatio?: number | null + textRecScoreThresh?: number | null + visualize?: boolean | null +} + +const OcrResponseSchema = z.object({ + result: z.object({ + ocrResults: z.array( + z.object({ + prunedResult: z.object({ + rec_texts: z.array(z.string()) + }) + }) + ) + }) +}) + +export class PpocrService extends OcrBaseService { + public ocr = async (file: SupportedOcrFile, options?: OcrPpocrConfig): Promise => { + if (!isImageFileMetadata(file)) { + throw new Error('Only image files are supported currently') + } + if (!options) { + throw new Error('config is required') + } + return this.imageOcr(file, options) + } + + private async imageOcr(file: ImageFileMetadata, options: OcrPpocrConfig): Promise { + if (!options.apiUrl) { + throw new Error('API URL is required') + } + const apiUrl = options.apiUrl + + const buffer = await loadOcrImage(file) + const base64 = buffer.toString('base64') + const payload = { + file: base64, + fileType: FileType.Image, + useDocOrientationClassify: false, + useDocUnwarping: false, + visualize: false + } satisfies OcrPayload + + const headers: Record = { + 'Content-Type': 'application/json' + } + + if (options.accessToken) { + headers['Authorization'] = `token ${options.accessToken}` + } + + try { + const response = await net.fetch(apiUrl, { + method: 'POST', + headers, + body: JSON.stringify(payload) + }) + + if (!response.ok) { + const text = await response.text() + throw new Error(`OCR service error: ${response.status} ${response.statusText} - ${text}`) + } + + const data = await response.json() + + const validatedResponse = OcrResponseSchema.parse(data) + const recTexts = validatedResponse.result.ocrResults[0].prunedResult.rec_texts + + return { text: recTexts.join('\n') } + } catch (error: any) { + throw new Error(`OCR service error: ${error.message}`) + } + } +} + +export const ppocrService = new PpocrService() diff --git a/src/renderer/src/assets/images/providers/paddleocr.png b/src/renderer/src/assets/images/providers/paddleocr.png new file mode 100644 index 0000000000000000000000000000000000000000..eec88c8e026eed0bf67bb16eb9fdbeef1cd7d99f GIT binary patch literal 16805 zcmbSyb8sfnvv!;}wrwYGY-eNJwzaYIh8x?)Zfx7OtqnF#HrCzW{chd+|2I>8o@c7M zPxYCZI_f@A%8F7*@c8gxU|>ix(h{owaOyt-fcaMnlgr=z18_H0DKW5`8N#!F4Sjb?03n`EWhoVCTlZiZliv|OqlK0fQe0l<- z48j955QH>_VW&c4UoF-$jt}$>dV3+6M5v03Wbk}a0q0q-`A-^>e%x#p{}|3E zN10YjY>SmwAFQb%5{ZV?PqBv#hO^<62PA<=_I`qi8^;c}H5@B^>ZX#-1%_HHexjjp;QujzNpKH<9jdG{1ZqDhF*~wP}4y%`4B( z1IOPxIV~kC(a4-*qhEEO=hp^uMrBMKv+ST!FKzY@%@ILjkBxcyGM=@a{AH8Jy{1fm zb6WBk6jBQYx%D93GQ2Or8qbJ!zkw*b_#ESswz z{0HNHc9>R9I~~n--etygM2s!Jbr$}Y%jjF#n7Tv;T~PA0mV&~pvF9F%^xgI`ryKD)qgmB>*U>v& zVcWaK1!F3O6@O^3iIOSe7&S2`Z3>qin#n>4RVNCGul(c>@nv7$#)=QMkz0RKIQ*0> z-%(+vDx;jrR#?H2#5qKp$+GTpJGNk5G!wSqyxE4!%gZF*3bZkFWWN0@UTXyhTk8%) zXc}_~=rnx%4lN)yvZ8a6PZoa7P=~yTmPCep!AWD);TQlp--+Uv*4|P5MOO;SkhAkg zXoxs2_dr0nTd<375D>N)*4L1tdF+%>TLRZ@jv$?u1qCe!Qf<^cG@;g|Wo-V@azw~% zE8?BGpdFQ}lM}DoJfW>F$VFy{(VFuc2o9b%)HRgN)TBMRI$8ABQWb4C>gejUg2FU3 zrxOy=B^fHh1tQ)6N@ol@d8h;NUHg&^ozOTJR8G%+HcuV%Se2nuy}FJ-LtFdHD4v3k z7B&0bPy@DQp`lpYKsFm$xVZ5*@IFL1O{Xo>;6S2 zFt8WF3dpl$Q4?^GWpUK;uP@Niwb|=uF`rQF`!y8~@pp*vJFiuijh&se2=u*V*$suC znriNZx$!s9IFt{y1*h*YrLV1_wJUJ<%J@M0Pp6m{y3;Uo^1KLdvOLd4sXYC^ADMcN zofctu(?oj`KP4kWgPcq9;Qg1{10=>S=71A-e-8tx#;kq1inG~@6??E2HG76kC$yF( zzy();2e}sfTk(;}&!zbg$gQy;rTm{)LFs?)>u3<5AZEltj4WddH_IYeF_`cxKGaOu z<~q~4!TX1qxn@HOL)$)Or2X7%*40b#AsfkwC-jlah1LhElKk%f`cMmKSB0vSl}Arr znw9~9{kp()*KM3BxtV?k%d=7iNcU;+35UG99uh>jscz30EilS4#9dx`v6UiPSAm*T>Tyq*z^c*^3h( zx?0(-J7=sA1mUTNCIQV9J;=S}Ai2AE4pf;ZiminwMWWhq=@R+LqDUY+h%i~B^7$rk zklAgxRu2YEZ3ymla^{f0m!`5YA&-Xe%`HfZeTG6vL7$cf&n|XzMXLe)-&Zz=jE4E1 zBcS>gop2YF;vAnz?ZN`85*ok|4-`^Z3Gwc>oDUbWKJNvXPz=lGQC~HAE;vbE{O&yy z2!O>L&HPUd$l{0iFq2z>Dvz$X$dJ-AD`aYX2xdzq|%d-IN^^r_^XgDYp z%)zN<7I2duzlP@H3qlW;lkA<=TF={{!z~1(4~w-*EthJ1%vNKS{TyRH57q-+gV8vR zK^P|AUj^dvmMAg(wtOz8+YkRnx?Yt5oksLq*HD%5w&Sz?8Fz$kXJQ!IBr;3Y8h8TzTt&;hL}vxw`=tIVWPq`N>jX*ECqj%{1$;X02Gmb2 zA3)ZKeH`tLR+j>~jreL9H%8$hdq9=@xSB{u{-`Q1L2_3(1LPj6!^P#bw^<~GQ&lRH zE+3r_;`aufeCd=D&&fOAn&`v1Wis}uPb9C;tGg(SBA?riCc{L$+0adF@q25cejnPS z0aiOvJ?(d^U&M&3?Zs662yBc}wq*PV5MVj1adY~yyL50GhL6K@Px74jTOMzQDjGk0 z4apH(Qt+RhwUPs&Xy2hSVhe;y8>il!q5Y(W$lzxO_J%7 z{oKcwr&@hEE3UI=X?(k#KeRM|{>~`Nl^s>_?Lcsn+ITXSGAOWnG7sA;in~=NT6}y< zaaQ8wc>Juu==MHZx@X6^NYGC7P8~bI*&5Il`97c0n9XJzA5EoP!}NRE9i?P3h_Q`q z-KnVMl%zh4n6nJ4UflZADmqFj31)M+af|cKG)jFc2mQ85fJSLuLwr>JLqm z@}CH*u-oX>4v#T1I@h1Fh1Y{?>@E&_@Hru7M$r6(0Rqg=y0e3U} zp#I$aw%Oe4Q3wtKg7I+<0uQghZ6tkE_c6*6)xrs-=SQSXhs}WF`+f=9M+F$0{I4E! z8k!PgPg43xnEQ>sIGfJh*wM|{31$O)Y*|6DJ*yBPg~CCOEF)Lzj!>0cuT5&y_dWMU zp9`R#x@1B|_Ee4OUIb@7$I$(=BI*yZwLwP-^dsU67EVr9oz~9#YOxdxKk>UuY#F{1 zh9KT2d!fbq<;_{E69RxD(im96ia@_coA>VtKE9F{C9MOK3u_HpV-BNaZV}LIDr((@ zc~YRq!yjpE)LRd`dt?*J)dzi0OsnbC3eYh$(AKzM@(Edj6n{UMEa;p4KDtv~#Ta4yb0j&sbR~Ts zBg-9>W3>`-?wRZT05wuHC>=o00shM zSCGO1e1sxk9BHys2JO7_2oq0%jQyL2+gC+*Yb$)Rzf;{k6h!zW14}(kn-e@9K@3J< zS2S`bdt#T1b0y-B+o$x%&scW7+fSmL$ zc=_npDLJT!)x+J|h?MRrt3tZQX$Da6j0?lz{(}s*&4i1RmyDIeEjoM51u;DnLs$r# zh%ZJ~{sGNdRKeWlOw~ku9i)%s2|{AGvj;;}7vK~x-l!c^4AXTe_4{-NF+~-0AtpptrkkHH zI}EHIGg9Q!Jp;&MEaS5p&4vOevvpDM85muFUm)UgG$q)B3xP#?G%Wp3Ll0ZIY72ig zPt4BWdLdN?Y%eaQetgX{ zKwz*xd!HLrrl(40WDXXb^P~pOb#)t?Q8yA#tr{44H@qqI+e)WpaRb!I*9L*eh+qgT z)U_|u>dc$NDM)8=XAlT;-)aF7AM14>@Ifnl{Edmw6*e8-Rs0>u&X#dQrPmw)a^#3y zZbXl?W>U}Mv4IsaX2XK#=M;OA4$TX4WbZ^tjdutzORwRN)fn@jJGi(NBspEL1-!&5 z@*h^{#z^3j_}VGFT|7oj&=Jj~jBT6vGd5}*hvJ4mqga_CFm;bi`FS+uT<>>kZvTNF z!?)ix9V5i1(2%9x!9dPfgH)iYJ)*M-?Nm#6&4B!S0^DC?l`d>UQqL^%av6f;KsBIOu8k23LOpS1MJ*`xE-AE*@G0(-E1i7<;Anpoq9L2 zXb`nv9I0K8{g;*gI?zjp?GuDK$Vk>}xeA-=HeI4`w43#9js9Ai-tVJtrpYMzX{6EF z@^$8}9~k34jX}3_CuEboieU1Mli~}f80yKw<7T9e-C8}G%?&NB2>u5S>CiUiA$n$R z{Ocg9Y~oJ>1VNh#C3P$mNE-Cy1`Cnp+#3!*UV^;vg{ZDf7aUmJ5RtbjYeHTI3B@$+VtZ35>O>iTzuW)5Dcjd z|8?taX-JA_{OZ<^P3E}z`_;m(kJYrS9ofiy>oa&k{JVaX8jtj2A1TOhM{p@<*Rpz{ z1nRFuS%!{$lE*%|!47~5zYFkI4|vzxi^;t@Dq3mdj1;y$OS(QV_L{D0 zWBM=8VgvL)SoK4@48*kwgW}~IoA@L3?#-J<22A396a$=lZQ!BaWX6J%lc31m*pEVF z`c7_f@@CpIe;YP~V(;-}Lyr9{rX0w%v-$F1;pR1HWn-ocCZQN;{zhM9XO1LnLgAiJ zkH5XRXDqv2jpsV3%uq(scOdk6j>zirJo?yW`(v*fAp|@y%!%qJHKk zlQmzbigTu0A^;xvm1-h>O|lFOKdL-GUQ>O5)l-$hbz*r`^ncMEga*d(17m65*INE~ ze6KhJS_r77K$kyl$@u32i0#9>4PtENf ztdt|1&LBOO{1mgSRWa1yD9mpq5mx|E4|-_jwN6}Qb1wF7N^kRjghdd8j0QhfLdLmF zVWr-(nt*e$jylVKS6Sn+W$$fn7vdpNM=mlvZK^$bTfA<$q8Zuj&i*E5rGxcC0xe!C zoV!NJp5TX)d`@I%JeF5H#A^41e{&9aT~+hEhip8P+)69LdlqiZ88tU?WF>nZuBak3 z5p6p2wAiEST78IjJrKUBxhZ@?Awb(hZ}wkGNXp7v*98z)Lutc8AoIV9xi!|`S>B1h z_(O8XIX7K2R37&J`1J)7U7C(BuS1eyPQ9zHqT<-p2m`j1(w(uKuF4GaeGhz!H;B>p z^m_ynh5L`MOPddybNg!Au7!O?o4J7HE1&Ol+F}>VP22cX zxHP+66=BetD^p*?KU2rq2B8m!+z@I|4q`$XjY8!{Gcjd;n_D;uMx;n;GDpgdj*+t0 zqQRs!BZM&ylWq7#h(;a$wy2r+@IzS9d0ki-xMaKSt`$4~@cem)L;(^7zepNbz$Vg% zERCcMrNz^XIwnBuqsdd0K9?;BZ&d$@0|B11-vizMH$8}B_>Nu4(^20_tyB&bhsaq?e+Mok<%bBM8B3Y;&<+(YX+vRd55SU8E4< zgW7VF^oNbB@*>#@Fm@z2_n7QG*#&N-s#HyUJ+rLdKY$dEO2Arxv*qfQ@@Qi0wA(;n zd9Gs3rqRwC48)C~+9Lo~OXAwOPA!8<8&2Jo*uPPSZUX68Dg}AWt826CVk$M^aSts+ zl87zDNcB<;JFRdzz&EWG0BJ$$P|8!6r80{4(eL>VVa;_f%KB7$nxCi`hS`#}E(xoJ zuQt}QBdG~1x)%}Vjj!%MvtD2Cc6iCA`_*jaFZ4K^vTJkc<&&eKW=U|*7Kj#}ZV$^D zEZl}SFt_Fs;_Zrv=z3fUK=MN&zay|zUpO4(&P0Mpvj$enY_w#if)=n2o(owrQnPZm zEDkobOwuiF-9XPR^H>4C8UNgqc!Ab^ErVmsVagQGy3CPtfE8{2en6>1_N-)HZn+)& z-Vii%HcJBj#&-jfPupanj$x`k~gM>4Ud9)o!9-DJR^vueF7Jh&%XSd%i zN=;}Y!GhBHniCi_rqh z(KHV76=Kw9^xEE$<-;->hih-1(-h~8c? znveq}>IocV}cG)-ZjS8=#fAT%Wosy39X5Hc3zOEqQLKEh(Z4&khxV@ip z{=6pRx%Dpz%~X$Z__I+zr?h&HXo4Im2f5rvutxL2lM@~YzK7|1l!AEz7i6}-{-oN*9&-VFZYL>DLX17{Y3GPU*-F5Fr z1-c3@%?=Kw02*WY0kdad0_5P$GPHy|cylpn@Mp5J0GGxU0%R>q&=6D-DWvALM;KQN z5H+X$Ua@;S^{kg{0TPN7<`60({>wkJ`vhqqg_Ies1TS1NoY>igaPOl{quk9$`9v&E zxGysmb%b;t2};1-qXJzFIaF%ZlbmrueSM*e;w%GyWq%)Mv1@8tDss*|1XyAM(fNlV zLu!W=ikOvBMJBqt~4P#!yRG|3sQJdLZr^bAXg9pJqR??XWE80IOU< zle_`^Q~dbAYYFD*gxT+?>0^6Mw6nsNuH@fZ(PcEPPqkfZ+<2YKE}yVI~MfliE?99$}wu1K%BQu7Hgz%j`Z2$ee~rU~?3Y;a2# zek7)CMFLcgx)a||LYxQ;R2l6O1{*?%5*LKA%2NoD>;f6N%Gc2{4h-2>p|hOdZ0{+_ z^hiQMh%=@rP`4!H#flj*M_BJ*oll9!JDw1ZaM-W)yYIa5z_ez3Q+$}PY#WFGo6gC! zTemW*6Z$~^L%bNZD(=ORMa&9E=SmwCno9!%Ph%WiQy&kqE3Gr_ud`xKZxd%3H_hHR z#;e`L^O%unNEAmZURJz!>YA_|sSN3FGMwouM&T}Iuu`v((v~RS9I2@-G)9hR!>;vz zpomjjU`Hh_vFPXIRedp|thQ~yyiajihmc_JtkP6+KjjU&1m7!V#9<7Lm@qze?A{&; z4A5S^lRIRU!Kr6mP4lm;v{OgG+A-0pQu894W04WG(7n1en)@zDN_RFS#W4NG z=(*n3o;AX*(SjKmBa~j~th-{{AnVdCFYrgbMd^`v=s35a=ra|)7)$gX={{&@KEM54 zCP1i-?^Zm>D0SZ57hTdr9F+DKK$4Fjwe*{NFTEoBIu*2=- zv?0lL$Ia0c2K%ruofHv>@d`PV&ZS3V*e5)_xl;SIo1G(xJN=t0N4y}Abzha#lIoCt zrb8*W;>jZT4~hQz44mYEYP|Sh2L{6CTy*fOm}il;6wNu-&#UZ2aQkST2C*OYf0uT% zQnNm`LJJpE8Do1xgErt z{uwh>vRPz19s0|(dVOJtGIP>+${>G5N8e`hkYKtWH^oktZKSk{AD?Fg8D+!yhqIYv zF>cXwI~qwiw`ND*D?HaC5jnjZvX$8>^v0Lp@1!mcX+I*k?JBl14wZ%dk$5*L%?7RP z(VhB*QqE=AW7j%#kV`AB>qZ4da?XL$aY&R+D1 zGnVd!dEc8W7fY<~-)A*$8nCic&c$pZ#o-la0V6!lcWY@{+?N84tLO(RbFKV^;9jVl zK7!QRnUqO{^v6x`Cgq%Z{Rv#5Qbx$u8LnujX+Eq5JoJo9uC0t6vYS;4s&USBg8$;g z2lPR!N-g#>CcA$;a@Q%qAz^_O>_}R=^>JxPbKWsZdM#$i1KN=8XyD8egWH1Y)xH(Y zpJUHCdwQOE>tl@VM6yL+A;At1@5;?|4B@&mmj$`v6}wyqV$|Oc+ux~;5= zf%AkcgsaA0GohT7B?IqH>sjktMXpb3j(M;q#>nuytQpczOTr0@Fd&0Af@DR-xbB)N z%Hjs&7a+AZfGN>M76u)GRJWDMWC~_iu3B?cHT_~e*OZR)0IhTZhYji@MwI|5mXgV) z()y=GRrQ#h9~@%UZN6ZxH~xf>i(&;gEg#%@!$o+I1SW&FdppjNuW@cIm}}^mSgts0 zKYS}54SQomy`Uh}BCsknAyDlyaft}n5Lnmhp6tP{;jBqY%DzT3F9Mu&KA6y z1@{Lb(B})}h~bR{7$UC>wMJ*giC1{kc<@O6Jlp)l)!N7n?{;K5Etw{H0LzU#?4f1{ z-&)UDCdevr5J~#o{s?t~b{7KC|&hhnQ>>QK#>s0wB4GQiT`G#i8y9OXTi10Y@kAars zjd631<>u%I&r47%wiu83rEzXl55KE~#t{Q3xnLWbTMqwtE;^QPn^X{%%i7fwZqomm zta2+Onh^VxQ$0@>Nvqd``j>9=t-|p|69!zDIL-!9+gj)^@aM&SLlej41QHxLwv+>h za1ZM56q6(PA=`1xoxgLX^)ZI~`YTW8M@JiW_0N<(?I5 zYF#V-n}8eRuk^MsAS|+Vee?OugSL6I+e$MdXG|{SkE|zpBB7;~F2>6j=-X4Yz^vW< z+Z@TugP+4yTKtB(NGe$JyBATIV`Pp=F1u*Ow{qy zVk!bvmDC=L0nykBlR=o<-@FUac3KW;Tm{PQ!%-LUA|VcS+R5XHLZFl z55j82dFYpFL5qk)zIxcPq!8%5_maZJCT_;X`1|Dp>{pR2tW<1o9^*HQnui;avU;YWT|ZaoBh9zD#@{=km_ z?pzv@%a-B)r0Vs+O+fyu<0rU%gTsTci{Js7oYxDtsy^5Bem}kv=;v6WB2w*>|H3<{ zC(otra!MtMIcbKEhn;B9=D4-`L1z~Idr1o=V8Xwl`}chW``H=~sPMjY?2 zk-bQsc%)`UC`pV1br(DzT!6SbsZ;X8Af)n*ppZ%{hldD~XRZ;{O4CNp1Lb&P2Y%%O z-A(Z;8Jl-tQe>P1Q-053LScl#a5%C&w6HM6KhJ%v3t*mbSB^*qyt+P67QbP>WZMQ9WB3ST2P_eq7+($_aW5voOkGx|yq6)8{ij{p&Adg4dH=S(?xLi2U1Ls&%|7{e4f}=7r*@pb5I6dxKI?*te z1kGNEGn8^NXHSKPJ~HOZ&Cj2rmuC~o0|}P)-e+-zdlF@TCx9q=LinEGr&Gt(wpfOO zuTm({qOY-8vR1&Y0`X| zXB$-B1KU22X!@scNy`huVOZ!dlz?6U)mlIjCO|8pa+h(w5QlZFD#t_$6(;Eoq=Ddu zE1e{dM0|pqFy4M~aVGP-mz(>sZ_2Pr0_EN=qTf6ePaK%Mmfu36_k|+^P^z{@uxU4l zRCRjDo(5!MbzZgrO1JxYlCLLwx5FZl%5#6axfEvW>8%<{6Pr(wCNn<)1_VR{NY@PP zd327n$ZHXEqI+{|R)OC}u>#?HFj~15+}B+%Gh^4|WNL)9MrvPI@b(bzzlCM{M87=%tdF-Dv|M3=wr7dZQWM7MYlQ7w%;+M7ny;GHEdieIZ{W zMD`g{1-908?38|D&rDnN7UpnM3(f5BuAwkxAXGeC?N{7HEiy5@HEipEGkDcoT=Pa6 ztm|H!cz=tA!`+A_BJP=CQGlANF{ovM#^{75A136!$NS!+3Rik@;^G>Zn1kp^zxx&TM|llv%KRCExxq5Br1L} zdG;c0daFIpREv-6lBlI6d;5fWZH4LSgQSWc1i5&^xCU1s{6b=o#i{+1pVI7!8uCL; zWNq+sH;g#MY}%yEb`U1Kc@9SYN4`FvQcA@F=Kbq0$)IiLPEp=6#?R%J?l3U~XasWF-|*hx+*GL{c+E^mxXYL?|}KYLj`DZBFXA*>V?SXE0b zv~PccfH&v$l6=WMy+;;2tR`p*ei9MV0tCX$rVmLU`Ct(ozyf$U8w6=^Tl)1Sf)jX_2V~d_Nu9is_h~R*VH7 zbr^Va3R^F3A#CIAHZb(ktIfv^nhYi7Jg5*I6)zdH9@6ZqLUo3{T$h^=Z&&v*kSdn3 znOKCS#`-vEpD05_G*5dRBxF_z>kC%j$26Z)&R0r^0F2z+xsF5$)uB ze4egcehXb26=@^*iP5jCjh6?FYM@)=fo|gvis}y41^SLBkgdg{jA24h1-&H7LX`;7 z&!qo%f5rL?tAk~Hcz;DrcWca(^Tl{-KpwdS0z@sK${;J=SGtbs(F50o$Z;GKhB{h6 zBHkNW)IGKp>Pj?OoByWjHv55br>5mcT}>kDE5sgeNg-C~$m4UP-_ZuvggTapn;lX*(Sp8H;lzb8(Mi-JQ}?C)DLYqq7F&{= zS4rsNZ3})8h(8x(1~d8TsM}gi6aTq+cotcv#W)#Q#Z)!z|LyvVZA3Um)U@zt>Zv-) zL;QwVt>`VZqQ^JtN>U9aSgNv8D-zmVSTzlg)Y2tIwbHU>Y3c^jzzJ8 zR_g}|(ir?<&%BRRwBk#!s=mIxuwnK{VCY{Xvu42cAwHz|rI z--baPc9m^!*iKze-|8`44ww-tM_(d^I`Hjn-JLGEMXsG;3+-74QG+WmMU&~ zJZwSJtjgj9_taf_mB6(prVh~P_&Y?b?rb30+8M=y7>4xf8iWBrCp?6jdIzOPqA>_< z{@tt#{}h8>{8fZ}F>Y+Nc(iGog5u`mJ?qbpw^|M4j*m`<;i9CO?Ncg)jDnN|n)n)D zTaMk+lQcoWe+kn}sEEKmzYgq`-plADWx36M&Q`BuhqhuUN^>*~>cSoJS&b0zBCdyA z+iJy8HhoAnG+IuMCcT;%KY>X1m>HJn_3cnO*u{>OxHj*vm9@BK&c`hlaF!78_7rb@ zG}Ab;xV81TbgmrY>CL)xMmRmqHb7m3e3G%;d#LNt-1tiKY}_%y2dL*lC4rGMD}Dv1 zr}D9Qd6#HfJYNsL^XWX{;)hY=Z1eY(IqK8WpgjU6$G^O?5V7zU3~u6~$mi0tyP>Y{ z0sut&0(tb7KGfYs7i)tbGQVC@(jm^zq}L|_zx2#(7bt&2fa*J53zMq}dV%@IY>Q@i z7Rpu~xlIpJ#Z6@Quv3Z9%Pu=7h%~B0<2D3Yy(lgo3v&yZc}60a{hZi|!+s`HbOVTW ziyCRoUe(^M-58=;LkdU9n$vz6iz#(-;;3_{7KbliHt^oH;KEqG3raJldAFz;Y9KVe zJ3I+z!d(iIUTOU(kuAZqgRfUwA*TjujUFQG>4culJTM!j1DoQ(DT2MJIgw*ble(L0 z$sGN|!o*HG_^xR4b9<2U*%arYeVT*4B0 z$GB(AW^DzHZCSvSgmJ%!WdJXnpfjuX=>>!mXax0ocX-aLfCOUtI9;KheI9L8TIJ4ZkBu2)A&c@lZe_^7RN@1kUP09SOr0RG0-IW?2(` zBD<`Fzd@ShtFNg6VoF#F0I89!;gZk$726&51A!_E$mol}&1+!+?A-?_n$!e>PtBg~ zw13P)AkSjFT>f z2^uU&FO;~mI-x@}#sTXB;Rjt&o>;P?#pg|CY9a`4x1-^`P%*7-*L9Az-3wynP_u#KqQ%DY!TFYqe=%D0G=e=+|#(r3opr<`Xi0|DxAp>{# zONP)2&QqjmLQhrX^4pJbwqCLW;ZqmyVmqlKO^p?O4=nUUX<{qQ1bx7zXL9*)%T*$5 z-sFGrul(144w1(stOz=XL|B#L4J3qMEZvT36skskL*Nu$uN3zb5rZMYwQr-2QoDd| z{jqz&&5guuQx0}FgFhhfh+EVGJU6i}pWZz30i(TVG($kpG--FUQo_RiFfd23%t7sj zR(#m4XJnyIaJe-N-b4i&b+u07`}?+^{}qFZ@@`Zq)XRJ@)~`%U!#>>WZ?zqJVi70v zMVkt_Vn->LI1v^sX?sxD+e~W$XSpNdYFz))5zk1VDsVQV+-D^ud2s@DURs~B@u0aK z(haV-a{5Z^$r-PN7~>q1iX0JocUWKzf3mUURDxv|2TbD+w0S%$xTC@)p)TL(t#_c( zNWZl!k7xkM$~B_s`u71Z!#bz=t|o~0g(%>-2SJ0!SK^CXa4BV`DE-dH%V!7PrQ^bZ zBxIOtX5mu)D7cgZqj{gxF9@0IMC(--K!poDxjXo_#jJ)o%nLCT8J)+wix4UtJvGm) zE9P;Pn^pT!q5K~d6Qe@Y@B>{v#rWyd)G;4V-KS|Gy(=DG>%)FfjxDcU7NhGhCK`Uw3-Q-% zI8WN&9`ysd2kS(ah!j)c!9gOMLG-C*+mFca`pvEX5=o+0`SS=7B*rS0DZ$cfYn-ZKqDKH(Wqi88Rjwh_5SlH=!ckgT+KOpSK5F>tH z0<~NY72ax^n?u9Heyo=V^~I1QHb!uZCO1s5&*m3l(AD<&lO8wldE0-qs_kISg^r9+ zl!|Evm2x&(-+EY;@fk3)2;M?*54S*s)V(NMVVX2i{ppX}=d*8AfLJjz$5g&KCJeQo zmvUxmd#R#k$#sxnbdSCrohszEh{6-kSoX|i*^j0$hAjVgeC%RqC>?VTTz)uE-p^3F zH#;k@-dX7=!0j9$0*8$i^js-8y$oJgolsyewDQk(A=|y3(F6CIUI+sOUg0c+I2#-C zi0rq$qZ{%jc1}veo}+<#Z}-$Od~_ZT{bSu<#CQZ%Qnw05Edq`kho)GXq?WP!p!n;a z>j%4YG1N)#bU(~b=TMGwc+D1VTYA70V1R>jA76i*D~KC_iaZwVvwJD<@V)@yA!W&k zp5_UMMwe}Sx6n%w!8(t&P*HHGR34#ltYigc;-eiOc1rzLuLhOjY_reT6FBK-{bB&= zX#1s%`o>z<$lS%oU(%(CNi*Aro`D9^5pY)6!+(0)3%sI(l7EdA3DopG4-aIeV`sT{ z4c|mU4;L>!JL1lS_pdFrhSm$-NzreaKz~+dTR-`M=$&Y7Ora*?VfTN=3!3pF8X~ zr+f$DdGNIk?89Av@c1R)TU{e+J&8Bg^t=|GeTU%eL~5#kSGsnb#&{S>*LRM50VB}*L0N`xo%i3oN1Hy+Q7pcoo~1efd@tI zN(pe~d2VFz8+^RHzE|SMk7ktJqu&A?-|76aR4UAihkJ#{5;Q=(+;uI8Jp}wyvn)Vk ztxzp(?WVVjr>?1$x>jH4w6_h?PZrumr*#74S-HNWxq^mhAXUyPtA$qfzv&Vdnw`GD zs)zH#17XN?mFKrk9}4_fF<7&umh!;)S1&yR?qzrUX0@q{c1}wa#F{jx-vXE=Vr`2R zp`!3J>1>|fSEz_16~R+c0K{(}lZ=evRaZdmzfPROy+K*Ff-ylz4q;P)y@}2^ds&oy zN@_sX!J~nn15BsR>fvArIf8b8Llxa_X7Z6%v%zdGj!7Fc9Jww}V81;f-| zmOLlofx(}`3eOE|ca#1+BU_Yi;S!J+SoV+lL0l1H zvzpfq)~z4V5aOg}pxYX=y+Z2hX+cxoE~Aoz+2thxt-6!+nR4zCQG4P(M0|L8k%Fc; z20Uh68d$^6W9aGJ6zUeV7Op{64=QH)tY^!t-3({Pb%)U_o)r)Ktg8Fxhq4hhIYzs2 z-O&-MQ5Y-S2E~;u6qe|_0Sl`WV*~T9L0aM(*aqBlqcIQ+Oei$bcGB_Y!b+{(H-f7D z;9UgR$qi&;eYLI`Q*1OEGmxke8M1|FjhRTFi{{+C32gwu;hkaWRBazxQ3`K}u+R^? z(lt!d1SC-KV6wLupwKR@4q+qbaHH7`YfWn<8^(^tn!z(6SE%9B#%&Bo&Kx2*2G!Ed zY*uzsmr&v>9Y|sM{ z9@qbP>HYzl2$>7tl7GJ!<=lMm9wF|H+0781>-Vr~rU)r;faBlUsDUm)z<8O<+k!3t zf$c+S1o}h=yc8o}p7>b-CVV?g53brBd=gt>FLk>s7>&B-h3^8ra{9&TIgg(^oG-2g zljF;vx7L#tBH=RlV8vNdhL9v$51Zu*7@jaG`Pc^tW$@wd3pA_eT7dGrw(pcd_LGjn@&Tpq zmw+_wx&5EVj>jTbp&VD?htB;tS{z!i!T&~6XJgnkL}kf$ohpv_xB`Tmq2m7v&j2w0 zlS43rxy9Q}?f?J)u1Q2eRF1b1rTRaAs+D$K0@z1P7;Qs^Sc>Nm=)0v3?0000 export const BUILTIN_OCR_PROVIDERS: BuiltinOcrProvider[] = Object.values(BUILTIN_OCR_PROVIDERS_MAP) diff --git a/src/renderer/src/hooks/useOcrProvider.tsx b/src/renderer/src/hooks/useOcrProvider.tsx index b0e23c20e3..38afaf81b0 100644 --- a/src/renderer/src/hooks/useOcrProvider.tsx +++ b/src/renderer/src/hooks/useOcrProvider.tsx @@ -1,4 +1,5 @@ import { loggerService } from '@logger' +import PaddleocrLogo from '@renderer/assets/images/providers/paddleocr.png' import TesseractLogo from '@renderer/assets/images/providers/Tesseract.js.png' import { BUILTIN_OCR_PROVIDERS_MAP, DEFAULT_OCR_PROVIDER } from '@renderer/config/ocr' import { getBuiltinOcrProviderLabel } from '@renderer/i18n/label' @@ -80,6 +81,8 @@ export const useOcrProviders = () => { return case 'system': return + case 'paddleocr': + return } } return diff --git a/src/renderer/src/i18n/label.ts b/src/renderer/src/i18n/label.ts index bc389dc82a..bdac2f7230 100644 --- a/src/renderer/src/i18n/label.ts +++ b/src/renderer/src/i18n/label.ts @@ -327,10 +327,12 @@ export const getBuiltInMcpServerDescriptionLabel = (key: string): string => { const builtinOcrProviderKeyMap = { system: 'ocr.builtin.system', - tesseract: '' + tesseract: '', + paddleocr: '' } as const satisfies Record export const getBuiltinOcrProviderLabel = (key: BuiltinOcrProviderId) => { if (key === 'tesseract') return 'Tesseract' + else if (key == 'paddleocr') return 'PaddleOCR' else return getLabel(builtinOcrProviderKeyMap, key) } diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index 41882c82ea..fc9245cf8e 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -3884,6 +3884,13 @@ "title": "Image" }, "image_provider": "OCR service provider", + "paddleocr": { + "aistudio_access_token": "Access token of AI Studio Community", + "aistudio_url_label": "AI Studio Community", + "api_url": "API URL", + "serving_doc_url_label": "PaddleOCR Serving Documentation", + "tip": "You can refer to the official PaddleOCR documentation to deploy a local service, or deploy a cloud service on the PaddlePaddle AI Studio Community. For the latter case, please provide the access token of the AI Studio Community." + }, "system": { "win": { "langs_tooltip": "Dependent on Windows to provide services, you need to download language packs in the system to support the relevant languages." diff --git a/src/renderer/src/i18n/locales/ja-jp.json b/src/renderer/src/i18n/locales/ja-jp.json index 90b2eb32b9..4ee4518fbd 100644 --- a/src/renderer/src/i18n/locales/ja-jp.json +++ b/src/renderer/src/i18n/locales/ja-jp.json @@ -3884,6 +3884,13 @@ "title": "画像" }, "image_provider": "OCRサービスプロバイダー", + "paddleocr": { + "aistudio_access_token": "AI Studio Community のアクセス・トークン", + "aistudio_url_label": "AI Studio Community", + "api_url": "API URL", + "serving_doc_url_label": "PaddleOCR サービング ドキュメント", + "tip": "ローカルサービスをデプロイするには、公式の PaddleOCR ドキュメントを参照するか、PaddlePaddle AI Studio コミュニティ上でクラウドサービスをデプロイすることができます。後者の場合は、AI Studio コミュニティのアクセストークンを提供してください。" + }, "system": { "win": { "langs_tooltip": "Windows が提供するサービスに依存しており、関連する言語をサポートするには、システムで言語パックをダウンロードする必要があります。" diff --git a/src/renderer/src/i18n/locales/ru-ru.json b/src/renderer/src/i18n/locales/ru-ru.json index 6371fd9efe..428b3c4028 100644 --- a/src/renderer/src/i18n/locales/ru-ru.json +++ b/src/renderer/src/i18n/locales/ru-ru.json @@ -3884,6 +3884,13 @@ "title": "Изображение" }, "image_provider": "Поставщик услуг OCR", + "paddleocr": { + "aistudio_access_token": "Токен доступа сообщества AI Studio", + "aistudio_url_label": "Сообщество AI Studio", + "api_url": "URL API", + "serving_doc_url_label": "Документация по PaddleOCR Serving", + "tip": "Вы можете обратиться к официальной документации PaddleOCR, чтобы развернуть локальный сервис, либо развернуть облачный сервис в сообществе PaddlePaddle AI Studio. В последнем случае, пожалуйста, предоставьте токен доступа сообщества AI Studio." + }, "system": { "win": { "langs_tooltip": "Для предоставления служб Windows необходимо загрузить языковой пакет в системе для поддержки соответствующего языка." diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index 71c734ce75..539972ddfc 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -3884,6 +3884,13 @@ "title": "图片" }, "image_provider": "OCR 服务提供商", + "paddleocr": { + "aistudio_access_token": "星河社区访问令牌", + "aistudio_url_label": "星河社区", + "api_url": "API URL", + "serving_doc_url_label": "PaddleOCR 服务化部署文档", + "tip": "您可以参考 PaddleOCR 官方文档部署本地服务,或者在飞桨星河社区部署云服务。对于后一种情况,请填写星河社区访问令牌。" + }, "system": { "win": { "langs_tooltip": "依赖 Windows 提供服务,您需要在系统中下载语言包来支持相关语言。" diff --git a/src/renderer/src/i18n/locales/zh-tw.json b/src/renderer/src/i18n/locales/zh-tw.json index 4ab0fa3dae..2a60499310 100644 --- a/src/renderer/src/i18n/locales/zh-tw.json +++ b/src/renderer/src/i18n/locales/zh-tw.json @@ -3884,6 +3884,13 @@ "title": "圖片" }, "image_provider": "OCR 服務提供商", + "paddleocr": { + "aistudio_access_token": "星河社群存取權杖", + "aistudio_url_label": "星河社群", + "api_url": "API 網址", + "serving_doc_url_label": "PaddleOCR 服務化部署文件", + "tip": "您可以參考 PaddleOCR 官方文件來部署本機服務,或是在飛槳星河社群部署雲端服務。對於後者,請提供星河社群的存取權杖。" + }, "system": { "win": { "langs_tooltip": "依賴 Windows 提供服務,您需要在系統中下載語言包來支援相關語言。" diff --git a/src/renderer/src/i18n/translate/el-gr.json b/src/renderer/src/i18n/translate/el-gr.json index 187110e47d..52f911f148 100644 --- a/src/renderer/src/i18n/translate/el-gr.json +++ b/src/renderer/src/i18n/translate/el-gr.json @@ -3884,6 +3884,13 @@ "title": "Εικόνα" }, "image_provider": "Πάροχοι υπηρεσιών OCR", + "paddleocr": { + "aistudio_access_token": "Διακριτικό πρόσβασης της κοινότητας AI Studio", + "aistudio_url_label": "Κοινότητα AI Studio", + "api_url": "Διεύθυνση URL API", + "serving_doc_url_label": "Τεκμηρίωση PaddleOCR Serving", + "tip": "Μπορείτε να ανατρέξετε στην επίσημη τεκμηρίωση του PaddleOCR για να αναπτύξετε μια τοπική υπηρεσία, ή να αναπτύξετε μια υπηρεσία στο cloud στην Κοινότητα PaddlePaddle AI Studio. Στη δεύτερη περίπτωση, παρακαλώ παρέχετε το διακριτικό πρόσβασης (access token) της Κοινότητας AI Studio." + }, "system": { "win": { "langs_tooltip": "Εξαρτάται από τα Windows για την παροχή υπηρεσιών, πρέπει να κατεβάσετε το πακέτο γλώσσας στο σύστημα για να υποστηρίξετε τις σχετικές γλώσσες." diff --git a/src/renderer/src/i18n/translate/es-es.json b/src/renderer/src/i18n/translate/es-es.json index 029b5c5813..d9eb5fb5c4 100644 --- a/src/renderer/src/i18n/translate/es-es.json +++ b/src/renderer/src/i18n/translate/es-es.json @@ -3884,6 +3884,13 @@ "title": "Imagen" }, "image_provider": "Proveedor de servicios OCR", + "paddleocr": { + "aistudio_access_token": "Token de acceso de la comunidad de AI Studio", + "aistudio_url_label": "Comunidad de AI Studio", + "api_url": "URL de la API", + "serving_doc_url_label": "Documentación de PaddleOCR Serving", + "tip": "Puede consultar la documentación oficial de PaddleOCR para implementar un servicio local, o implementar un servicio en la nube en la Comunidad de PaddlePaddle AI Studio. En este último caso, proporcione el token de acceso de la Comunidad de AI Studio." + }, "system": { "win": { "langs_tooltip": "Dependiendo de Windows para proporcionar servicios, necesita descargar el paquete de idioma en el sistema para admitir los idiomas correspondientes." diff --git a/src/renderer/src/i18n/translate/fr-fr.json b/src/renderer/src/i18n/translate/fr-fr.json index 08af4d6e30..383f53bec3 100644 --- a/src/renderer/src/i18n/translate/fr-fr.json +++ b/src/renderer/src/i18n/translate/fr-fr.json @@ -3884,6 +3884,13 @@ "title": "Image" }, "image_provider": "Fournisseur de service OCR", + "paddleocr": { + "aistudio_access_token": "Jeton d’accès de la communauté AI Studio", + "aistudio_url_label": "Communauté AI Studio", + "api_url": "URL de l’API", + "serving_doc_url_label": "Documentation de PaddleOCR Serving", + "tip": "Vous pouvez consulter la documentation officielle de PaddleOCR pour déployer un service local, ou déployer un service cloud sur la Communauté PaddlePaddle AI Studio. Dans ce dernier cas, veuillez fournir le jeton d’accès de la Communauté AI Studio." + }, "system": { "win": { "langs_tooltip": "Dépendre de Windows pour fournir des services, vous devez télécharger des packs linguistiques dans le système afin de prendre en charge les langues concernées." diff --git a/src/renderer/src/i18n/translate/pt-pt.json b/src/renderer/src/i18n/translate/pt-pt.json index eff87d6902..3da7c91e18 100644 --- a/src/renderer/src/i18n/translate/pt-pt.json +++ b/src/renderer/src/i18n/translate/pt-pt.json @@ -3884,6 +3884,13 @@ "title": "Imagem" }, "image_provider": "Provedor de serviços OCR", + "paddleocr": { + "aistudio_access_token": "Token de acesso da comunidade AI Studio", + "aistudio_url_label": "Comunidade AI Studio", + "api_url": "URL da API", + "serving_doc_url_label": "Documentação do PaddleOCR Serving", + "tip": "Você pode consultar a documentação oficial do PaddleOCR para implantar um serviço local ou implantar um serviço na nuvem na Comunidade PaddlePaddle AI Studio. No último caso, forneça o token de acesso da Comunidade AI Studio." + }, "system": { "win": { "langs_tooltip": "Dependendo do Windows para fornecer serviços, você precisa baixar pacotes de idiomas no sistema para dar suporte aos idiomas relevantes." diff --git a/src/renderer/src/pages/settings/DocProcessSettings/OcrPpocrSettings.tsx b/src/renderer/src/pages/settings/DocProcessSettings/OcrPpocrSettings.tsx new file mode 100644 index 0000000000..634e63b2d3 --- /dev/null +++ b/src/renderer/src/pages/settings/DocProcessSettings/OcrPpocrSettings.tsx @@ -0,0 +1,83 @@ +import { ErrorBoundary } from '@renderer/components/ErrorBoundary' +import { useOcrProvider } from '@renderer/hooks/useOcrProvider' +import { BuiltinOcrProviderIds, isOcrPpocrProvider } from '@renderer/types' +import { Input } from 'antd' +import { startTransition, useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' + +import { SettingHelpLink, SettingHelpText, SettingHelpTextRow, SettingRow, SettingRowTitle } from '..' + +export const OcrPpocrSettings = () => { + // Hack: Hard-coded for now + const SERVING_DOC_URL = 'https://www.paddleocr.ai/latest/version3.x/deployment/serving.html' + const AISTUDIO_URL = 'https://aistudio.baidu.com/pipeline/mine' + + const { t } = useTranslation() + const { provider, updateConfig } = useOcrProvider(BuiltinOcrProviderIds.paddleocr) + + if (!isOcrPpocrProvider(provider)) { + throw new Error('Not PaddleOCR provider.') + } + + const [apiUrl, setApiUrl] = useState(provider.config.apiUrl || '') + const [accessToken, setAccessToken] = useState(provider.config.accessToken || '') + + const onApiUrlChange = useCallback((e: React.ChangeEvent) => { + const value = e.target.value + startTransition(() => { + setApiUrl(value) + }) + }, []) + const onAccessTokenChange = useCallback((e: React.ChangeEvent) => { + const value = e.target.value + startTransition(() => { + setAccessToken(value) + }) + }, []) + + const onBlur = useCallback(() => { + updateConfig({ + apiUrl, + accessToken + }) + }, [apiUrl, accessToken, updateConfig]) + + return ( + + + {t('settings.tool.ocr.paddleocr.api_url')} + + + + + + {t('settings.tool.ocr.paddleocr.aistudio_access_token')} + + + + + + {t('settings.tool.ocr.paddleocr.tip')} +
+ + {t('settings.tool.ocr.paddleocr.serving_doc_url_label')} + + + {t('settings.tool.ocr.paddleocr.aistudio_url_label')} + +
+
+
+ ) +} diff --git a/src/renderer/src/pages/settings/DocProcessSettings/OcrProviderSettings.tsx b/src/renderer/src/pages/settings/DocProcessSettings/OcrProviderSettings.tsx index ac069a3b3b..120e5a9e48 100644 --- a/src/renderer/src/pages/settings/DocProcessSettings/OcrProviderSettings.tsx +++ b/src/renderer/src/pages/settings/DocProcessSettings/OcrProviderSettings.tsx @@ -8,6 +8,7 @@ import { Divider, Flex } from 'antd' import styled from 'styled-components' import { SettingGroup, SettingTitle } from '..' +import { OcrPpocrSettings } from './OcrPpocrSettings' import { OcrSystemSettings } from './OcrSystemSettings' import { OcrTesseractSettings } from './OcrTesseractSettings' @@ -32,6 +33,8 @@ const OcrProviderSettings = ({ provider }: Props) => { return case 'system': return + case 'paddleocr': + return default: return null } diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index 5a70c202f8..6f432babf2 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -67,7 +67,7 @@ const persistedReducer = persistReducer( { key: 'cherry-studio', storage, - version: 147, + version: 148, blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs'], migrate }, diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 0ba27d931e..5107b584e1 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -2383,6 +2383,15 @@ const migrateConfig = { logger.error('migrate 147 error', error as Error) return state } + }, + '148': (state: RootState) => { + try { + addOcrProvider(state, BUILTIN_OCR_PROVIDERS_MAP.paddleocr) + return state + } catch (error) { + logger.error('migrate 148 error', error as Error) + return state + } } } diff --git a/src/renderer/src/types/ocr.ts b/src/renderer/src/types/ocr.ts index 692ae7283d..d67cba958d 100644 --- a/src/renderer/src/types/ocr.ts +++ b/src/renderer/src/types/ocr.ts @@ -4,7 +4,8 @@ import { FileMetadata, ImageFileMetadata, isImageFileMetadata, TranslateLanguage export const BuiltinOcrProviderIds = { tesseract: 'tesseract', - system: 'system' + system: 'system', + paddleocr: 'paddleocr' } as const export type BuiltinOcrProviderId = keyof typeof BuiltinOcrProviderIds @@ -74,7 +75,7 @@ export type OcrProviderBaseConfig = { enabled?: boolean } -export type OcrProviderConfig = OcrApiProviderConfig | OcrTesseractConfig | OcrSystemConfig +export type OcrProviderConfig = OcrApiProviderConfig | OcrTesseractConfig | OcrSystemConfig | OcrPpocrConfig export type OcrProvider = { id: string @@ -170,3 +171,20 @@ export type OcrSystemProvider = { export const isOcrSystemProvider = (p: OcrProvider): p is OcrSystemProvider => { return p.id === BuiltinOcrProviderIds.system } + +// PaddleOCR Types +export type OcrPpocrConfig = OcrProviderBaseConfig & { + apiUrl?: string + accessToken?: string +} + +export type OcrPpocrProvider = { + id: 'paddleocr' + config: OcrPpocrConfig +} & ImageOcrProvider & + // PdfOcrProvider & + BuiltinOcrProvider + +export const isOcrPpocrProvider = (p: OcrProvider): p is OcrPpocrProvider => { + return p.id === BuiltinOcrProviderIds.paddleocr +} From b6d10656f9d8e480c026bf490eaaf07fa2524b16 Mon Sep 17 00:00:00 2001 From: Chen Tao <70054568+eeee0717@users.noreply.github.com> Date: Thu, 4 Sep 2025 17:23:31 +0800 Subject: [PATCH 3/3] feat: refactor Knowledge Base (#8384) Co-authored-by: icarus Co-authored-by: eeee0717 --- package.json | 17 +- packages/shared/config/types.ts | 2 +- src/main/ipc.ts | 3 +- .../{ => embedjs}/embeddings/Embeddings.ts | 0 .../embeddings/EmbeddingsFactory.ts | 0 .../embeddings/VoyageEmbeddings.ts | 0 .../loader/draftsExportLoader.ts | 0 .../{ => embedjs}/loader/epubLoader.ts | 0 .../knowledge/{ => embedjs}/loader/index.ts | 0 .../{ => embedjs}/loader/noteLoader.ts | 0 .../{ => embedjs}/loader/odLoader.ts | 0 .../langchain/embeddings/EmbeddingsFactory.ts | 63 ++ .../langchain/embeddings/JinaEmbeddings.ts | 199 ++++ .../langchain/embeddings/TextEmbeddings.ts | 25 + .../langchain/loader/MarkdownLoader.ts | 97 ++ .../knowledge/langchain/loader/NoteLoader.ts | 50 + .../langchain/loader/YoutubeLoader.ts | 170 +++ src/main/knowledge/langchain/loader/index.ts | 236 +++++ .../knowledge/langchain/retriever/index.ts | 55 + .../langchain/splitter/SrtSplitter.ts | 133 +++ .../knowledge/langchain/splitter/index.ts | 31 + .../preprocess/PreprocessingService.ts | 63 ++ src/main/knowledge/reranker/BaseReranker.ts | 140 +-- .../knowledge/reranker/GeneralReranker.ts | 9 +- src/main/knowledge/reranker/Reranker.ts | 5 +- .../reranker/strategies/BailianStrategy.ts | 18 + .../reranker/strategies/DefaultStrategy.ts | 25 + .../reranker/strategies/JinaStrategy.ts | 33 + .../reranker/strategies/RerankStrategy.ts | 9 + .../reranker/strategies/StrategyFactory.ts | 25 + .../reranker/strategies/TeiStrategy.ts | 26 + .../reranker/strategies/VoyageStrategy.ts | 24 + .../knowledge/reranker/strategies/types.ts | 19 + .../EmbedJsFramework.ts} | 431 ++------ .../services/knowledge/IKnowledgeFramework.ts | 72 ++ .../knowledge/KnowledgeFrameworkFactory.ts | 48 + .../services/knowledge/KnowledgeService.ts | 190 ++++ .../services/knowledge/LangChainFramework.ts | 555 ++++++++++ src/main/services/memory/MemoryService.ts | 2 +- src/main/utils/file.ts | 13 + src/main/utils/knowledge.ts | 13 + src/preload/index.ts | 9 +- src/renderer/index.html | 2 +- .../src/aiCore/chunk/handleToolCallChunk.ts | 13 + .../aiCore/legacy/clients/BaseApiClient.ts | 29 +- .../clients/anthropic/AnthropicAPIClient.ts | 27 +- .../legacy/clients/aws/AwsBedrockAPIClient.ts | 27 +- .../legacy/clients/gemini/GeminiAPIClient.ts | 63 +- .../legacy/clients/openai/OpenAIApiClient.ts | 30 +- .../clients/openai/OpenAIResponseAPIClient.ts | 23 +- .../src/aiCore/tools/KnowledgeSearchTool.ts | 3 +- .../src/components/Popups/VideoPopup.tsx | 206 ++++ src/renderer/src/databases/index.ts | 10 +- src/renderer/src/hooks/useKnowledge.ts | 130 ++- .../src/hooks/useKnowledgeBaseForm.ts | 14 +- src/renderer/src/hooks/useSettings.ts | 2 +- src/renderer/src/hooks/useTimer.ts | 2 +- src/renderer/src/i18n/locales/en-us.json | 43 +- src/renderer/src/i18n/locales/ja-jp.json | 43 +- src/renderer/src/i18n/locales/ru-ru.json | 43 +- src/renderer/src/i18n/locales/zh-cn.json | 41 +- src/renderer/src/i18n/locales/zh-tw.json | 43 +- src/renderer/src/i18n/translate/el-gr.json | 41 +- src/renderer/src/i18n/translate/es-es.json | 41 +- src/renderer/src/i18n/translate/fr-fr.json | 41 +- src/renderer/src/i18n/translate/pt-pt.json | 41 +- src/renderer/src/pages/files/FileItem.tsx | 8 +- .../pages/home/Messages/Blocks/VideoBlock.tsx | 14 + .../src/pages/home/Messages/Blocks/index.tsx | 85 +- .../src/pages/home/Messages/MessageVideo.tsx | 112 ++ .../src/pages/knowledge/KnowledgeContent.tsx | 69 +- .../__tests__/AdvancedSettingsPanel.test.tsx | 1 + .../__tests__/GeneralSettingsPanel.test.tsx | 89 +- .../GeneralSettingsPanel.test.tsx.snap | 102 +- .../KnowledgeBaseFormModal.test.tsx.snap | 2 +- .../components/AddKnowledgeBasePopup.tsx | 6 +- .../components/EditKnowledgeBasePopup.tsx | 9 +- .../KnowledgeSearchItem/TextItem.tsx | 31 + .../KnowledgeSearchItem/VideoItem.tsx | 138 +++ .../KnowledgeSearchItem/components.tsx | 54 + .../components/KnowledgeSearchItem/hooks.ts | 72 ++ .../components/KnowledgeSearchItem/index.tsx | 93 ++ .../components/KnowledgeSearchPopup.tsx | 129 +-- .../GeneralSettingsPanel.tsx | 79 +- .../KnowledgeBaseFormModal.tsx | 4 +- .../knowledge/components/MigrationInfoTag.tsx | 57 + .../pages/knowledge/components/QuotaTag.tsx | 16 +- .../knowledge/items/KnowledgeDirectories.tsx | 5 +- .../pages/knowledge/items/KnowledgeFiles.tsx | 12 +- .../pages/knowledge/items/KnowledgeNotes.tsx | 13 +- .../knowledge/items/KnowledgeSitemaps.tsx | 5 +- .../pages/knowledge/items/KnowledgeUrls.tsx | 5 +- .../pages/knowledge/items/KnowledgeVideos.tsx | 174 ++++ src/renderer/src/queue/KnowledgeQueue.ts | 12 +- src/renderer/src/services/KnowledgeService.ts | 69 +- .../src/services/StreamProcessingService.ts | 5 + src/renderer/src/services/WebSearchService.ts | 7 +- .../callbacks/imageCallbacks.ts | 15 + .../messageStreaming/callbacks/index.ts | 4 + .../callbacks/videoCallbacks.ts | 39 + src/renderer/src/store/index.ts | 2 +- src/renderer/src/store/knowledge.ts | 16 +- src/renderer/src/store/migrate.ts | 22 +- .../src/store/thunk/knowledgeThunk.ts | 16 +- src/renderer/src/types/chunk.ts | 31 +- src/renderer/src/types/file.ts | 13 +- src/renderer/src/types/index.ts | 102 +- src/renderer/src/types/knowledge.ts | 173 ++++ src/renderer/src/types/newMessage.ts | 12 +- src/renderer/src/utils/fetch.ts | 2 +- src/renderer/src/utils/file.ts | 24 +- src/renderer/src/utils/messageUtils/create.ts | 16 +- src/renderer/src/utils/messageUtils/is.ts | 46 +- yarn.lock | 973 +++++++++++++++++- 114 files changed, 5791 insertions(+), 960 deletions(-) rename src/main/knowledge/{ => embedjs}/embeddings/Embeddings.ts (100%) rename src/main/knowledge/{ => embedjs}/embeddings/EmbeddingsFactory.ts (100%) rename src/main/knowledge/{ => embedjs}/embeddings/VoyageEmbeddings.ts (100%) rename src/main/knowledge/{ => embedjs}/loader/draftsExportLoader.ts (100%) rename src/main/knowledge/{ => embedjs}/loader/epubLoader.ts (100%) rename src/main/knowledge/{ => embedjs}/loader/index.ts (100%) rename src/main/knowledge/{ => embedjs}/loader/noteLoader.ts (100%) rename src/main/knowledge/{ => embedjs}/loader/odLoader.ts (100%) create mode 100644 src/main/knowledge/langchain/embeddings/EmbeddingsFactory.ts create mode 100644 src/main/knowledge/langchain/embeddings/JinaEmbeddings.ts create mode 100644 src/main/knowledge/langchain/embeddings/TextEmbeddings.ts create mode 100644 src/main/knowledge/langchain/loader/MarkdownLoader.ts create mode 100644 src/main/knowledge/langchain/loader/NoteLoader.ts create mode 100644 src/main/knowledge/langchain/loader/YoutubeLoader.ts create mode 100644 src/main/knowledge/langchain/loader/index.ts create mode 100644 src/main/knowledge/langchain/retriever/index.ts create mode 100644 src/main/knowledge/langchain/splitter/SrtSplitter.ts create mode 100644 src/main/knowledge/langchain/splitter/index.ts create mode 100644 src/main/knowledge/preprocess/PreprocessingService.ts create mode 100644 src/main/knowledge/reranker/strategies/BailianStrategy.ts create mode 100644 src/main/knowledge/reranker/strategies/DefaultStrategy.ts create mode 100644 src/main/knowledge/reranker/strategies/JinaStrategy.ts create mode 100644 src/main/knowledge/reranker/strategies/RerankStrategy.ts create mode 100644 src/main/knowledge/reranker/strategies/StrategyFactory.ts create mode 100644 src/main/knowledge/reranker/strategies/TeiStrategy.ts create mode 100644 src/main/knowledge/reranker/strategies/VoyageStrategy.ts create mode 100644 src/main/knowledge/reranker/strategies/types.ts rename src/main/services/{KnowledgeService.ts => knowledge/EmbedJsFramework.ts} (52%) create mode 100644 src/main/services/knowledge/IKnowledgeFramework.ts create mode 100644 src/main/services/knowledge/KnowledgeFrameworkFactory.ts create mode 100644 src/main/services/knowledge/KnowledgeService.ts create mode 100644 src/main/services/knowledge/LangChainFramework.ts create mode 100644 src/main/utils/knowledge.ts create mode 100644 src/renderer/src/components/Popups/VideoPopup.tsx create mode 100644 src/renderer/src/pages/home/Messages/Blocks/VideoBlock.tsx create mode 100644 src/renderer/src/pages/home/Messages/MessageVideo.tsx create mode 100644 src/renderer/src/pages/knowledge/components/KnowledgeSearchItem/TextItem.tsx create mode 100644 src/renderer/src/pages/knowledge/components/KnowledgeSearchItem/VideoItem.tsx create mode 100644 src/renderer/src/pages/knowledge/components/KnowledgeSearchItem/components.tsx create mode 100644 src/renderer/src/pages/knowledge/components/KnowledgeSearchItem/hooks.ts create mode 100644 src/renderer/src/pages/knowledge/components/KnowledgeSearchItem/index.tsx create mode 100644 src/renderer/src/pages/knowledge/components/MigrationInfoTag.tsx create mode 100644 src/renderer/src/pages/knowledge/items/KnowledgeVideos.tsx create mode 100644 src/renderer/src/services/messageStreaming/callbacks/videoCallbacks.ts create mode 100644 src/renderer/src/types/knowledge.ts diff --git a/package.json b/package.json index c84bc483d4..a651fe00de 100644 --- a/package.json +++ b/package.json @@ -74,15 +74,23 @@ "@libsql/win32-x64-msvc": "^0.4.7", "@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch", "@strongtz/win32-arm64-msvc": "^0.4.7", + "cheerio": "^1.1.2", + "faiss-node": "^0.5.1", "graceful-fs": "^4.2.11", + "html-to-text": "^9.0.5", + "htmlparser2": "^10.0.0", "jsdom": "26.1.0", "node-stream-zip": "^1.15.0", "officeparser": "^4.2.0", "os-proxy-config": "^1.1.2", + "pdf-parse": "^1.1.1", + "react-player": "^3.3.1", + "react-youtube": "^10.1.0", "selection-hook": "^1.0.11", "sharp": "^0.34.3", "tesseract.js": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch", - "turndown": "7.2.0" + "turndown": "7.2.0", + "youtubei.js": "^15.0.1" }, "devDependencies": { "@agentic/exa": "^7.3.3", @@ -127,8 +135,10 @@ "@google/genai": "patch:@google/genai@npm%3A1.0.1#~/.yarn/patches/@google-genai-npm-1.0.1-e26f0f9af7.patch", "@hello-pangea/dnd": "^18.0.1", "@kangfenmao/keyv-storage": "^0.1.0", - "@langchain/community": "^0.3.36", + "@langchain/community": "^0.3.50", + "@langchain/core": "^0.3.68", "@langchain/ollama": "^0.2.1", + "@langchain/openai": "^0.6.7", "@mistralai/mistralai": "^1.7.5", "@modelcontextprotocol/sdk": "^1.17.0", "@mozilla/readability": "^0.6.0", @@ -171,9 +181,11 @@ "@types/cli-progress": "^3", "@types/fs-extra": "^11", "@types/he": "^1", + "@types/html-to-text": "^9", "@types/lodash": "^4.17.5", "@types/markdown-it": "^14", "@types/md5": "^2.3.5", + "@types/mime-types": "^3", "@types/node": "^22.17.1", "@types/pako": "^1.0.2", "@types/react": "^19.0.12", @@ -253,6 +265,7 @@ "markdown-it": "^14.1.0", "mermaid": "^11.10.1", "mime": "^4.0.4", + "mime-types": "^3.0.1", "motion": "^12.10.5", "notion-helper": "^1.3.22", "npx-scope-finder": "^1.2.0", diff --git a/packages/shared/config/types.ts b/packages/shared/config/types.ts index d46717b47e..8fb11cd8b3 100644 --- a/packages/shared/config/types.ts +++ b/packages/shared/config/types.ts @@ -7,7 +7,7 @@ export type LoaderReturn = { loaderType: string status?: ProcessingStatus message?: string - messageSource?: 'preprocess' | 'embedding' + messageSource?: 'preprocess' | 'embedding' | 'validation' } export type FileChangeEventType = 'add' | 'change' | 'unlink' | 'addDir' | 'unlinkDir' diff --git a/src/main/ipc.ts b/src/main/ipc.ts index 8b22fee49c..f989a3d808 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -24,7 +24,7 @@ import DxtService from './services/DxtService' import { ExportService } from './services/ExportService' import { fileStorage as fileManager } from './services/FileStorage' import FileService from './services/FileSystemService' -import KnowledgeService from './services/KnowledgeService' +import KnowledgeService from './services/knowledge/KnowledgeService' import mcpService from './services/MCPService' import MemoryService from './services/memory/MemoryService' import { openTraceWindow, setTraceWindowTitle } from './services/NodeTraceService' @@ -524,7 +524,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { } }) - // knowledge base ipcMain.handle(IpcChannel.KnowledgeBase_Create, KnowledgeService.create.bind(KnowledgeService)) ipcMain.handle(IpcChannel.KnowledgeBase_Reset, KnowledgeService.reset.bind(KnowledgeService)) ipcMain.handle(IpcChannel.KnowledgeBase_Delete, KnowledgeService.delete.bind(KnowledgeService)) diff --git a/src/main/knowledge/embeddings/Embeddings.ts b/src/main/knowledge/embedjs/embeddings/Embeddings.ts similarity index 100% rename from src/main/knowledge/embeddings/Embeddings.ts rename to src/main/knowledge/embedjs/embeddings/Embeddings.ts diff --git a/src/main/knowledge/embeddings/EmbeddingsFactory.ts b/src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts similarity index 100% rename from src/main/knowledge/embeddings/EmbeddingsFactory.ts rename to src/main/knowledge/embedjs/embeddings/EmbeddingsFactory.ts diff --git a/src/main/knowledge/embeddings/VoyageEmbeddings.ts b/src/main/knowledge/embedjs/embeddings/VoyageEmbeddings.ts similarity index 100% rename from src/main/knowledge/embeddings/VoyageEmbeddings.ts rename to src/main/knowledge/embedjs/embeddings/VoyageEmbeddings.ts diff --git a/src/main/knowledge/loader/draftsExportLoader.ts b/src/main/knowledge/embedjs/loader/draftsExportLoader.ts similarity index 100% rename from src/main/knowledge/loader/draftsExportLoader.ts rename to src/main/knowledge/embedjs/loader/draftsExportLoader.ts diff --git a/src/main/knowledge/loader/epubLoader.ts b/src/main/knowledge/embedjs/loader/epubLoader.ts similarity index 100% rename from src/main/knowledge/loader/epubLoader.ts rename to src/main/knowledge/embedjs/loader/epubLoader.ts diff --git a/src/main/knowledge/loader/index.ts b/src/main/knowledge/embedjs/loader/index.ts similarity index 100% rename from src/main/knowledge/loader/index.ts rename to src/main/knowledge/embedjs/loader/index.ts diff --git a/src/main/knowledge/loader/noteLoader.ts b/src/main/knowledge/embedjs/loader/noteLoader.ts similarity index 100% rename from src/main/knowledge/loader/noteLoader.ts rename to src/main/knowledge/embedjs/loader/noteLoader.ts diff --git a/src/main/knowledge/loader/odLoader.ts b/src/main/knowledge/embedjs/loader/odLoader.ts similarity index 100% rename from src/main/knowledge/loader/odLoader.ts rename to src/main/knowledge/embedjs/loader/odLoader.ts diff --git a/src/main/knowledge/langchain/embeddings/EmbeddingsFactory.ts b/src/main/knowledge/langchain/embeddings/EmbeddingsFactory.ts new file mode 100644 index 0000000000..d2879aa349 --- /dev/null +++ b/src/main/knowledge/langchain/embeddings/EmbeddingsFactory.ts @@ -0,0 +1,63 @@ +import { VoyageEmbeddings } from '@langchain/community/embeddings/voyage' +import type { Embeddings } from '@langchain/core/embeddings' +import { OllamaEmbeddings } from '@langchain/ollama' +import { AzureOpenAIEmbeddings, OpenAIEmbeddings } from '@langchain/openai' +import { ApiClient, SystemProviderIds } from '@types' + +import { isJinaEmbeddingsModel, JinaEmbeddings } from './JinaEmbeddings' + +export default class EmbeddingsFactory { + static create({ embedApiClient, dimensions }: { embedApiClient: ApiClient; dimensions?: number }): Embeddings { + const batchSize = 10 + const { model, provider, apiKey, apiVersion, baseURL } = embedApiClient + if (provider === SystemProviderIds.ollama) { + let baseUrl = baseURL + if (baseURL.includes('v1/')) { + baseUrl = baseURL.replace('v1/', '') + } + const headers = apiKey + ? { + Authorization: `Bearer ${apiKey}` + } + : undefined + return new OllamaEmbeddings({ + model: model, + baseUrl, + ...headers + }) + } else if (provider === SystemProviderIds.voyageai) { + return new VoyageEmbeddings({ + modelName: model, + apiKey, + outputDimension: dimensions, + batchSize + }) + } + if (isJinaEmbeddingsModel(model)) { + return new JinaEmbeddings({ + model, + apiKey, + batchSize, + dimensions, + baseUrl: baseURL + }) + } + if (apiVersion !== undefined) { + return new AzureOpenAIEmbeddings({ + azureOpenAIApiKey: apiKey, + azureOpenAIApiVersion: apiVersion, + azureOpenAIApiDeploymentName: model, + azureOpenAIEndpoint: baseURL, + dimensions, + batchSize + }) + } + return new OpenAIEmbeddings({ + model, + apiKey, + dimensions, + batchSize, + configuration: { baseURL } + }) + } +} diff --git a/src/main/knowledge/langchain/embeddings/JinaEmbeddings.ts b/src/main/knowledge/langchain/embeddings/JinaEmbeddings.ts new file mode 100644 index 0000000000..0a6c5f1f84 --- /dev/null +++ b/src/main/knowledge/langchain/embeddings/JinaEmbeddings.ts @@ -0,0 +1,199 @@ +import { Embeddings, type EmbeddingsParams } from '@langchain/core/embeddings' +import { chunkArray } from '@langchain/core/utils/chunk_array' +import { getEnvironmentVariable } from '@langchain/core/utils/env' +import z from 'zod/v4' + +const jinaModelSchema = z.union([ + z.literal('jina-clip-v2'), + z.literal('jina-embeddings-v3'), + z.literal('jina-colbert-v2'), + z.literal('jina-clip-v1'), + z.literal('jina-colbert-v1-en'), + z.literal('jina-embeddings-v2-base-es'), + z.literal('jina-embeddings-v2-base-code'), + z.literal('jina-embeddings-v2-base-de'), + z.literal('jina-embeddings-v2-base-zh'), + z.literal('jina-embeddings-v2-base-en') +]) + +type JinaModel = z.infer + +export const isJinaEmbeddingsModel = (model: string): model is JinaModel => { + return jinaModelSchema.safeParse(model).success +} + +interface JinaEmbeddingsParams extends EmbeddingsParams { + /** Model name to use */ + model: JinaModel + + baseUrl?: string + + /** + * Timeout to use when making requests to Jina. + */ + timeout?: number + + /** + * The maximum number of documents to embed in a single request. + */ + batchSize?: number + + /** + * Whether to strip new lines from the input text. + */ + stripNewLines?: boolean + + /** + * The dimensions of the embedding. + */ + dimensions?: number + + /** + * Scales the embedding so its Euclidean (L2) norm becomes 1, preserving direction. Useful when downstream involves dot-product, classification, visualization.. + */ + normalized?: boolean +} + +type JinaMultiModelInput = + | { + text: string + image?: never + } + | { + image: string + text?: never + } + +type JinaEmbeddingsInput = string | JinaMultiModelInput + +interface EmbeddingCreateParams { + model: JinaEmbeddingsParams['model'] + + /** + * input can be strings or JinaMultiModelInputs,if you want embed image,you should use JinaMultiModelInputs + */ + input: JinaEmbeddingsInput[] + dimensions: number + task?: 'retrieval.query' | 'retrieval.passage' +} + +interface EmbeddingResponse { + model: string + object: string + usage: { + total_tokens: number + prompt_tokens: number + } + data: { + object: string + index: number + embedding: number[] + }[] +} + +interface EmbeddingErrorResponse { + detail: string +} + +export class JinaEmbeddings extends Embeddings implements JinaEmbeddingsParams { + model: JinaEmbeddingsParams['model'] = 'jina-clip-v2' + + batchSize = 24 + + baseUrl = 'https://api.jina.ai/v1/embeddings' + + stripNewLines = true + + dimensions = 1024 + + apiKey: string + + constructor( + fields?: Partial & { + apiKey?: string + } + ) { + const fieldsWithDefaults = { maxConcurrency: 2, ...fields } + super(fieldsWithDefaults) + + const apiKey = + fieldsWithDefaults?.apiKey || getEnvironmentVariable('JINA_API_KEY') || getEnvironmentVariable('JINA_AUTH_TOKEN') + + if (!apiKey) throw new Error('Jina API key not found') + + this.apiKey = apiKey + this.baseUrl = fieldsWithDefaults?.baseUrl ? `${fieldsWithDefaults?.baseUrl}embeddings` : this.baseUrl + this.model = fieldsWithDefaults?.model ?? this.model + this.dimensions = fieldsWithDefaults?.dimensions ?? this.dimensions + this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize + this.stripNewLines = fieldsWithDefaults?.stripNewLines ?? this.stripNewLines + } + + private doStripNewLines(input: JinaEmbeddingsInput[]) { + if (this.stripNewLines) { + return input.map((i) => { + if (typeof i === 'string') { + return i.replace(/\n/g, ' ') + } + if (i.text) { + return { text: i.text.replace(/\n/g, ' ') } + } + return i + }) + } + return input + } + + async embedDocuments(input: JinaEmbeddingsInput[]): Promise { + const batches = chunkArray(this.doStripNewLines(input), this.batchSize) + const batchRequests = batches.map((batch) => { + const params = this.getParams(batch) + return this.embeddingWithRetry(params) + }) + + const batchResponses = await Promise.all(batchRequests) + const embeddings: number[][] = [] + + for (let i = 0; i < batchResponses.length; i += 1) { + const batch = batches[i] + const batchResponse = batchResponses[i] || [] + for (let j = 0; j < batch.length; j += 1) { + embeddings.push(batchResponse[j]) + } + } + + return embeddings + } + + async embedQuery(input: JinaEmbeddingsInput): Promise { + const params = this.getParams(this.doStripNewLines([input]), true) + + const embeddings = (await this.embeddingWithRetry(params)) || [[]] + return embeddings[0] + } + + private getParams(input: JinaEmbeddingsInput[], query?: boolean): EmbeddingCreateParams { + return { + model: this.model, + input, + dimensions: this.dimensions, + task: query ? 'retrieval.query' : this.model === 'jina-clip-v2' ? undefined : 'retrieval.passage' + } + } + + private async embeddingWithRetry(body: EmbeddingCreateParams) { + const response = await fetch(this.baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${this.apiKey}` + }, + body: JSON.stringify(body) + }) + const embeddingData: EmbeddingResponse | EmbeddingErrorResponse = await response.json() + if ('detail' in embeddingData && embeddingData.detail) { + throw new Error(`${embeddingData.detail}`) + } + return (embeddingData as EmbeddingResponse).data.map(({ embedding }) => embedding) + } +} diff --git a/src/main/knowledge/langchain/embeddings/TextEmbeddings.ts b/src/main/knowledge/langchain/embeddings/TextEmbeddings.ts new file mode 100644 index 0000000000..b788070d4d --- /dev/null +++ b/src/main/knowledge/langchain/embeddings/TextEmbeddings.ts @@ -0,0 +1,25 @@ +import type { Embeddings as BaseEmbeddings } from '@langchain/core/embeddings' +import { TraceMethod } from '@mcp-trace/trace-core' +import { ApiClient } from '@types' + +import EmbeddingsFactory from './EmbeddingsFactory' + +export default class TextEmbeddings { + private sdk: BaseEmbeddings + constructor({ embedApiClient, dimensions }: { embedApiClient: ApiClient; dimensions?: number }) { + this.sdk = EmbeddingsFactory.create({ + embedApiClient, + dimensions + }) + } + + @TraceMethod({ spanName: 'embedDocuments', tag: 'Embeddings' }) + public async embedDocuments(texts: string[]): Promise { + return this.sdk.embedDocuments(texts) + } + + @TraceMethod({ spanName: 'embedQuery', tag: 'Embeddings' }) + public async embedQuery(text: string): Promise { + return this.sdk.embedQuery(text) + } +} diff --git a/src/main/knowledge/langchain/loader/MarkdownLoader.ts b/src/main/knowledge/langchain/loader/MarkdownLoader.ts new file mode 100644 index 0000000000..287c9a0bda --- /dev/null +++ b/src/main/knowledge/langchain/loader/MarkdownLoader.ts @@ -0,0 +1,97 @@ +import { BaseDocumentLoader } from '@langchain/core/document_loaders/base' +import { Document } from '@langchain/core/documents' +import { readTextFileWithAutoEncoding } from '@main/utils/file' +import MarkdownIt from 'markdown-it' + +export class MarkdownLoader extends BaseDocumentLoader { + private path: string + private md: MarkdownIt + + constructor(path: string) { + super() + this.path = path + this.md = new MarkdownIt() + } + public async load(): Promise { + const content = await readTextFileWithAutoEncoding(this.path) + return this.parseMarkdown(content) + } + + private parseMarkdown(content: string): Document[] { + const tokens = this.md.parse(content, {}) + const documents: Document[] = [] + + let currentSection: { + heading?: string + level?: number + content: string + startLine?: number + } = { content: '' } + + let i = 0 + while (i < tokens.length) { + const token = tokens[i] + + if (token.type === 'heading_open') { + // Save previous section if it has content + if (currentSection.content.trim()) { + documents.push( + new Document({ + pageContent: currentSection.content.trim(), + metadata: { + source: this.path, + heading: currentSection.heading || 'Introduction', + level: currentSection.level || 0, + startLine: currentSection.startLine || 0 + } + }) + ) + } + + // Start new section + const level = parseInt(token.tag.slice(1)) // Extract number from h1, h2, etc. + const headingContent = tokens[i + 1]?.content || '' + + currentSection = { + heading: headingContent, + level: level, + content: '', + startLine: token.map?.[0] || 0 + } + + // Skip heading_open, inline, heading_close tokens + i += 3 + continue + } + + // Add token content to current section + if (token.content) { + currentSection.content += token.content + } + + // Add newlines for block tokens + if (token.block && token.type !== 'heading_close') { + currentSection.content += '\n' + } + + i++ + } + + // Add the last section + if (currentSection.content.trim()) { + documents.push( + new Document({ + pageContent: currentSection.content.trim(), + metadata: { + source: this.path, + heading: currentSection.heading || 'Introduction', + level: currentSection.level || 0, + startLine: currentSection.startLine || 0 + } + }) + ) + } + + return documents + } +} diff --git a/src/main/knowledge/langchain/loader/NoteLoader.ts b/src/main/knowledge/langchain/loader/NoteLoader.ts new file mode 100644 index 0000000000..d0339a6ce7 --- /dev/null +++ b/src/main/knowledge/langchain/loader/NoteLoader.ts @@ -0,0 +1,50 @@ +import { BaseDocumentLoader } from '@langchain/core/document_loaders/base' +import { Document } from '@langchain/core/documents' + +export class NoteLoader extends BaseDocumentLoader { + private text: string + private sourceUrl?: string + constructor( + public _text: string, + public _sourceUrl?: string + ) { + super() + this.text = _text + this.sourceUrl = _sourceUrl + } + + /** + * A protected method that takes a `raw` string as a parameter and returns + * a promise that resolves to an array containing the raw text as a single + * element. + * @param raw The raw text to be parsed. + * @returns A promise that resolves to an array containing the raw text as a single element. + */ + protected async parse(raw: string): Promise { + return [raw] + } + + public async load(): Promise { + const metadata = { source: this.sourceUrl || 'note' } + const parsed = await this.parse(this.text) + parsed.forEach((pageContent, i) => { + if (typeof pageContent !== 'string') { + throw new Error(`Expected string, at position ${i} got ${typeof pageContent}`) + } + }) + + return parsed.map( + (pageContent, i) => + new Document({ + pageContent, + metadata: + parsed.length === 1 + ? metadata + : { + ...metadata, + line: i + 1 + } + }) + ) + } +} diff --git a/src/main/knowledge/langchain/loader/YoutubeLoader.ts b/src/main/knowledge/langchain/loader/YoutubeLoader.ts new file mode 100644 index 0000000000..671793d3c7 --- /dev/null +++ b/src/main/knowledge/langchain/loader/YoutubeLoader.ts @@ -0,0 +1,170 @@ +import { BaseDocumentLoader } from '@langchain/core/document_loaders/base' +import { Document } from '@langchain/core/documents' +import { Innertube } from 'youtubei.js' + +// ... (接口定义 YoutubeConfig 和 VideoMetadata 保持不变) + +/** + * Configuration options for the YoutubeLoader class. Includes properties + * such as the videoId, language, and addVideoInfo. + */ +interface YoutubeConfig { + videoId: string + language?: string + addVideoInfo?: boolean + // 新增一个选项,用于控制输出格式 + transcriptFormat?: 'text' | 'srt' +} + +/** + * Metadata of a YouTube video. Includes properties such as the source + * (videoId), description, title, view_count, author, and category. + */ +interface VideoMetadata { + source: string + description?: string + title?: string + view_count?: number + author?: string + category?: string +} + +/** + * A document loader for loading data from YouTube videos. It uses the + * youtubei.js library to fetch the transcript and video metadata. + * @example + * ```typescript + * const loader = new YoutubeLoader({ + * videoId: "VIDEO_ID", + * language: "en", + * addVideoInfo: true, + * transcriptFormat: "srt" // 获取 SRT 格式 + * }); + * const docs = await loader.load(); + * console.log(docs[0].pageContent); + * ``` + */ +export class YoutubeLoader extends BaseDocumentLoader { + private videoId: string + private language?: string + private addVideoInfo: boolean + // 新增格式化选项的私有属性 + private transcriptFormat: 'text' | 'srt' + + constructor(config: YoutubeConfig) { + super() + this.videoId = config.videoId + this.language = config?.language + this.addVideoInfo = config?.addVideoInfo ?? false + // 初始化格式化选项,默认为 'text' 以保持向后兼容 + this.transcriptFormat = config?.transcriptFormat ?? 'text' + } + + /** + * Extracts the videoId from a YouTube video URL. + * @param url The URL of the YouTube video. + * @returns The videoId of the YouTube video. + */ + private static getVideoID(url: string): string { + const match = url.match(/.*(?:youtu.be\/|v\/|u\/\w\/|embed\/|watch\?v=)([^#&?]*).*/) + if (match !== null && match[1].length === 11) { + return match[1] + } else { + throw new Error('Failed to get youtube video id from the url') + } + } + + /** + * Creates a new instance of the YoutubeLoader class from a YouTube video + * URL. + * @param url The URL of the YouTube video. + * @param config Optional configuration options for the YoutubeLoader instance, excluding the videoId. + * @returns A new instance of the YoutubeLoader class. + */ + static createFromUrl(url: string, config?: Omit): YoutubeLoader { + const videoId = YoutubeLoader.getVideoID(url) + return new YoutubeLoader({ ...config, videoId }) + } + + /** + * [新增] 辅助函数:将毫秒转换为 SRT 时间戳格式 (HH:MM:SS,ms) + * @param ms 毫秒数 + * @returns 格式化后的时间字符串 + */ + private static formatTimestamp(ms: number): string { + const totalSeconds = Math.floor(ms / 1000) + const hours = Math.floor(totalSeconds / 3600) + .toString() + .padStart(2, '0') + const minutes = Math.floor((totalSeconds % 3600) / 60) + .toString() + .padStart(2, '0') + const seconds = (totalSeconds % 60).toString().padStart(2, '0') + const milliseconds = (ms % 1000).toString().padStart(3, '0') + return `${hours}:${minutes}:${seconds},${milliseconds}` + } + + /** + * Loads the transcript and video metadata from the specified YouTube + * video. It can return the transcript as plain text or in SRT format. + * @returns An array of Documents representing the retrieved data. + */ + async load(): Promise { + const metadata: VideoMetadata = { + source: this.videoId + } + + try { + const youtube = await Innertube.create({ + lang: this.language, + retrieve_player: false + }) + + const info = await youtube.getInfo(this.videoId) + const transcriptData = await info.getTranscript() + + if (!transcriptData.transcript.content?.body?.initial_segments) { + throw new Error('Transcript segments not found in the response.') + } + + const segments = transcriptData.transcript.content.body.initial_segments + + let pageContent: string + + // 根据 transcriptFormat 选项决定如何格式化字幕 + if (this.transcriptFormat === 'srt') { + // [修改] 将字幕片段格式化为 SRT 格式 + pageContent = segments + .map((segment, index) => { + const srtIndex = index + 1 + const startTime = YoutubeLoader.formatTimestamp(Number(segment.start_ms)) + const endTime = YoutubeLoader.formatTimestamp(Number(segment.end_ms)) + const text = segment.snippet?.text || '' // 使用 segment.snippet.text + + return `${srtIndex}\n${startTime} --> ${endTime}\n${text}` + }) + .join('\n\n') // 每个 SRT 块之间用两个换行符分隔 + } else { + // [原始逻辑] 拼接为纯文本 + pageContent = segments.map((segment) => segment.snippet?.text || '').join(' ') + } + + if (this.addVideoInfo) { + const basicInfo = info.basic_info + metadata.description = basicInfo.short_description + metadata.title = basicInfo.title + metadata.view_count = basicInfo.view_count + metadata.author = basicInfo.author + } + + const document = new Document({ + pageContent, + metadata + }) + + return [document] + } catch (e: unknown) { + throw new Error(`Failed to get YouTube video transcription: ${(e as Error).message}`) + } + } +} diff --git a/src/main/knowledge/langchain/loader/index.ts b/src/main/knowledge/langchain/loader/index.ts new file mode 100644 index 0000000000..f4718f71f8 --- /dev/null +++ b/src/main/knowledge/langchain/loader/index.ts @@ -0,0 +1,236 @@ +import { DocxLoader } from '@langchain/community/document_loaders/fs/docx' +import { EPubLoader } from '@langchain/community/document_loaders/fs/epub' +import { PDFLoader } from '@langchain/community/document_loaders/fs/pdf' +import { PPTXLoader } from '@langchain/community/document_loaders/fs/pptx' +import { CheerioWebBaseLoader } from '@langchain/community/document_loaders/web/cheerio' +import { SitemapLoader } from '@langchain/community/document_loaders/web/sitemap' +import { FaissStore } from '@langchain/community/vectorstores/faiss' +import { Document } from '@langchain/core/documents' +import { loggerService } from '@logger' +import { UrlSource } from '@main/utils/knowledge' +import { LoaderReturn } from '@shared/config/types' +import { FileMetadata, FileTypes, KnowledgeBaseParams } from '@types' +import { randomUUID } from 'crypto' +import { JSONLoader } from 'langchain/document_loaders/fs/json' +import { TextLoader } from 'langchain/document_loaders/fs/text' + +import { SplitterFactory } from '../splitter' +import { MarkdownLoader } from './MarkdownLoader' +import { NoteLoader } from './NoteLoader' +import { YoutubeLoader } from './YoutubeLoader' + +const logger = loggerService.withContext('KnowledgeService File Loader') + +type LoaderInstance = + | TextLoader + | PDFLoader + | PPTXLoader + | DocxLoader + | JSONLoader + | EPubLoader + | CheerioWebBaseLoader + | YoutubeLoader + | SitemapLoader + | NoteLoader + | MarkdownLoader + +/** + * 为文档数组中的每个文档的 metadata 添加类型信息。 + */ +function formatDocument(docs: Document[], type: string): Document[] { + return docs.map((doc) => ({ + ...doc, + metadata: { + ...doc.metadata, + type: type + } + })) +} + +/** + * 通用文档处理管道 + */ +async function processDocuments( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + docs: Document[], + loaderType: string, + splitterType?: string +): Promise { + const formattedDocs = formatDocument(docs, loaderType) + const splitter = SplitterFactory.create({ + chunkSize: base.chunkSize, + chunkOverlap: base.chunkOverlap, + ...(splitterType && { type: splitterType }) + }) + + const splitterResults = await splitter.splitDocuments(formattedDocs) + const ids = splitterResults.map(() => randomUUID()) + + await vectorStore.addDocuments(splitterResults, { ids }) + + return { + entriesAdded: splitterResults.length, + uniqueId: ids[0] || '', + uniqueIds: ids, + loaderType + } +} + +/** + * 通用加载器执行函数 + */ +async function executeLoader( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + loaderInstance: LoaderInstance, + loaderType: string, + identifier: string, + splitterType?: string +): Promise { + const emptyResult: LoaderReturn = { + entriesAdded: 0, + uniqueId: '', + uniqueIds: [], + loaderType + } + + try { + const docs = await loaderInstance.load() + return await processDocuments(base, vectorStore, docs, loaderType, splitterType) + } catch (error) { + logger.error(`Error loading or processing ${identifier} with loader ${loaderType}: ${error}`) + return emptyResult + } +} + +/** + * 文件扩展名到加载器的映射 + */ +const FILE_LOADER_MAP: Record LoaderInstance; type: string }> = { + '.pdf': { loader: PDFLoader, type: 'pdf' }, + '.txt': { loader: TextLoader, type: 'text' }, + '.pptx': { loader: PPTXLoader, type: 'pptx' }, + '.docx': { loader: DocxLoader, type: 'docx' }, + '.doc': { loader: DocxLoader, type: 'doc' }, + '.json': { loader: JSONLoader, type: 'json' }, + '.epub': { loader: EPubLoader, type: 'epub' }, + '.md': { loader: MarkdownLoader, type: 'markdown' } +} + +export async function addFileLoader( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + file: FileMetadata +): Promise { + const fileExt = file.ext.toLowerCase() + const loaderConfig = FILE_LOADER_MAP[fileExt] + + if (!loaderConfig) { + // 默认使用文本加载器 + const loaderInstance = new TextLoader(file.path) + const type = fileExt.replace('.', '') || 'unknown' + return executeLoader(base, vectorStore, loaderInstance, type, file.path) + } + + const loaderInstance = new loaderConfig.loader(file.path) + return executeLoader(base, vectorStore, loaderInstance, loaderConfig.type, file.path) +} + +export async function addWebLoader( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + url: string, + source: UrlSource +): Promise { + let loaderInstance: CheerioWebBaseLoader | YoutubeLoader | undefined + let splitterType: string | undefined + + switch (source) { + case 'normal': + loaderInstance = new CheerioWebBaseLoader(url) + break + case 'youtube': + loaderInstance = YoutubeLoader.createFromUrl(url, { + addVideoInfo: true, + transcriptFormat: 'srt', + language: 'zh' + }) + splitterType = 'srt' + break + } + + if (!loaderInstance) { + return { + entriesAdded: 0, + uniqueId: '', + uniqueIds: [], + loaderType: source + } + } + + return executeLoader(base, vectorStore, loaderInstance, source, url, splitterType) +} + +export async function addSitemapLoader( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + url: string +): Promise { + const loaderInstance = new SitemapLoader(url) + return executeLoader(base, vectorStore, loaderInstance, 'sitemap', url) +} + +export async function addNoteLoader( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + content: string, + sourceUrl: string +): Promise { + const loaderInstance = new NoteLoader(content, sourceUrl) + return executeLoader(base, vectorStore, loaderInstance, 'note', sourceUrl) +} + +export async function addVideoLoader( + base: KnowledgeBaseParams, + vectorStore: FaissStore, + files: FileMetadata[] +): Promise { + const srtFile = files.find((f) => f.type === FileTypes.TEXT) + const videoFile = files.find((f) => f.type === FileTypes.VIDEO) + + const emptyResult: LoaderReturn = { + entriesAdded: 0, + uniqueId: '', + uniqueIds: [], + loaderType: 'video' + } + + if (!srtFile || !videoFile) { + return emptyResult + } + + try { + const loaderInstance = new TextLoader(srtFile.path) + const originalDocs = await loaderInstance.load() + + const docsWithVideoMeta = originalDocs.map( + (doc) => + new Document({ + ...doc, + metadata: { + ...doc.metadata, + video: { + path: videoFile.path, + name: videoFile.origin_name + } + } + }) + ) + + return await processDocuments(base, vectorStore, docsWithVideoMeta, 'video', 'srt') + } catch (error) { + logger.error(`Error loading or processing file ${srtFile.path} with loader video: ${error}`) + return emptyResult + } +} diff --git a/src/main/knowledge/langchain/retriever/index.ts b/src/main/knowledge/langchain/retriever/index.ts new file mode 100644 index 0000000000..7f673d1deb --- /dev/null +++ b/src/main/knowledge/langchain/retriever/index.ts @@ -0,0 +1,55 @@ +import { BM25Retriever } from '@langchain/community/retrievers/bm25' +import { FaissStore } from '@langchain/community/vectorstores/faiss' +import { BaseRetriever } from '@langchain/core/retrievers' +import { loggerService } from '@main/services/LoggerService' +import { type KnowledgeBaseParams } from '@types' +import { type Document } from 'langchain/document' +import { EnsembleRetriever } from 'langchain/retrievers/ensemble' + +const logger = loggerService.withContext('RetrieverFactory') +export class RetrieverFactory { + /** + * 根据提供的参数创建一个 LangChain 检索器 (Retriever)。 + * @param base 知识库配置参数。 + * @param vectorStore 一个已初始化的向量存储实例。 + * @param documents 文档列表,用于初始化 BM25Retriever。 + * @returns 返回一个 BaseRetriever 实例。 + */ + public createRetriever(base: KnowledgeBaseParams, vectorStore: FaissStore, documents: Document[]): BaseRetriever { + const retrieverType = base.retriever?.mode ?? 'hybrid' + const retrieverWeight = base.retriever?.weight ?? 0.5 + const searchK = base.documentCount ?? 5 + + logger.info(`Creating retriever of type: ${retrieverType} with k=${searchK}`) + + switch (retrieverType) { + case 'bm25': + if (documents.length === 0) { + throw new Error('BM25Retriever requires documents, but none were provided or found.') + } + logger.info('Create BM25 Retriever') + return BM25Retriever.fromDocuments(documents, { k: searchK }) + + case 'hybrid': { + if (documents.length === 0) { + logger.warn('No documents provided for BM25 part of hybrid search. Falling back to vector search only.') + return vectorStore.asRetriever(searchK) + } + + const vectorstoreRetriever = vectorStore.asRetriever(searchK) + const bm25Retriever = BM25Retriever.fromDocuments(documents, { k: searchK }) + + logger.info('Create Hybrid Retriever') + return new EnsembleRetriever({ + retrievers: [bm25Retriever, vectorstoreRetriever], + weights: [retrieverWeight, 1 - retrieverWeight] + }) + } + + case 'vector': + default: + logger.info('Create Vector Retriever') + return vectorStore.asRetriever(searchK) + } + } +} diff --git a/src/main/knowledge/langchain/splitter/SrtSplitter.ts b/src/main/knowledge/langchain/splitter/SrtSplitter.ts new file mode 100644 index 0000000000..2e7d47da76 --- /dev/null +++ b/src/main/knowledge/langchain/splitter/SrtSplitter.ts @@ -0,0 +1,133 @@ +import { Document } from '@langchain/core/documents' +import { TextSplitter, TextSplitterParams } from 'langchain/text_splitter' + +// 定义一个接口来表示解析后的单个字幕片段 +interface SrtSegment { + text: string + startTime: number // in seconds + endTime: number // in seconds +} + +// 辅助函数:将 SRT 时间戳字符串 (HH:MM:SS,ms) 转换为秒 +function srtTimeToSeconds(time: string): number { + const parts = time.split(':') + const secondsAndMs = parts[2].split(',') + const hours = parseInt(parts[0], 10) + const minutes = parseInt(parts[1], 10) + const seconds = parseInt(secondsAndMs[0], 10) + const milliseconds = parseInt(secondsAndMs[1], 10) + + return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000 +} + +export class SrtSplitter extends TextSplitter { + constructor(fields?: Partial) { + // 传入 chunkSize 和 chunkOverlap + super(fields) + } + splitText(): Promise { + throw new Error('Method not implemented.') + } + + // 核心方法:重写 splitDocuments 来实现自定义逻辑 + async splitDocuments(documents: Document[]): Promise { + const allChunks: Document[] = [] + + for (const doc of documents) { + // 1. 解析 SRT 内容 + const segments = this.parseSrt(doc.pageContent) + if (segments.length === 0) continue + + // 2. 将字幕片段组合成块 + const chunks = this.mergeSegmentsIntoChunks(segments, doc.metadata) + allChunks.push(...chunks) + } + + return allChunks + } + + // 辅助方法:解析整个 SRT 字符串 + private parseSrt(srt: string): SrtSegment[] { + const segments: SrtSegment[] = [] + const blocks = srt.trim().split(/\n\n/) + + for (const block of blocks) { + const lines = block.split('\n') + if (lines.length < 3) continue + + const timeMatch = lines[1].match(/(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})/) + if (!timeMatch) continue + + const startTime = srtTimeToSeconds(timeMatch[1]) + const endTime = srtTimeToSeconds(timeMatch[2]) + const text = lines.slice(2).join(' ').trim() + + segments.push({ text, startTime, endTime }) + } + + return segments + } + + // 辅助方法:将解析后的片段合并成每 5 段一个块 + private mergeSegmentsIntoChunks(segments: SrtSegment[], baseMetadata: Record): Document[] { + const chunks: Document[] = [] + let currentChunkText = '' + let currentChunkStartTime = 0 + let currentChunkEndTime = 0 + let segmentCount = 0 + + for (const segment of segments) { + if (segmentCount === 0) { + currentChunkStartTime = segment.startTime + } + + currentChunkText += (currentChunkText ? ' ' : '') + segment.text + currentChunkEndTime = segment.endTime + segmentCount++ + + // 当累积到 5 段时,创建一个新的 Document + if (segmentCount === 5) { + const metadata: Record = { + ...baseMetadata, + startTime: currentChunkStartTime, + endTime: currentChunkEndTime + } + if (baseMetadata.source_url) { + metadata.source_url_with_timestamp = `${baseMetadata.source_url}?t=${Math.floor(currentChunkStartTime)}s` + } + chunks.push( + new Document({ + pageContent: currentChunkText, + metadata + }) + ) + + // 重置计数器和临时变量 + currentChunkText = '' + currentChunkStartTime = 0 + currentChunkEndTime = 0 + segmentCount = 0 + } + } + + // 如果还有剩余的片段,创建最后一个 Document + if (segmentCount > 0) { + const metadata: Record = { + ...baseMetadata, + startTime: currentChunkStartTime, + endTime: currentChunkEndTime + } + if (baseMetadata.source_url) { + metadata.source_url_with_timestamp = `${baseMetadata.source_url}?t=${Math.floor(currentChunkStartTime)}s` + } + chunks.push( + new Document({ + pageContent: currentChunkText, + metadata + }) + ) + } + + return chunks + } +} diff --git a/src/main/knowledge/langchain/splitter/index.ts b/src/main/knowledge/langchain/splitter/index.ts new file mode 100644 index 0000000000..62ca1c9e90 --- /dev/null +++ b/src/main/knowledge/langchain/splitter/index.ts @@ -0,0 +1,31 @@ +import { RecursiveCharacterTextSplitter, TextSplitter } from '@langchain/textsplitters' + +import { SrtSplitter } from './SrtSplitter' + +export type SplitterConfig = { + chunkSize?: number + chunkOverlap?: number + type?: 'recursive' | 'srt' | string +} +export class SplitterFactory { + /** + * Creates a TextSplitter instance based on the provided configuration. + * @param config - The configuration object specifying the splitter type and its parameters. + * @returns An instance of a TextSplitter, or null if no splitting is required. + */ + public static create(config: SplitterConfig): TextSplitter { + switch (config.type) { + case 'srt': + return new SrtSplitter({ + chunkSize: config.chunkSize, + chunkOverlap: config.chunkOverlap + }) + case 'recursive': + default: + return new RecursiveCharacterTextSplitter({ + chunkSize: config.chunkSize, + chunkOverlap: config.chunkOverlap + }) + } + } +} diff --git a/src/main/knowledge/preprocess/PreprocessingService.ts b/src/main/knowledge/preprocess/PreprocessingService.ts new file mode 100644 index 0000000000..1c10529be9 --- /dev/null +++ b/src/main/knowledge/preprocess/PreprocessingService.ts @@ -0,0 +1,63 @@ +import PreprocessProvider from '@main/knowledge/preprocess/PreprocessProvider' +import { loggerService } from '@main/services/LoggerService' +import { windowService } from '@main/services/WindowService' +import type { FileMetadata, KnowledgeBaseParams, KnowledgeItem } from '@types' + +const logger = loggerService.withContext('PreprocessingService') + +class PreprocessingService { + public async preprocessFile( + file: FileMetadata, + base: KnowledgeBaseParams, + item: KnowledgeItem, + userId: string + ): Promise { + let fileToProcess: FileMetadata = file + // Check if preprocessing is configured and applicable (e.g., for PDFs) + if (base.preprocessProvider && file.ext.toLowerCase() === '.pdf') { + try { + const provider = new PreprocessProvider(base.preprocessProvider.provider, userId) + + // Check if file has already been preprocessed + const alreadyProcessed = await provider.checkIfAlreadyProcessed(file) + if (alreadyProcessed) { + logger.debug(`File already preprocessed, using cached result: ${file.path}`) + return alreadyProcessed + } + + // Execute preprocessing + logger.debug(`Starting preprocess for scanned PDF: ${file.path}`) + const { processedFile, quota } = await provider.parseFile(item.id, file) + fileToProcess = processedFile + + // Notify the UI + const mainWindow = windowService.getMainWindow() + mainWindow?.webContents.send('file-preprocess-finished', { + itemId: item.id, + quota: quota + }) + } catch (err) { + logger.error(`Preprocessing failed: ${err}`) + // If preprocessing fails, re-throw the error to be handled by the caller + throw new Error(`Preprocessing failed: ${err}`) + } + } + + return fileToProcess + } + + public async checkQuota(base: KnowledgeBaseParams, userId: string): Promise { + try { + if (base.preprocessProvider && base.preprocessProvider.type === 'preprocess') { + const provider = new PreprocessProvider(base.preprocessProvider.provider, userId) + return await provider.checkQuota() + } + throw new Error('No preprocess provider configured') + } catch (err) { + logger.error(`Failed to check quota: ${err}`) + throw new Error(`Failed to check quota: ${err}`) + } + } +} + +export const preprocessingService = new PreprocessingService() diff --git a/src/main/knowledge/reranker/BaseReranker.ts b/src/main/knowledge/reranker/BaseReranker.ts index c3ac979d25..9483cb3d4e 100644 --- a/src/main/knowledge/reranker/BaseReranker.ts +++ b/src/main/knowledge/reranker/BaseReranker.ts @@ -1,101 +1,46 @@ -import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' -import { KnowledgeBaseParams } from '@types' +import { DEFAULT_DOCUMENT_COUNT, DEFAULT_RELEVANT_SCORE } from '@main/utils/knowledge' +import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types' + +import { MultiModalDocument, RerankStrategy } from './strategies/RerankStrategy' +import { StrategyFactory } from './strategies/StrategyFactory' export default abstract class BaseReranker { protected base: KnowledgeBaseParams + protected strategy: RerankStrategy constructor(base: KnowledgeBaseParams) { if (!base.rerankApiClient) { throw new Error('Rerank model is required') } this.base = base + this.strategy = StrategyFactory.createStrategy(base.rerankApiClient.provider) } - - abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise - - /** - * Get Rerank Request Url - */ - protected getRerankUrl() { - if (this.base.rerankApiClient?.provider === 'bailian') { - return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank' - } - - let baseURL = this.base.rerankApiClient?.baseURL - - if (baseURL && baseURL.endsWith('/')) { - // `/` 结尾强制使用rerankBaseURL - return `${baseURL}rerank` - } - - if (baseURL && !baseURL.endsWith('/v1')) { - baseURL = `${baseURL}/v1` - } - - return `${baseURL}/rerank` + abstract rerank(query: string, searchResults: KnowledgeSearchResult[]): Promise + protected getRerankUrl(): string { + return this.strategy.buildUrl(this.base.rerankApiClient?.baseURL) } - - /** - * Get Rerank Request Body - */ - protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) { - const provider = this.base.rerankApiClient?.provider - const documents = searchResults.map((doc) => doc.pageContent) - const topN = this.base.documentCount - - if (provider === 'voyageai') { - return { - model: this.base.rerankApiClient?.model, - query, - documents, - top_k: topN - } - } else if (provider === 'bailian') { - return { - model: this.base.rerankApiClient?.model, - input: { - query, - documents - }, - parameters: { - top_n: topN - } - } - } else if (provider?.includes('tei')) { - return { - query, - texts: documents, - return_text: true - } - } else { - return { - model: this.base.rerankApiClient?.model, - query, - documents, - top_n: topN - } - } + protected getRerankRequestBody(query: string, searchResults: KnowledgeSearchResult[]) { + const documents = this.buildDocuments(searchResults) + const topN = this.base.documentCount ?? DEFAULT_DOCUMENT_COUNT + const model = this.base.rerankApiClient?.model + return this.strategy.buildRequestBody(query, documents, topN, model) } + private buildDocuments(searchResults: KnowledgeSearchResult[]): MultiModalDocument[] { + return searchResults.map((doc) => { + const document: MultiModalDocument = {} - /** - * Extract Rerank Result - */ + // 检查是否是图片类型,添加图片内容 + if (doc.metadata?.type === 'image') { + document.image = doc.pageContent + } else { + document.text = doc.pageContent + } + + return document + }) + } protected extractRerankResult(data: any) { - const provider = this.base.rerankApiClient?.provider - if (provider === 'bailian') { - return data.output.results - } else if (provider === 'voyageai') { - return data.data - } else if (provider?.includes('tei')) { - return data.map((item: any) => { - return { - index: item.index, - relevance_score: item.score - } - }) - } else { - return data.results - } + return this.strategy.extractResults(data) } /** @@ -105,35 +50,30 @@ export default abstract class BaseReranker { * @protected */ protected getRerankResult( - searchResults: ExtractChunkData[], - rerankResults: Array<{ - index: number - relevance_score: number - }> + searchResults: KnowledgeSearchResult[], + rerankResults: Array<{ index: number; relevance_score: number }> ) { - const resultMap = new Map(rerankResults.map((result) => [result.index, result.relevance_score || 0])) + const resultMap = new Map( + rerankResults.map((result) => [result.index, result.relevance_score || DEFAULT_RELEVANT_SCORE]) + ) - return searchResults - .map((doc: ExtractChunkData, index: number) => { + const returenResults = searchResults + .map((doc: KnowledgeSearchResult, index: number) => { const score = resultMap.get(index) if (score === undefined) return undefined - - return { - ...doc, - score - } + return { ...doc, score } }) - .filter((doc): doc is ExtractChunkData => doc !== undefined) + .filter((doc): doc is KnowledgeSearchResult => doc !== undefined) .sort((a, b) => b.score - a.score) - } + return returenResults + } public defaultHeaders() { return { Authorization: `Bearer ${this.base.rerankApiClient?.apiKey}`, 'Content-Type': 'application/json' } } - protected formatErrorMessage(url: string, error: any, requestBody: any) { const errorDetails = { url: url, diff --git a/src/main/knowledge/reranker/GeneralReranker.ts b/src/main/knowledge/reranker/GeneralReranker.ts index 5a0e240a9d..e4b3503606 100644 --- a/src/main/knowledge/reranker/GeneralReranker.ts +++ b/src/main/knowledge/reranker/GeneralReranker.ts @@ -1,19 +1,14 @@ -import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' -import { KnowledgeBaseParams } from '@types' +import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types' import { net } from 'electron' import BaseReranker from './BaseReranker' - export default class GeneralReranker extends BaseReranker { constructor(base: KnowledgeBaseParams) { super(base) } - - public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise => { + public rerank = async (query: string, searchResults: KnowledgeSearchResult[]): Promise => { const url = this.getRerankUrl() - const requestBody = this.getRerankRequestBody(query, searchResults) - try { const response = await net.fetch(url, { method: 'POST', diff --git a/src/main/knowledge/reranker/Reranker.ts b/src/main/knowledge/reranker/Reranker.ts index d42376ea20..59de4b0470 100644 --- a/src/main/knowledge/reranker/Reranker.ts +++ b/src/main/knowledge/reranker/Reranker.ts @@ -1,5 +1,4 @@ -import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' -import { KnowledgeBaseParams } from '@types' +import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types' import GeneralReranker from './GeneralReranker' @@ -8,7 +7,7 @@ export default class Reranker { constructor(base: KnowledgeBaseParams) { this.sdk = new GeneralReranker(base) } - public async rerank(query: string, searchResults: ExtractChunkData[]): Promise { + public async rerank(query: string, searchResults: KnowledgeSearchResult[]): Promise { return this.sdk.rerank(query, searchResults) } } diff --git a/src/main/knowledge/reranker/strategies/BailianStrategy.ts b/src/main/knowledge/reranker/strategies/BailianStrategy.ts new file mode 100644 index 0000000000..e5932b9f40 --- /dev/null +++ b/src/main/knowledge/reranker/strategies/BailianStrategy.ts @@ -0,0 +1,18 @@ +import { MultiModalDocument, RerankStrategy } from './RerankStrategy' +export class BailianStrategy implements RerankStrategy { + buildUrl(): string { + return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank' + } + buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) { + const textDocuments = documents.filter((d) => d.text).map((d) => d.text!) + + return { + model, + input: { query, documents: textDocuments }, + parameters: { top_n: topN } + } + } + extractResults(data: any) { + return data.output.results + } +} diff --git a/src/main/knowledge/reranker/strategies/DefaultStrategy.ts b/src/main/knowledge/reranker/strategies/DefaultStrategy.ts new file mode 100644 index 0000000000..59ee3fb47b --- /dev/null +++ b/src/main/knowledge/reranker/strategies/DefaultStrategy.ts @@ -0,0 +1,25 @@ +import { MultiModalDocument, RerankStrategy } from './RerankStrategy' +export class DefaultStrategy implements RerankStrategy { + buildUrl(baseURL?: string): string { + if (baseURL && baseURL.endsWith('/')) { + return `${baseURL}rerank` + } + if (baseURL && !baseURL.endsWith('/v1')) { + baseURL = `${baseURL}/v1` + } + return `${baseURL}/rerank` + } + buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) { + const textDocuments = documents.filter((d) => d.text).map((d) => d.text!) + + return { + model, + query, + documents: textDocuments, + top_n: topN + } + } + extractResults(data: any) { + return data.results + } +} diff --git a/src/main/knowledge/reranker/strategies/JinaStrategy.ts b/src/main/knowledge/reranker/strategies/JinaStrategy.ts new file mode 100644 index 0000000000..200190f544 --- /dev/null +++ b/src/main/knowledge/reranker/strategies/JinaStrategy.ts @@ -0,0 +1,33 @@ +import { MultiModalDocument, RerankStrategy } from './RerankStrategy' +export class JinaStrategy implements RerankStrategy { + buildUrl(baseURL?: string): string { + if (baseURL && baseURL.endsWith('/')) { + return `${baseURL}rerank` + } + if (baseURL && !baseURL.endsWith('/v1')) { + baseURL = `${baseURL}/v1` + } + return `${baseURL}/rerank` + } + buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) { + if (model === 'jina-reranker-m0') { + return { + model, + query, + documents, + top_n: topN + } + } + const textDocuments = documents.filter((d) => d.text).map((d) => d.text!) + + return { + model, + query, + documents: textDocuments, + top_n: topN + } + } + extractResults(data: any) { + return data.results + } +} diff --git a/src/main/knowledge/reranker/strategies/RerankStrategy.ts b/src/main/knowledge/reranker/strategies/RerankStrategy.ts new file mode 100644 index 0000000000..a23f630b7d --- /dev/null +++ b/src/main/knowledge/reranker/strategies/RerankStrategy.ts @@ -0,0 +1,9 @@ +export interface MultiModalDocument { + text?: string + image?: string +} +export interface RerankStrategy { + buildUrl(baseURL?: string): string + buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string): any + extractResults(data: any): Array<{ index: number; relevance_score: number }> +} diff --git a/src/main/knowledge/reranker/strategies/StrategyFactory.ts b/src/main/knowledge/reranker/strategies/StrategyFactory.ts new file mode 100644 index 0000000000..9e04547e31 --- /dev/null +++ b/src/main/knowledge/reranker/strategies/StrategyFactory.ts @@ -0,0 +1,25 @@ +import { BailianStrategy } from './BailianStrategy' +import { DefaultStrategy } from './DefaultStrategy' +import { JinaStrategy } from './JinaStrategy' +import { RerankStrategy } from './RerankStrategy' +import { TEIStrategy } from './TeiStrategy' +import { isTEIProvider, RERANKER_PROVIDERS } from './types' +import { VoyageAIStrategy } from './VoyageStrategy' + +export class StrategyFactory { + static createStrategy(provider?: string): RerankStrategy { + switch (provider) { + case RERANKER_PROVIDERS.VOYAGEAI: + return new VoyageAIStrategy() + case RERANKER_PROVIDERS.BAILIAN: + return new BailianStrategy() + case RERANKER_PROVIDERS.JINA: + return new JinaStrategy() + default: + if (isTEIProvider(provider)) { + return new TEIStrategy() + } + return new DefaultStrategy() + } + } +} diff --git a/src/main/knowledge/reranker/strategies/TeiStrategy.ts b/src/main/knowledge/reranker/strategies/TeiStrategy.ts new file mode 100644 index 0000000000..58f24661ba --- /dev/null +++ b/src/main/knowledge/reranker/strategies/TeiStrategy.ts @@ -0,0 +1,26 @@ +import { MultiModalDocument, RerankStrategy } from './RerankStrategy' +export class TEIStrategy implements RerankStrategy { + buildUrl(baseURL?: string): string { + if (baseURL && baseURL.endsWith('/')) { + return `${baseURL}rerank` + } + if (baseURL && !baseURL.endsWith('/v1')) { + baseURL = `${baseURL}/v1` + } + return `${baseURL}/rerank` + } + buildRequestBody(query: string, documents: MultiModalDocument[]) { + const textDocuments = documents.filter((d) => d.text).map((d) => d.text!) + return { + query, + texts: textDocuments, + return_text: true + } + } + extractResults(data: any) { + return data.map((item: any) => ({ + index: item.index, + relevance_score: item.score + })) + } +} diff --git a/src/main/knowledge/reranker/strategies/VoyageStrategy.ts b/src/main/knowledge/reranker/strategies/VoyageStrategy.ts new file mode 100644 index 0000000000..e81319f024 --- /dev/null +++ b/src/main/knowledge/reranker/strategies/VoyageStrategy.ts @@ -0,0 +1,24 @@ +import { MultiModalDocument, RerankStrategy } from './RerankStrategy' +export class VoyageAIStrategy implements RerankStrategy { + buildUrl(baseURL?: string): string { + if (baseURL && baseURL.endsWith('/')) { + return `${baseURL}rerank` + } + if (baseURL && !baseURL.endsWith('/v1')) { + baseURL = `${baseURL}/v1` + } + return `${baseURL}/rerank` + } + buildRequestBody(query: string, documents: MultiModalDocument[], topN: number, model?: string) { + const textDocuments = documents.filter((d) => d.text).map((d) => d.text!) + return { + model, + query, + documents: textDocuments, + top_k: topN + } + } + extractResults(data: any) { + return data.data + } +} diff --git a/src/main/knowledge/reranker/strategies/types.ts b/src/main/knowledge/reranker/strategies/types.ts new file mode 100644 index 0000000000..91bfdef8f6 --- /dev/null +++ b/src/main/knowledge/reranker/strategies/types.ts @@ -0,0 +1,19 @@ +import { objectValues } from '@types' + +export const RERANKER_PROVIDERS = { + VOYAGEAI: 'voyageai', + BAILIAN: 'bailian', + JINA: 'jina', + TEI: 'tei' +} as const + +export type RerankProvider = (typeof RERANKER_PROVIDERS)[keyof typeof RERANKER_PROVIDERS] + +export function isTEIProvider(provider?: string): boolean { + return provider?.includes(RERANKER_PROVIDERS.TEI) ?? false +} + +export function isKnownProvider(provider?: string): provider is RerankProvider { + if (!provider) return false + return objectValues(RERANKER_PROVIDERS).some((p) => p === provider) +} diff --git a/src/main/services/KnowledgeService.ts b/src/main/services/knowledge/EmbedJsFramework.ts similarity index 52% rename from src/main/services/KnowledgeService.ts rename to src/main/services/knowledge/EmbedJsFramework.ts index 99879390e4..64ac77434e 100644 --- a/src/main/services/KnowledgeService.ts +++ b/src/main/services/knowledge/EmbedJsFramework.ts @@ -1,111 +1,40 @@ -/** - * Knowledge Service - Manages knowledge bases using RAG (Retrieval-Augmented Generation) - * - * This service handles creation, management, and querying of knowledge bases from various sources - * including files, directories, URLs, sitemaps, and notes. - * - * Features: - * - Concurrent task processing with workload management - * - Multiple data source support - * - Vector database integration - * - * For detailed documentation, see: - * @see {@link ../../../docs/technical/KnowledgeService.md} - */ - import * as fs from 'node:fs' import path from 'node:path' import { RAGApplication, RAGApplicationBuilder } from '@cherrystudio/embedjs' -import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import { LibSqlDb } from '@cherrystudio/embedjs-libsql' import { SitemapLoader } from '@cherrystudio/embedjs-loader-sitemap' import { WebLoader } from '@cherrystudio/embedjs-loader-web' import { loggerService } from '@logger' -import Embeddings from '@main/knowledge/embeddings/Embeddings' -import { addFileLoader } from '@main/knowledge/loader' -import { NoteLoader } from '@main/knowledge/loader/noteLoader' -import PreprocessProvider from '@main/knowledge/preprocess/PreprocessProvider' -import Reranker from '@main/knowledge/reranker/Reranker' -import { fileStorage } from '@main/services/FileStorage' -import { windowService } from '@main/services/WindowService' -import { getDataPath } from '@main/utils' +import Embeddings from '@main/knowledge/embedjs/embeddings/Embeddings' +import { addFileLoader } from '@main/knowledge/embedjs/loader' +import { NoteLoader } from '@main/knowledge/embedjs/loader/noteLoader' +import { preprocessingService } from '@main/knowledge/preprocess/PreprocessingService' import { getAllFiles } from '@main/utils/file' -import { TraceMethod } from '@mcp-trace/trace-core' import { MB } from '@shared/config/constant' -import type { LoaderReturn } from '@shared/config/types' +import { LoaderReturn } from '@shared/config/types' import { IpcChannel } from '@shared/IpcChannel' -import { FileMetadata, KnowledgeBaseParams, KnowledgeItem } from '@types' +import { FileMetadata, KnowledgeBaseParams, KnowledgeSearchResult } from '@types' import { v4 as uuidv4 } from 'uuid' +import { windowService } from '../WindowService' +import { + IKnowledgeFramework, + KnowledgeBaseAddItemOptionsNonNullableAttribute, + LoaderDoneReturn, + LoaderTask, + LoaderTaskItem, + LoaderTaskItemState +} from './IKnowledgeFramework' + const logger = loggerService.withContext('MainKnowledgeService') -export interface KnowledgeBaseAddItemOptions { - base: KnowledgeBaseParams - item: KnowledgeItem - forceReload?: boolean - userId?: string -} - -interface KnowledgeBaseAddItemOptionsNonNullableAttribute { - base: KnowledgeBaseParams - item: KnowledgeItem - forceReload: boolean - userId: string -} - -interface EvaluateTaskWorkload { - workload: number -} - -type LoaderDoneReturn = LoaderReturn | null - -enum LoaderTaskItemState { - PENDING, - PROCESSING, - DONE -} - -interface LoaderTaskItem { - state: LoaderTaskItemState - task: () => Promise - evaluateTaskWorkload: EvaluateTaskWorkload -} - -interface LoaderTask { - loaderTasks: LoaderTaskItem[] - loaderDoneReturn: LoaderDoneReturn -} - -interface LoaderTaskOfSet { - loaderTasks: Set - loaderDoneReturn: LoaderDoneReturn -} - -interface QueueTaskItem { - taskPromise: () => Promise - resolve: () => void - evaluateTaskWorkload: EvaluateTaskWorkload -} - -const loaderTaskIntoOfSet = (loaderTask: LoaderTask): LoaderTaskOfSet => { - return { - loaderTasks: new Set(loaderTask.loaderTasks), - loaderDoneReturn: loaderTask.loaderDoneReturn - } -} - -class KnowledgeService { - private storageDir = path.join(getDataPath(), 'KnowledgeBase') - private pendingDeleteFile = path.join(this.storageDir, 'knowledge_pending_delete.json') - // Byte based - private workload = 0 - private processingItemCount = 0 - private knowledgeItemProcessingQueueMappingPromise: Map void> = new Map() +export class EmbedJsFramework implements IKnowledgeFramework { + private storageDir: string private ragApplications: Map = new Map() + private pendingDeleteFile: string private dbInstances: Map = new Map() - private static MAXIMUM_WORKLOAD = 80 * MB - private static MAXIMUM_PROCESSING_ITEM_COUNT = 30 + private static ERROR_LOADER_RETURN: LoaderReturn = { entriesAdded: 0, uniqueId: '', @@ -114,7 +43,9 @@ class KnowledgeService { status: 'failed' } - constructor() { + constructor(storageDir: string) { + this.storageDir = storageDir + this.pendingDeleteFile = path.join(this.storageDir, 'knowledge_pending_delete.json') this.initStorageDir() this.cleanupOnStartup() } @@ -229,33 +160,28 @@ class KnowledgeService { logger.info(`Startup cleanup completed: ${deletedCount}/${pendingDeleteIds.length} knowledge bases deleted`) } - private getRagApplication = async ({ - id, - embedApiClient, - dimensions, - documentCount - }: KnowledgeBaseParams): Promise => { - if (this.ragApplications.has(id)) { - return this.ragApplications.get(id)! + private async getRagApplication(base: KnowledgeBaseParams): Promise { + if (this.ragApplications.has(base.id)) { + return this.ragApplications.get(base.id)! } let ragApplication: RAGApplication const embeddings = new Embeddings({ - embedApiClient, - dimensions + embedApiClient: base.embedApiClient, + dimensions: base.dimensions }) try { - const libSqlDb = new LibSqlDb({ path: path.join(this.storageDir, id) }) + const libSqlDb = new LibSqlDb({ path: path.join(this.storageDir, base.id) }) // Save database instance for later closing - this.dbInstances.set(id, libSqlDb) + this.dbInstances.set(base.id, libSqlDb) ragApplication = await new RAGApplicationBuilder() .setModel('NO_MODEL') .setEmbeddingModel(embeddings) .setVectorDatabase(libSqlDb) - .setSearchResultCount(documentCount || 30) + .setSearchResultCount(base.documentCount || 30) .build() - this.ragApplications.set(id, ragApplication) + this.ragApplications.set(base.id, ragApplication) } catch (e) { logger.error('Failed to create RAGApplication:', e as Error) throw new Error(`Failed to create RAGApplication: ${e}`) @@ -263,17 +189,14 @@ class KnowledgeService { return ragApplication } - - public create = async (_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise => { + async initialize(base: KnowledgeBaseParams): Promise { await this.getRagApplication(base) } - - public reset = async (_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise => { - const ragApplication = await this.getRagApplication(base) - await ragApplication.reset() + async reset(base: KnowledgeBaseParams): Promise { + const ragApp = await this.getRagApplication(base) + await ragApp.reset() } - - public async delete(_: Electron.IpcMainInvokeEvent, id: string): Promise { + async delete(id: string): Promise { logger.debug(`delete id: ${id}`) await this.cleanupKnowledgeResources(id) @@ -286,15 +209,41 @@ class KnowledgeService { this.pendingDeleteManager.add(id) } } - - private maximumLoad() { - return ( - this.processingItemCount >= KnowledgeService.MAXIMUM_PROCESSING_ITEM_COUNT || - this.workload >= KnowledgeService.MAXIMUM_WORKLOAD - ) + getLoaderTask(options: KnowledgeBaseAddItemOptionsNonNullableAttribute): LoaderTask { + const { item } = options + const getRagApplication = () => this.getRagApplication(options.base) + switch (item.type) { + case 'file': + return this.fileTask(getRagApplication, options) + case 'directory': + return this.directoryTask(getRagApplication, options) + case 'url': + return this.urlTask(getRagApplication, options) + case 'sitemap': + return this.sitemapTask(getRagApplication, options) + case 'note': + return this.noteTask(getRagApplication, options) + default: + return { + loaderTasks: [], + loaderDoneReturn: null + } + } } + + async remove(options: { uniqueIds: string[]; base: KnowledgeBaseParams }): Promise { + const ragApp = await this.getRagApplication(options.base) + for (const id of options.uniqueIds) { + await ragApp.deleteLoader(id) + } + } + async search(options: { search: string; base: KnowledgeBaseParams }): Promise { + const ragApp = await this.getRagApplication(options.base) + return await ragApp.search(options.search) + } + private fileTask( - ragApplication: RAGApplication, + getRagApplication: () => Promise, options: KnowledgeBaseAddItemOptionsNonNullableAttribute ): LoaderTask { const { base, item, forceReload, userId } = options @@ -307,7 +256,8 @@ class KnowledgeService { task: async () => { try { // Add preprocessing logic - const fileToProcess: FileMetadata = await this.preprocessing(file, base, item, userId) + const ragApplication = await getRagApplication() + const fileToProcess: FileMetadata = await preprocessingService.preprocessFile(file, base, item, userId) // Use processed file for loading return addFileLoader(ragApplication, fileToProcess, base, forceReload) @@ -318,7 +268,7 @@ class KnowledgeService { .catch((e) => { logger.error(`Error in addFileLoader for ${file.name}: ${e}`) const errorResult: LoaderReturn = { - ...KnowledgeService.ERROR_LOADER_RETURN, + ...EmbedJsFramework.ERROR_LOADER_RETURN, message: e.message, messageSource: 'embedding' } @@ -328,7 +278,7 @@ class KnowledgeService { } catch (e: any) { logger.error(`Preprocessing failed for ${file.name}: ${e}`) const errorResult: LoaderReturn = { - ...KnowledgeService.ERROR_LOADER_RETURN, + ...EmbedJsFramework.ERROR_LOADER_RETURN, message: e.message, messageSource: 'preprocess' } @@ -345,7 +295,7 @@ class KnowledgeService { return loaderTask } private directoryTask( - ragApplication: RAGApplication, + getRagApplication: () => Promise, options: KnowledgeBaseAddItemOptionsNonNullableAttribute ): LoaderTask { const { base, item, forceReload } = options @@ -372,8 +322,9 @@ class KnowledgeService { for (const file of files) { loaderTasks.push({ state: LoaderTaskItemState.PENDING, - task: () => - addFileLoader(ragApplication, file, base, forceReload) + task: async () => { + const ragApplication = await getRagApplication() + return addFileLoader(ragApplication, file, base, forceReload) .then((result) => { loaderDoneReturn.entriesAdded += 1 processedFiles += 1 @@ -384,11 +335,12 @@ class KnowledgeService { .catch((err) => { logger.error('Failed to add dir loader:', err) return { - ...KnowledgeService.ERROR_LOADER_RETURN, + ...EmbedJsFramework.ERROR_LOADER_RETURN, message: `Failed to add dir loader: ${err.message}`, messageSource: 'embedding' } - }), + }) + }, evaluateTaskWorkload: { workload: file.size } }) } @@ -400,7 +352,7 @@ class KnowledgeService { } private urlTask( - ragApplication: RAGApplication, + getRagApplication: () => Promise, options: KnowledgeBaseAddItemOptionsNonNullableAttribute ): LoaderTask { const { base, item, forceReload } = options @@ -410,7 +362,8 @@ class KnowledgeService { loaderTasks: [ { state: LoaderTaskItemState.PENDING, - task: () => { + task: async () => { + const ragApplication = await getRagApplication() const loaderReturn = ragApplication.addLoader( new WebLoader({ urlOrContent: content, @@ -434,7 +387,7 @@ class KnowledgeService { .catch((err) => { logger.error('Failed to add url loader:', err) return { - ...KnowledgeService.ERROR_LOADER_RETURN, + ...EmbedJsFramework.ERROR_LOADER_RETURN, message: `Failed to add url loader: ${err.message}`, messageSource: 'embedding' } @@ -449,7 +402,7 @@ class KnowledgeService { } private sitemapTask( - ragApplication: RAGApplication, + getRagApplication: () => Promise, options: KnowledgeBaseAddItemOptionsNonNullableAttribute ): LoaderTask { const { base, item, forceReload } = options @@ -459,8 +412,9 @@ class KnowledgeService { loaderTasks: [ { state: LoaderTaskItemState.PENDING, - task: () => - ragApplication + task: async () => { + const ragApplication = await getRagApplication() + return ragApplication .addLoader( new SitemapLoader({ url: content, chunkSize: base.chunkSize, chunkOverlap: base.chunkOverlap }) as any, forceReload @@ -478,11 +432,12 @@ class KnowledgeService { .catch((err) => { logger.error('Failed to add sitemap loader:', err) return { - ...KnowledgeService.ERROR_LOADER_RETURN, + ...EmbedJsFramework.ERROR_LOADER_RETURN, message: `Failed to add sitemap loader: ${err.message}`, messageSource: 'embedding' } - }), + }) + }, evaluateTaskWorkload: { workload: 20 * MB } } ], @@ -492,7 +447,7 @@ class KnowledgeService { } private noteTask( - ragApplication: RAGApplication, + getRagApplication: () => Promise, options: KnowledgeBaseAddItemOptionsNonNullableAttribute ): LoaderTask { const { base, item, forceReload } = options @@ -505,7 +460,8 @@ class KnowledgeService { loaderTasks: [ { state: LoaderTaskItemState.PENDING, - task: () => { + task: async () => { + const ragApplication = await getRagApplication() const loaderReturn = ragApplication.addLoader( new NoteLoader({ text: content, @@ -528,7 +484,7 @@ class KnowledgeService { .catch((err) => { logger.error('Failed to add note loader:', err) return { - ...KnowledgeService.ERROR_LOADER_RETURN, + ...EmbedJsFramework.ERROR_LOADER_RETURN, message: `Failed to add note loader: ${err.message}`, messageSource: 'embedding' } @@ -541,199 +497,4 @@ class KnowledgeService { } return loaderTask } - - private processingQueueHandle() { - const getSubtasksUntilMaximumLoad = (): QueueTaskItem[] => { - const queueTaskList: QueueTaskItem[] = [] - that: for (const [task, resolve] of this.knowledgeItemProcessingQueueMappingPromise) { - for (const item of task.loaderTasks) { - if (this.maximumLoad()) { - break that - } - - const { state, task: taskPromise, evaluateTaskWorkload } = item - - if (state !== LoaderTaskItemState.PENDING) { - continue - } - - const { workload } = evaluateTaskWorkload - this.workload += workload - this.processingItemCount += 1 - item.state = LoaderTaskItemState.PROCESSING - queueTaskList.push({ - taskPromise: () => - taskPromise().then(() => { - this.workload -= workload - this.processingItemCount -= 1 - task.loaderTasks.delete(item) - if (task.loaderTasks.size === 0) { - this.knowledgeItemProcessingQueueMappingPromise.delete(task) - resolve() - } - this.processingQueueHandle() - }), - resolve: () => {}, - evaluateTaskWorkload - }) - } - } - return queueTaskList - } - const subTasks = getSubtasksUntilMaximumLoad() - if (subTasks.length > 0) { - const subTaskPromises = subTasks.map(({ taskPromise }) => taskPromise()) - Promise.all(subTaskPromises).then(() => { - subTasks.forEach(({ resolve }) => resolve()) - }) - } - } - - private appendProcessingQueue(task: LoaderTask): Promise { - return new Promise((resolve) => { - this.knowledgeItemProcessingQueueMappingPromise.set(loaderTaskIntoOfSet(task), () => { - resolve(task.loaderDoneReturn!) - }) - }) - } - - public add = (_: Electron.IpcMainInvokeEvent, options: KnowledgeBaseAddItemOptions): Promise => { - return new Promise((resolve) => { - const { base, item, forceReload = false, userId = '' } = options - const optionsNonNullableAttribute = { base, item, forceReload, userId } - this.getRagApplication(base) - .then((ragApplication) => { - const task = (() => { - switch (item.type) { - case 'file': - return this.fileTask(ragApplication, optionsNonNullableAttribute) - case 'directory': - return this.directoryTask(ragApplication, optionsNonNullableAttribute) - case 'url': - return this.urlTask(ragApplication, optionsNonNullableAttribute) - case 'sitemap': - return this.sitemapTask(ragApplication, optionsNonNullableAttribute) - case 'note': - return this.noteTask(ragApplication, optionsNonNullableAttribute) - default: - return null - } - })() - - if (task) { - this.appendProcessingQueue(task).then(() => { - resolve(task.loaderDoneReturn!) - }) - this.processingQueueHandle() - } else { - resolve({ - ...KnowledgeService.ERROR_LOADER_RETURN, - message: 'Unsupported item type', - messageSource: 'embedding' - }) - } - }) - .catch((err) => { - logger.error('Failed to add item:', err) - resolve({ - ...KnowledgeService.ERROR_LOADER_RETURN, - message: `Failed to add item: ${err.message}`, - messageSource: 'embedding' - }) - }) - }) - } - - @TraceMethod({ spanName: 'remove', tag: 'Knowledge' }) - public async remove( - _: Electron.IpcMainInvokeEvent, - { uniqueId, uniqueIds, base }: { uniqueId: string; uniqueIds: string[]; base: KnowledgeBaseParams } - ): Promise { - const ragApplication = await this.getRagApplication(base) - logger.debug(`Remove Item UniqueId: ${uniqueId}`) - for (const id of uniqueIds) { - await ragApplication.deleteLoader(id) - } - } - - @TraceMethod({ spanName: 'RagSearch', tag: 'Knowledge' }) - public async search( - _: Electron.IpcMainInvokeEvent, - { search, base }: { search: string; base: KnowledgeBaseParams } - ): Promise { - const ragApplication = await this.getRagApplication(base) - return await ragApplication.search(search) - } - - @TraceMethod({ spanName: 'rerank', tag: 'Knowledge' }) - public async rerank( - _: Electron.IpcMainInvokeEvent, - { search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] } - ): Promise { - if (results.length === 0) { - return results - } - return await new Reranker(base).rerank(search, results) - } - - public getStorageDir = (): string => { - return this.storageDir - } - - private preprocessing = async ( - file: FileMetadata, - base: KnowledgeBaseParams, - item: KnowledgeItem, - userId: string - ): Promise => { - let fileToProcess: FileMetadata = file - if (base.preprocessProvider && file.ext.toLowerCase() === '.pdf') { - try { - const provider = new PreprocessProvider(base.preprocessProvider.provider, userId) - const filePath = fileStorage.getFilePathById(file) - // Check if file has already been preprocessed - const alreadyProcessed = await provider.checkIfAlreadyProcessed(file) - if (alreadyProcessed) { - logger.debug(`File already preprocess processed, using cached result: ${filePath}`) - return alreadyProcessed - } - - // Execute preprocessing - logger.debug(`Starting preprocess processing for scanned PDF: ${filePath}`) - const { processedFile, quota } = await provider.parseFile(item.id, file) - fileToProcess = processedFile - const mainWindow = windowService.getMainWindow() - mainWindow?.webContents.send('file-preprocess-finished', { - itemId: item.id, - quota: quota - }) - } catch (err) { - logger.error(`Preprocess processing failed: ${err}`) - // If preprocessing fails, use original file - // fileToProcess = file - throw new Error(`Preprocess processing failed: ${err}`) - } - } - - return fileToProcess - } - - public checkQuota = async ( - _: Electron.IpcMainInvokeEvent, - base: KnowledgeBaseParams, - userId: string - ): Promise => { - try { - if (base.preprocessProvider && base.preprocessProvider.type === 'preprocess') { - const provider = new PreprocessProvider(base.preprocessProvider.provider, userId) - return await provider.checkQuota() - } - throw new Error('No preprocess provider configured') - } catch (err) { - logger.error(`Failed to check quota: ${err}`) - throw new Error(`Failed to check quota: ${err}`) - } - } } - -export default new KnowledgeService() diff --git a/src/main/services/knowledge/IKnowledgeFramework.ts b/src/main/services/knowledge/IKnowledgeFramework.ts new file mode 100644 index 0000000000..2afbbec713 --- /dev/null +++ b/src/main/services/knowledge/IKnowledgeFramework.ts @@ -0,0 +1,72 @@ +import { LoaderReturn } from '@shared/config/types' +import { KnowledgeBaseParams, KnowledgeItem, KnowledgeSearchResult } from '@types' + +export interface KnowledgeBaseAddItemOptions { + base: KnowledgeBaseParams + item: KnowledgeItem + forceReload?: boolean + userId?: string +} + +export interface KnowledgeBaseAddItemOptionsNonNullableAttribute { + base: KnowledgeBaseParams + item: KnowledgeItem + forceReload: boolean + userId: string +} + +export interface EvaluateTaskWorkload { + workload: number +} + +export type LoaderDoneReturn = LoaderReturn | null + +export enum LoaderTaskItemState { + PENDING, + PROCESSING, + DONE +} + +export interface LoaderTaskItem { + state: LoaderTaskItemState + task: () => Promise + evaluateTaskWorkload: EvaluateTaskWorkload +} + +export interface LoaderTask { + loaderTasks: LoaderTaskItem[] + loaderDoneReturn: LoaderDoneReturn +} + +export interface LoaderTaskOfSet { + loaderTasks: Set + loaderDoneReturn: LoaderDoneReturn +} + +export interface QueueTaskItem { + taskPromise: () => Promise + resolve: () => void + evaluateTaskWorkload: EvaluateTaskWorkload +} + +export const loaderTaskIntoOfSet = (loaderTask: LoaderTask): LoaderTaskOfSet => { + return { + loaderTasks: new Set(loaderTask.loaderTasks), + loaderDoneReturn: loaderTask.loaderDoneReturn + } +} + +export interface IKnowledgeFramework { + /** 为给定知识库初始化框架资源 */ + initialize(base: KnowledgeBaseParams): Promise + /** 重置知识库,删除其所有内容 */ + reset(base: KnowledgeBaseParams): Promise + /** 删除与知识库关联的资源,包括文件 */ + delete(id: string): Promise + /** 生成用于添加条目的任务对象,由队列处理 */ + getLoaderTask(options: KnowledgeBaseAddItemOptionsNonNullableAttribute): LoaderTask + /** 从知识库中删除特定条目 */ + remove(options: { uniqueIds: string[]; base: KnowledgeBaseParams }): Promise + /** 搜索知识库 */ + search(options: { search: string; base: KnowledgeBaseParams }): Promise +} diff --git a/src/main/services/knowledge/KnowledgeFrameworkFactory.ts b/src/main/services/knowledge/KnowledgeFrameworkFactory.ts new file mode 100644 index 0000000000..cf26749564 --- /dev/null +++ b/src/main/services/knowledge/KnowledgeFrameworkFactory.ts @@ -0,0 +1,48 @@ +import path from 'node:path' + +import { KnowledgeBaseParams } from '@types' +import { app } from 'electron' + +import { EmbedJsFramework } from './EmbedJsFramework' +import { IKnowledgeFramework } from './IKnowledgeFramework' +import { LangChainFramework } from './LangChainFramework' +class KnowledgeFrameworkFactory { + private static instance: KnowledgeFrameworkFactory + private frameworks: Map = new Map() + private storageDir: string + + private constructor(storageDir: string) { + this.storageDir = storageDir + } + + public static getInstance(storageDir: string): KnowledgeFrameworkFactory { + if (!KnowledgeFrameworkFactory.instance) { + KnowledgeFrameworkFactory.instance = new KnowledgeFrameworkFactory(storageDir) + } + return KnowledgeFrameworkFactory.instance + } + + public getFramework(base: KnowledgeBaseParams): IKnowledgeFramework { + const frameworkType = base.framework || 'embedjs' // 如果未指定,默认为 embedjs + if (this.frameworks.has(frameworkType)) { + return this.frameworks.get(frameworkType)! + } + let framework: IKnowledgeFramework + switch (frameworkType) { + case 'langchain': + framework = new LangChainFramework(this.storageDir) + break + case 'embedjs': + default: + framework = new EmbedJsFramework(this.storageDir) + break + } + + this.frameworks.set(frameworkType, framework) + return framework + } +} + +export const knowledgeFrameworkFactory = KnowledgeFrameworkFactory.getInstance( + path.join(app.getPath('userData'), 'Data', 'KnowledgeBase') +) diff --git a/src/main/services/knowledge/KnowledgeService.ts b/src/main/services/knowledge/KnowledgeService.ts new file mode 100644 index 0000000000..f34a2b31b6 --- /dev/null +++ b/src/main/services/knowledge/KnowledgeService.ts @@ -0,0 +1,190 @@ +import * as fs from 'node:fs' +import path from 'node:path' + +import { loggerService } from '@logger' +import { preprocessingService } from '@main/knowledge/preprocess/PreprocessingService' +import Reranker from '@main/knowledge/reranker/Reranker' +import { TraceMethod } from '@mcp-trace/trace-core' +import { MB } from '@shared/config/constant' +import { LoaderReturn } from '@shared/config/types' +import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types' +import { app } from 'electron' + +import { + KnowledgeBaseAddItemOptions, + LoaderTask, + loaderTaskIntoOfSet, + LoaderTaskItemState, + LoaderTaskOfSet, + QueueTaskItem +} from './IKnowledgeFramework' +import { knowledgeFrameworkFactory } from './KnowledgeFrameworkFactory' + +const logger = loggerService.withContext('MainKnowledgeService') + +class KnowledgeService { + private storageDir = path.join(app.getPath('userData'), 'Data', 'KnowledgeBase') + + private workload = 0 + private processingItemCount = 0 + private knowledgeItemProcessingQueueMappingPromise: Map void> = new Map() + private static MAXIMUM_WORKLOAD = 80 * MB + private static MAXIMUM_PROCESSING_ITEM_COUNT = 30 + private static ERROR_LOADER_RETURN: LoaderReturn = { + entriesAdded: 0, + uniqueId: '', + uniqueIds: [''], + loaderType: '', + status: 'failed' + } + + constructor() { + this.initStorageDir() + } + + private initStorageDir = (): void => { + if (!fs.existsSync(this.storageDir)) { + fs.mkdirSync(this.storageDir, { recursive: true }) + } + } + + private maximumLoad() { + return ( + this.processingItemCount >= KnowledgeService.MAXIMUM_PROCESSING_ITEM_COUNT || + this.workload >= KnowledgeService.MAXIMUM_WORKLOAD + ) + } + + private processingQueueHandle() { + const getSubtasksUntilMaximumLoad = (): QueueTaskItem[] => { + const queueTaskList: QueueTaskItem[] = [] + that: for (const [task, resolve] of this.knowledgeItemProcessingQueueMappingPromise) { + for (const item of task.loaderTasks) { + if (this.maximumLoad()) { + break that + } + + const { state, task: taskPromise, evaluateTaskWorkload } = item + + if (state !== LoaderTaskItemState.PENDING) { + continue + } + + const { workload } = evaluateTaskWorkload + this.workload += workload + this.processingItemCount += 1 + item.state = LoaderTaskItemState.PROCESSING + queueTaskList.push({ + taskPromise: () => + taskPromise().then(() => { + this.workload -= workload + this.processingItemCount -= 1 + task.loaderTasks.delete(item) + if (task.loaderTasks.size === 0) { + this.knowledgeItemProcessingQueueMappingPromise.delete(task) + resolve() + } + this.processingQueueHandle() + }), + resolve: () => {}, + evaluateTaskWorkload + }) + } + } + return queueTaskList + } + const subTasks = getSubtasksUntilMaximumLoad() + if (subTasks.length > 0) { + const subTaskPromises = subTasks.map(({ taskPromise }) => taskPromise()) + Promise.all(subTaskPromises).then(() => { + subTasks.forEach(({ resolve }) => resolve()) + }) + } + } + + private appendProcessingQueue(task: LoaderTask): Promise { + return new Promise((resolve) => { + this.knowledgeItemProcessingQueueMappingPromise.set(loaderTaskIntoOfSet(task), () => { + resolve(task.loaderDoneReturn!) + }) + }) + } + + public async create(_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise { + logger.info(`Creating knowledge base: ${JSON.stringify(base)}`) + const framework = knowledgeFrameworkFactory.getFramework(base) + await framework.initialize(base) + } + public async reset(_: Electron.IpcMainInvokeEvent, { base }: { base: KnowledgeBaseParams }): Promise { + const framework = knowledgeFrameworkFactory.getFramework(base) + await framework.reset(base) + } + + public async delete(_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams, id: string): Promise { + logger.info(`Deleting knowledge base: ${JSON.stringify(base)}`) + const framework = knowledgeFrameworkFactory.getFramework(base) + await framework.delete(id) + } + + public add = async (_: Electron.IpcMainInvokeEvent, options: KnowledgeBaseAddItemOptions): Promise => { + logger.info(`Adding item to knowledge base: ${JSON.stringify(options)}`) + return new Promise((resolve) => { + const { base, item, forceReload = false, userId = '' } = options + const framework = knowledgeFrameworkFactory.getFramework(base) + + const task = framework.getLoaderTask({ base, item, forceReload, userId }) + + if (task) { + this.appendProcessingQueue(task).then(() => { + resolve(task.loaderDoneReturn!) + }) + this.processingQueueHandle() + } else { + resolve({ + ...KnowledgeService.ERROR_LOADER_RETURN, + message: 'Unsupported item type', + messageSource: 'embedding' + }) + } + }) + } + + public async remove( + _: Electron.IpcMainInvokeEvent, + { uniqueIds, base }: { uniqueIds: string[]; base: KnowledgeBaseParams } + ): Promise { + logger.info(`Removing items from knowledge base: ${JSON.stringify({ uniqueIds, base })}`) + const framework = knowledgeFrameworkFactory.getFramework(base) + await framework.remove({ uniqueIds, base }) + } + public async search( + _: Electron.IpcMainInvokeEvent, + { search, base }: { search: string; base: KnowledgeBaseParams } + ): Promise { + logger.info(`Searching knowledge base: ${JSON.stringify({ search, base })}`) + const framework = knowledgeFrameworkFactory.getFramework(base) + return framework.search({ search, base }) + } + + @TraceMethod({ spanName: 'rerank', tag: 'Knowledge' }) + public async rerank( + _: Electron.IpcMainInvokeEvent, + { search, base, results }: { search: string; base: KnowledgeBaseParams; results: KnowledgeSearchResult[] } + ): Promise { + logger.info(`Reranking knowledge base: ${JSON.stringify({ search, base, results })}`) + if (results.length === 0) { + return results + } + return await new Reranker(base).rerank(search, results) + } + + public getStorageDir = (): string => { + return this.storageDir + } + + public async checkQuota(_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams, userId: string): Promise { + return preprocessingService.checkQuota(base, userId) + } +} + +export default new KnowledgeService() diff --git a/src/main/services/knowledge/LangChainFramework.ts b/src/main/services/knowledge/LangChainFramework.ts new file mode 100644 index 0000000000..b82242e102 --- /dev/null +++ b/src/main/services/knowledge/LangChainFramework.ts @@ -0,0 +1,555 @@ +import * as fs from 'node:fs' +import path from 'node:path' + +import { FaissStore } from '@langchain/community/vectorstores/faiss' +import type { Document } from '@langchain/core/documents' +import { loggerService } from '@logger' +import TextEmbeddings from '@main/knowledge/langchain/embeddings/TextEmbeddings' +import { + addFileLoader, + addNoteLoader, + addSitemapLoader, + addVideoLoader, + addWebLoader +} from '@main/knowledge/langchain/loader' +import { RetrieverFactory } from '@main/knowledge/langchain/retriever' +import { preprocessingService } from '@main/knowledge/preprocess/PreprocessingService' +import { getAllFiles } from '@main/utils/file' +import { getUrlSource } from '@main/utils/knowledge' +import { MB } from '@shared/config/constant' +import { LoaderReturn } from '@shared/config/types' +import { IpcChannel } from '@shared/IpcChannel' +import { + FileMetadata, + isKnowledgeDirectoryItem, + isKnowledgeFileItem, + isKnowledgeNoteItem, + isKnowledgeSitemapItem, + isKnowledgeUrlItem, + isKnowledgeVideoItem, + KnowledgeBaseParams, + KnowledgeSearchResult +} from '@types' +import { uuidv4 } from 'zod/v4' + +import { windowService } from '../WindowService' +import { + IKnowledgeFramework, + KnowledgeBaseAddItemOptionsNonNullableAttribute, + LoaderDoneReturn, + LoaderTask, + LoaderTaskItem, + LoaderTaskItemState +} from './IKnowledgeFramework' + +const logger = loggerService.withContext('LangChainFramework') + +export class LangChainFramework implements IKnowledgeFramework { + private storageDir: string + + private static ERROR_LOADER_RETURN: LoaderReturn = { + entriesAdded: 0, + uniqueId: '', + uniqueIds: [''], + loaderType: '', + status: 'failed' + } + + constructor(storageDir: string) { + this.storageDir = storageDir + this.initStorageDir() + } + private initStorageDir = (): void => { + if (!fs.existsSync(this.storageDir)) { + fs.mkdirSync(this.storageDir, { recursive: true }) + } + } + + private async createDatabase(base: KnowledgeBaseParams): Promise { + const dbPath = path.join(this.storageDir, base.id) + const embeddings = this.getEmbeddings(base) + const vectorStore = new FaissStore(embeddings, {}) + + const mockDocument: Document = { + pageContent: 'Create Database Document', + metadata: {} + } + + await vectorStore.addDocuments([mockDocument], { ids: ['1'] }) + await vectorStore.save(dbPath) + await vectorStore.delete({ ids: ['1'] }) + await vectorStore.save(dbPath) + } + + private getEmbeddings(base: KnowledgeBaseParams): TextEmbeddings { + return new TextEmbeddings({ + embedApiClient: base.embedApiClient, + dimensions: base.dimensions + }) + } + + private async getVectorStore(base: KnowledgeBaseParams): Promise { + const embeddings = this.getEmbeddings(base) + const vectorStore = await FaissStore.load(path.join(this.storageDir, base.id), embeddings) + + return vectorStore + } + + async initialize(base: KnowledgeBaseParams): Promise { + await this.createDatabase(base) + } + async reset(base: KnowledgeBaseParams): Promise { + const dbPath = path.join(this.storageDir, base.id) + if (fs.existsSync(dbPath)) { + fs.rmSync(dbPath, { recursive: true }) + } + } + + async delete(id: string): Promise { + const dbPath = path.join(this.storageDir, id) + if (fs.existsSync(dbPath)) { + fs.rmSync(dbPath, { recursive: true }) + } + } + getLoaderTask(options: KnowledgeBaseAddItemOptionsNonNullableAttribute): LoaderTask { + const { item } = options + const getStore = () => this.getVectorStore(options.base) + switch (item.type) { + case 'file': + return this.fileTask(getStore, options) + case 'directory': + return this.directoryTask(getStore, options) + case 'url': + return this.urlTask(getStore, options) + case 'sitemap': + return this.sitemapTask(getStore, options) + case 'note': + return this.noteTask(getStore, options) + case 'video': + return this.videoTask(getStore, options) + default: + return { + loaderTasks: [], + loaderDoneReturn: null + } + } + } + async remove(options: { uniqueIds: string[]; base: KnowledgeBaseParams }): Promise { + const { uniqueIds, base } = options + const vectorStore = await this.getVectorStore(base) + logger.info(`[ KnowledgeService Remove Item UniqueIds: ${uniqueIds}]`) + + await vectorStore.delete({ ids: uniqueIds }) + await vectorStore.save(path.join(this.storageDir, base.id)) + } + async search(options: { search: string; base: KnowledgeBaseParams }): Promise { + const { search, base } = options + logger.info(`search base: ${JSON.stringify(base)}`) + + try { + const vectorStore = await this.getVectorStore(base) + + // 如果是 bm25 或 hybrid 模式,则从数据库获取所有文档 + const documents: Document[] = await this.getAllDocuments(base) + if (documents.length === 0) return [] + + const retrieverFactory = new RetrieverFactory() + const retriever = retrieverFactory.createRetriever(base, vectorStore, documents) + + const results = await retriever.invoke(search) + logger.info(`Search Results: ${JSON.stringify(results)}`) + + // VectorStoreRetriever 和 EnsembleRetriever 会将分数附加到 metadata.score + // BM25Retriever 默认不返回分数,所以我们需要处理这种情况 + return results.map((item) => { + return { + pageContent: item.pageContent, + metadata: item.metadata, + // 如果 metadata 中没有 score,提供一个默认值 + score: typeof item.metadata.score === 'number' ? item.metadata.score : 0 + } + }) + } catch (error: any) { + logger.error(`Error during search in knowledge base ${base.id}: ${error.message}`) + return [] + } + } + + private fileTask( + getVectorStore: () => Promise, + options: KnowledgeBaseAddItemOptionsNonNullableAttribute + ): LoaderTask { + const { base, item, userId } = options + + if (!isKnowledgeFileItem(item)) { + logger.error(`Invalid item type for fileTask: expected 'file', got '${item.type}'`) + return { + loaderTasks: [], + loaderDoneReturn: { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Invalid item type: expected 'file', got '${item.type}'`, + messageSource: 'validation' + } + } + } + + const file = item.content + + const loaderTask: LoaderTask = { + loaderTasks: [ + { + state: LoaderTaskItemState.PENDING, + task: async () => { + try { + const vectorStore = await getVectorStore() + + // 添加预处理逻辑 + const fileToProcess: FileMetadata = await preprocessingService.preprocessFile(file, base, item, userId) + + // 使用处理后的文件进行加载 + return addFileLoader(base, vectorStore, fileToProcess) + .then((result) => { + loaderTask.loaderDoneReturn = result + return result + }) + .then(async () => { + await vectorStore.save(path.join(this.storageDir, base.id)) + }) + .catch((e) => { + logger.error(`Error in addFileLoader for ${file.name}: ${e}`) + const errorResult: LoaderReturn = { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: e.message, + messageSource: 'embedding' + } + loaderTask.loaderDoneReturn = errorResult + return errorResult + }) + } catch (e: any) { + logger.error(`Preprocessing failed for ${file.name}: ${e}`) + const errorResult: LoaderReturn = { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: e.message, + messageSource: 'preprocess' + } + loaderTask.loaderDoneReturn = errorResult + return errorResult + } + }, + evaluateTaskWorkload: { workload: file.size } + } + ], + loaderDoneReturn: null + } + + return loaderTask + } + private directoryTask( + getVectorStore: () => Promise, + options: KnowledgeBaseAddItemOptionsNonNullableAttribute + ): LoaderTask { + const { base, item } = options + + if (!isKnowledgeDirectoryItem(item)) { + logger.error(`Invalid item type for directoryTask: expected 'directory', got '${item.type}'`) + return { + loaderTasks: [], + loaderDoneReturn: { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Invalid item type: expected 'directory', got '${item.type}'`, + messageSource: 'validation' + } + } + } + + const directory = item.content + const files = getAllFiles(directory) + const totalFiles = files.length + let processedFiles = 0 + + const sendDirectoryProcessingPercent = (totalFiles: number, processedFiles: number) => { + const mainWindow = windowService.getMainWindow() + mainWindow?.webContents.send(IpcChannel.DirectoryProcessingPercent, { + itemId: item.id, + percent: (processedFiles / totalFiles) * 100 + }) + } + + const loaderDoneReturn: LoaderDoneReturn = { + entriesAdded: 0, + uniqueId: `DirectoryLoader_${uuidv4()}`, + uniqueIds: [], + loaderType: 'DirectoryLoader' + } + const loaderTasks: LoaderTaskItem[] = [] + for (const file of files) { + loaderTasks.push({ + state: LoaderTaskItemState.PENDING, + task: async () => { + const vectorStore = await getVectorStore() + return addFileLoader(base, vectorStore, file) + .then((result) => { + loaderDoneReturn.entriesAdded += 1 + processedFiles += 1 + sendDirectoryProcessingPercent(totalFiles, processedFiles) + loaderDoneReturn.uniqueIds.push(result.uniqueId) + return result + }) + .then(async () => { + await vectorStore.save(path.join(this.storageDir, base.id)) + }) + .catch((err) => { + logger.error(err) + return { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Failed to add dir loader: ${err.message}`, + messageSource: 'embedding' + } + }) + }, + evaluateTaskWorkload: { workload: file.size } + }) + } + + return { + loaderTasks, + loaderDoneReturn + } + } + + private urlTask( + getVectorStore: () => Promise, + options: KnowledgeBaseAddItemOptionsNonNullableAttribute + ): LoaderTask { + const { base, item } = options + + if (!isKnowledgeUrlItem(item)) { + logger.error(`Invalid item type for urlTask: expected 'url', got '${item.type}'`) + return { + loaderTasks: [], + loaderDoneReturn: { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Invalid item type: expected 'url', got '${item.type}'`, + messageSource: 'validation' + } + } + } + + const url = item.content + + const loaderTask: LoaderTask = { + loaderTasks: [ + { + state: LoaderTaskItemState.PENDING, + task: async () => { + // 使用处理后的网页进行加载 + const vectorStore = await getVectorStore() + return addWebLoader(base, vectorStore, url, getUrlSource(url)) + .then((result) => { + loaderTask.loaderDoneReturn = result + return result + }) + .then(async () => { + await vectorStore.save(path.join(this.storageDir, base.id)) + }) + .catch((e) => { + logger.error(`Error in addWebLoader for ${url}: ${e}`) + const errorResult: LoaderReturn = { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: e.message, + messageSource: 'embedding' + } + loaderTask.loaderDoneReturn = errorResult + return errorResult + }) + }, + evaluateTaskWorkload: { workload: 2 * MB } + } + ], + loaderDoneReturn: null + } + return loaderTask + } + + private sitemapTask( + getVectorStore: () => Promise, + options: KnowledgeBaseAddItemOptionsNonNullableAttribute + ): LoaderTask { + const { base, item } = options + + if (!isKnowledgeSitemapItem(item)) { + logger.error(`Invalid item type for sitemapTask: expected 'sitemap', got '${item.type}'`) + return { + loaderTasks: [], + loaderDoneReturn: { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Invalid item type: expected 'sitemap', got '${item.type}'`, + messageSource: 'validation' + } + } + } + + const url = item.content + + const loaderTask: LoaderTask = { + loaderTasks: [ + { + state: LoaderTaskItemState.PENDING, + task: async () => { + // 使用处理后的网页进行加载 + const vectorStore = await getVectorStore() + return addSitemapLoader(base, vectorStore, url) + .then((result) => { + loaderTask.loaderDoneReturn = result + return result + }) + .then(async () => { + await vectorStore.save(path.join(this.storageDir, base.id)) + }) + .catch((e) => { + logger.error(`Error in addWebLoader for ${url}: ${e}`) + const errorResult: LoaderReturn = { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: e.message, + messageSource: 'embedding' + } + loaderTask.loaderDoneReturn = errorResult + return errorResult + }) + }, + evaluateTaskWorkload: { workload: 2 * MB } + } + ], + loaderDoneReturn: null + } + return loaderTask + } + + private noteTask( + getVectorStore: () => Promise, + options: KnowledgeBaseAddItemOptionsNonNullableAttribute + ): LoaderTask { + const { base, item } = options + + if (!isKnowledgeNoteItem(item)) { + logger.error(`Invalid item type for noteTask: expected 'note', got '${item.type}'`) + return { + loaderTasks: [], + loaderDoneReturn: { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Invalid item type: expected 'note', got '${item.type}'`, + messageSource: 'validation' + } + } + } + + const content = item.content + const sourceUrl = item.sourceUrl ?? '' + + logger.info(`noteTask ${content}, ${sourceUrl}`) + + const encoder = new TextEncoder() + const contentBytes = encoder.encode(content) + const loaderTask: LoaderTask = { + loaderTasks: [ + { + state: LoaderTaskItemState.PENDING, + task: async () => { + // 使用处理后的笔记进行加载 + const vectorStore = await getVectorStore() + return addNoteLoader(base, vectorStore, content, sourceUrl) + .then((result) => { + loaderTask.loaderDoneReturn = result + return result + }) + .then(async () => { + await vectorStore.save(path.join(this.storageDir, base.id)) + }) + .catch((e) => { + logger.error(`Error in addNoteLoader for ${sourceUrl}: ${e}`) + const errorResult: LoaderReturn = { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: e.message, + messageSource: 'embedding' + } + loaderTask.loaderDoneReturn = errorResult + return errorResult + }) + }, + evaluateTaskWorkload: { workload: contentBytes.length } + } + ], + loaderDoneReturn: null + } + return loaderTask + } + + private videoTask( + getVectorStore: () => Promise, + options: KnowledgeBaseAddItemOptionsNonNullableAttribute + ): LoaderTask { + const { base, item } = options + + if (!isKnowledgeVideoItem(item)) { + logger.error(`Invalid item type for videoTask: expected 'video', got '${item.type}'`) + return { + loaderTasks: [], + loaderDoneReturn: { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: `Invalid item type: expected 'video', got '${item.type}'`, + messageSource: 'validation' + } + } + } + + const files = item.content + + const loaderTask: LoaderTask = { + loaderTasks: [ + { + state: LoaderTaskItemState.PENDING, + task: async () => { + const vectorStore = await getVectorStore() + return addVideoLoader(base, vectorStore, files) + .then((result) => { + loaderTask.loaderDoneReturn = result + return result + }) + .then(async () => { + await vectorStore.save(path.join(this.storageDir, base.id)) + }) + .catch((e) => { + logger.error(`Preprocessing failed for ${files[0].name}: ${e}`) + const errorResult: LoaderReturn = { + ...LangChainFramework.ERROR_LOADER_RETURN, + message: e.message, + messageSource: 'preprocess' + } + loaderTask.loaderDoneReturn = errorResult + return errorResult + }) + }, + evaluateTaskWorkload: { workload: files[0].size } + } + ], + loaderDoneReturn: null + } + return loaderTask + } + + private async getAllDocuments(base: KnowledgeBaseParams): Promise { + logger.info(`Fetching all documents from database for knowledge base: ${base.id}`) + + try { + const results = (await this.getVectorStore(base)).docstore._docs + + const documents: Document[] = Array.from(results.values()) + logger.info(`Fetched ${documents.length} documents for BM25/Hybrid retriever.`) + return documents + } catch (e) { + logger.error(`Could not fetch documents from database for base ${base.id}: ${e}`) + // 如果表不存在或查询失败,返回空数组 + return [] + } + } +} diff --git a/src/main/services/memory/MemoryService.ts b/src/main/services/memory/MemoryService.ts index aba341391f..85f182b686 100644 --- a/src/main/services/memory/MemoryService.ts +++ b/src/main/services/memory/MemoryService.ts @@ -1,6 +1,6 @@ import { Client, createClient } from '@libsql/client' import { loggerService } from '@logger' -import Embeddings from '@main/knowledge/embeddings/Embeddings' +import Embeddings from '@main/knowledge/embedjs/embeddings/Embeddings' import type { AddMemoryOptions, AssistantMessage, diff --git a/src/main/utils/file.ts b/src/main/utils/file.ts index 2f622d3544..97e87bf9ea 100644 --- a/src/main/utils/file.ts +++ b/src/main/utils/file.ts @@ -205,6 +205,19 @@ export async function readTextFileWithAutoEncoding(filePath: string): Promise { + const filePath = path.join(getFilesDir(), `${file.id}${file.ext}`) + const data = await fs.promises.readFile(filePath) + const base64 = data.toString('base64') + const ext = path.extname(filePath).slice(1) == 'jpg' ? 'jpeg' : path.extname(filePath).slice(1) + const mime = `image/${ext}` + return { + mime, + base64, + data: `data:${mime};base64,${base64}` + } +} + /** * 递归扫描目录,获取符合条件的文件和目录结构 * @param dirPath 当前要扫描的路径 diff --git a/src/main/utils/knowledge.ts b/src/main/utils/knowledge.ts new file mode 100644 index 0000000000..cce85829d0 --- /dev/null +++ b/src/main/utils/knowledge.ts @@ -0,0 +1,13 @@ +export const DEFAULT_DOCUMENT_COUNT = 6 +export const DEFAULT_RELEVANT_SCORE = 0 +export type UrlSource = 'normal' | 'github' | 'youtube' + +const youtubeRegex = /^(https?:\/\/)?(www\.)?(youtube\.com|youtu\.be|youtube\.be|yt\.be)/i + +export function getUrlSource(url: string): UrlSource { + if (youtubeRegex.test(url)) { + return 'youtube' + } else { + return 'normal' + } +} diff --git a/src/preload/index.ts b/src/preload/index.ts index 2244264753..0600b6b310 100644 --- a/src/preload/index.ts +++ b/src/preload/index.ts @@ -1,4 +1,3 @@ -import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import { electronAPI } from '@electron-toolkit/preload' import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' import { SpanContext } from '@opentelemetry/api' @@ -14,6 +13,7 @@ import { FileUploadResponse, KnowledgeBaseParams, KnowledgeItem, + KnowledgeSearchResult, MCPServer, MemoryConfig, MemoryListOptions, @@ -166,7 +166,8 @@ const api = { selectFolder: (options?: OpenDialogOptions) => ipcRenderer.invoke(IpcChannel.File_SelectFolder, options), saveImage: (name: string, data: string) => ipcRenderer.invoke(IpcChannel.File_SaveImage, name, data), binaryImage: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryImage, fileId), - base64Image: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64Image, fileId), + base64Image: (fileId: string): Promise<{ mime: string; base64: string; data: string }> => + ipcRenderer.invoke(IpcChannel.File_Base64Image, fileId), saveBase64Image: (data: string) => ipcRenderer.invoke(IpcChannel.File_SaveBase64Image, data), savePastedImage: (imageData: Uint8Array, extension?: string) => ipcRenderer.invoke(IpcChannel.File_SavePastedImage, imageData, extension), @@ -215,7 +216,7 @@ const api = { create: (base: KnowledgeBaseParams, context?: SpanContext) => tracedInvoke(IpcChannel.KnowledgeBase_Create, context, base), reset: (base: KnowledgeBaseParams) => ipcRenderer.invoke(IpcChannel.KnowledgeBase_Reset, base), - delete: (id: string) => ipcRenderer.invoke(IpcChannel.KnowledgeBase_Delete, id), + delete: (base: KnowledgeBaseParams, id: string) => ipcRenderer.invoke(IpcChannel.KnowledgeBase_Delete, base, id), add: ({ base, item, @@ -232,7 +233,7 @@ const api = { search: ({ search, base }: { search: string; base: KnowledgeBaseParams }, context?: SpanContext) => tracedInvoke(IpcChannel.KnowledgeBase_Search, context, { search, base }), rerank: ( - { search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] }, + { search, base, results }: { search: string; base: KnowledgeBaseParams; results: KnowledgeSearchResult[] }, context?: SpanContext ) => tracedInvoke(IpcChannel.KnowledgeBase_Rerank, context, { search, base, results }), checkQuota: ({ base, userId }: { base: KnowledgeBaseParams; userId: string }) => diff --git a/src/renderer/index.html b/src/renderer/index.html index 239d9c794c..7b9bfdc329 100644 --- a/src/renderer/index.html +++ b/src/renderer/index.html @@ -5,7 +5,7 @@ + content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' 'unsafe-inline' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; media-src 'self' file:; frame-src * file:" /> Cherry Studio