mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 06:30:10 +08:00
Merge branch 'main' into develop
This commit is contained in:
commit
5cee35b167
@ -92,9 +92,9 @@ artifactBuildCompleted: scripts/artifact-build-completed.js
|
||||
releaseInfo:
|
||||
releaseNotes: |
|
||||
⚠️ 注意:升级前请备份数据,否则将无法降级
|
||||
重构消息结构,支持不同类型消息按时间顺序显示
|
||||
智能体支持导入和导出
|
||||
快捷面板增加网络搜索引擎选择
|
||||
显示设置增加缩放控制按钮
|
||||
支持添加自定义小程序
|
||||
性能优化和错误修复
|
||||
优化软件启动速度
|
||||
优化软件进入后台后性能问题
|
||||
修复导出对话时自动重命名失败问题
|
||||
防止输入法切换期间误发消息问题
|
||||
修复群组消息重发功能问题及富文本粘贴兼容性问题
|
||||
改进 MCP 服务处理及 IPC 注册逻辑
|
||||
|
||||
@ -85,12 +85,19 @@ export default defineConfig({
|
||||
miniWindow: resolve(__dirname, 'src/renderer/miniWindow.html')
|
||||
},
|
||||
output: {
|
||||
manualChunks: (id) => {
|
||||
manualChunks: (id: string) => {
|
||||
// 检测所有 worker 文件,提取 worker 名称作为 chunk 名
|
||||
if (id.includes('.worker') && id.endsWith('?worker')) {
|
||||
const workerName = id.split('/').pop()?.split('.')[0] || 'worker'
|
||||
return `workers/${workerName}`
|
||||
}
|
||||
|
||||
// All node_modules are in the vendor chunk
|
||||
if (id.includes('node_modules')) {
|
||||
return 'vendor'
|
||||
}
|
||||
|
||||
// Other modules use default chunk splitting strategy
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "CherryStudio",
|
||||
"version": "1.3.5",
|
||||
"version": "1.3.6",
|
||||
"private": true,
|
||||
"description": "A powerful AI assistant for producer.",
|
||||
"main": "./out/main/index.js",
|
||||
|
||||
@ -2,12 +2,11 @@ import '@main/config'
|
||||
|
||||
import { electronApp, optimizer } from '@electron-toolkit/utils'
|
||||
import { replaceDevtoolsFont } from '@main/utils/windowUtil'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { app, BrowserWindow, ipcMain } from 'electron'
|
||||
import { app } from 'electron'
|
||||
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
|
||||
import Logger from 'electron-log'
|
||||
|
||||
import { isDev, isMac, isWin } from './constant'
|
||||
import { isDev } from './constant'
|
||||
import { registerIpc } from './ipc'
|
||||
import { configManager } from './services/ConfigManager'
|
||||
import mcpService from './services/MCPService'
|
||||
@ -85,18 +84,6 @@ if (!app.requestSingleInstanceLock()) {
|
||||
.then((name) => console.log(`Added Extension: ${name}`))
|
||||
.catch((err) => console.log('An error occurred: ', err))
|
||||
}
|
||||
ipcMain.handle(IpcChannel.System_GetDeviceType, () => {
|
||||
return isMac ? 'mac' : isWin ? 'windows' : 'linux'
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.System_GetHostname, () => {
|
||||
return require('os').hostname()
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
|
||||
const win = BrowserWindow.fromWebContents(e.sender)
|
||||
win && win.webContents.toggleDevTools()
|
||||
})
|
||||
})
|
||||
|
||||
registerProtocolClient(app)
|
||||
@ -128,7 +115,7 @@ if (!app.requestSingleInstanceLock()) {
|
||||
app.on('will-quit', async () => {
|
||||
// event.preventDefault()
|
||||
try {
|
||||
await mcpService.cleanup()
|
||||
await mcpService().cleanup()
|
||||
} catch (error) {
|
||||
Logger.error('Error cleaning up MCP service:', error)
|
||||
}
|
||||
|
||||
@ -19,7 +19,7 @@ import FileService from './services/FileService'
|
||||
import FileStorage from './services/FileStorage'
|
||||
import { GeminiService } from './services/GeminiService'
|
||||
import KnowledgeService from './services/KnowledgeService'
|
||||
import mcpService from './services/MCPService'
|
||||
import { getMcpInstance } from './services/MCPService'
|
||||
import * as NutstoreService from './services/NutstoreService'
|
||||
import ObsidianVaultService from './services/ObsidianVaultService'
|
||||
import { ProxyConfig, proxyManager } from './services/ProxyManager'
|
||||
@ -204,6 +204,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
ipcMain.handle(IpcChannel.Zip_Compress, (_, text: string) => compress(text))
|
||||
ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text))
|
||||
|
||||
// system
|
||||
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
|
||||
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
|
||||
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
|
||||
const win = BrowserWindow.fromWebContents(e.sender)
|
||||
win && win.webContents.toggleDevTools()
|
||||
})
|
||||
|
||||
// backup
|
||||
ipcMain.handle(IpcChannel.Backup_Backup, backupManager.backup)
|
||||
ipcMain.handle(IpcChannel.Backup_Restore, backupManager.restore)
|
||||
@ -301,16 +309,16 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
)
|
||||
|
||||
// Register MCP handlers
|
||||
ipcMain.handle(IpcChannel.Mcp_RemoveServer, mcpService.removeServer)
|
||||
ipcMain.handle(IpcChannel.Mcp_RestartServer, mcpService.restartServer)
|
||||
ipcMain.handle(IpcChannel.Mcp_StopServer, mcpService.stopServer)
|
||||
ipcMain.handle(IpcChannel.Mcp_ListTools, mcpService.listTools)
|
||||
ipcMain.handle(IpcChannel.Mcp_CallTool, mcpService.callTool)
|
||||
ipcMain.handle(IpcChannel.Mcp_ListPrompts, mcpService.listPrompts)
|
||||
ipcMain.handle(IpcChannel.Mcp_GetPrompt, mcpService.getPrompt)
|
||||
ipcMain.handle(IpcChannel.Mcp_ListResources, mcpService.listResources)
|
||||
ipcMain.handle(IpcChannel.Mcp_GetResource, mcpService.getResource)
|
||||
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo)
|
||||
ipcMain.handle(IpcChannel.Mcp_RemoveServer, (event, server) => getMcpInstance().removeServer(event, server))
|
||||
ipcMain.handle(IpcChannel.Mcp_RestartServer, (event, server) => getMcpInstance().restartServer(event, server))
|
||||
ipcMain.handle(IpcChannel.Mcp_StopServer, (event, server) => getMcpInstance().stopServer(event, server))
|
||||
ipcMain.handle(IpcChannel.Mcp_ListTools, (event, server) => getMcpInstance().listTools(event, server))
|
||||
ipcMain.handle(IpcChannel.Mcp_CallTool, (event, params) => getMcpInstance().callTool(event, params))
|
||||
ipcMain.handle(IpcChannel.Mcp_ListPrompts, (event, server) => getMcpInstance().listPrompts(event, server))
|
||||
ipcMain.handle(IpcChannel.Mcp_GetPrompt, (event, params) => getMcpInstance().getPrompt(event, params))
|
||||
ipcMain.handle(IpcChannel.Mcp_ListResources, (event, server) => getMcpInstance().listResources(event, server))
|
||||
ipcMain.handle(IpcChannel.Mcp_GetResource, (event, params) => getMcpInstance().getResource(event, params))
|
||||
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, () => getMcpInstance().getInstallInfo())
|
||||
|
||||
ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name))
|
||||
ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name))
|
||||
|
||||
@ -38,7 +38,7 @@ export default abstract class BaseReranker {
|
||||
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
|
||||
const provider = this.base.rerankModelProvider
|
||||
const documents = searchResults.map((doc) => doc.pageContent)
|
||||
const topN = this.base.topN || 10
|
||||
const topN = this.base.documentCount
|
||||
|
||||
if (provider === 'voyageai') {
|
||||
return {
|
||||
|
||||
@ -69,20 +69,18 @@ function withCache<T extends unknown[], R>(
|
||||
}
|
||||
|
||||
class McpService {
|
||||
private static instance: McpService | null = null
|
||||
private clients: Map<string, Client> = new Map()
|
||||
private pendingClients: Map<string, Promise<Client>> = new Map()
|
||||
|
||||
private getServerKey(server: MCPServer): string {
|
||||
return JSON.stringify({
|
||||
baseUrl: server.baseUrl,
|
||||
command: server.command,
|
||||
args: server.args,
|
||||
registryUrl: server.registryUrl,
|
||||
env: server.env,
|
||||
id: server.id
|
||||
})
|
||||
public static getInstance(): McpService {
|
||||
if (!McpService.instance) {
|
||||
McpService.instance = new McpService()
|
||||
}
|
||||
return McpService.instance
|
||||
}
|
||||
|
||||
constructor() {
|
||||
private constructor() {
|
||||
this.initClient = this.initClient.bind(this)
|
||||
this.listTools = this.listTools.bind(this)
|
||||
this.callTool = this.callTool.bind(this)
|
||||
@ -97,9 +95,26 @@ class McpService {
|
||||
this.cleanup = this.cleanup.bind(this)
|
||||
}
|
||||
|
||||
private getServerKey(server: MCPServer): string {
|
||||
return JSON.stringify({
|
||||
baseUrl: server.baseUrl,
|
||||
command: server.command,
|
||||
args: server.args,
|
||||
registryUrl: server.registryUrl,
|
||||
env: server.env,
|
||||
id: server.id
|
||||
})
|
||||
}
|
||||
|
||||
async initClient(server: MCPServer): Promise<Client> {
|
||||
const serverKey = this.getServerKey(server)
|
||||
|
||||
// If there's a pending initialization, wait for it
|
||||
const pendingClient = this.pendingClients.get(serverKey)
|
||||
if (pendingClient) {
|
||||
return pendingClient
|
||||
}
|
||||
|
||||
// Check if we already have a client for this server configuration
|
||||
const existingClient = this.clients.get(serverKey)
|
||||
if (existingClient) {
|
||||
@ -114,209 +129,226 @@ class McpService {
|
||||
} else {
|
||||
return existingClient
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error(`[MCP] Error pinging server ${server.name}:`, error)
|
||||
} catch (error: any) {
|
||||
Logger.error(`[MCP] Error pinging server ${server.name}:`, error?.message)
|
||||
this.clients.delete(serverKey)
|
||||
}
|
||||
}
|
||||
// Create new client instance for each connection
|
||||
const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} })
|
||||
|
||||
const args = [...(server.args || [])]
|
||||
// Create a promise for the initialization process
|
||||
const initPromise = (async () => {
|
||||
try {
|
||||
// Create new client instance for each connection
|
||||
const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} })
|
||||
|
||||
// let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
||||
const authProvider = new McpOAuthClientProvider({
|
||||
serverUrlHash: crypto
|
||||
.createHash('md5')
|
||||
.update(server.baseUrl || '')
|
||||
.digest('hex')
|
||||
})
|
||||
const args = [...(server.args || [])]
|
||||
|
||||
const initTransport = async (): Promise<
|
||||
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
||||
> => {
|
||||
// Create appropriate transport based on configuration
|
||||
if (server.type === 'inMemory') {
|
||||
Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`)
|
||||
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair()
|
||||
// start the in-memory server with the given name and environment variables
|
||||
const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {})
|
||||
try {
|
||||
await inMemoryServer.connect(serverTransport)
|
||||
Logger.info(`[MCP] In-memory server started: ${server.name}`)
|
||||
} catch (error: Error | any) {
|
||||
Logger.error(`[MCP] Error starting in-memory server: ${error}`)
|
||||
throw new Error(`Failed to start in-memory server: ${error.message}`)
|
||||
}
|
||||
// set the client transport to the client
|
||||
return clientTransport
|
||||
} else if (server.baseUrl) {
|
||||
if (server.type === 'streamableHttp') {
|
||||
const options: StreamableHTTPClientTransportOptions = {
|
||||
requestInit: {
|
||||
headers: server.headers || {}
|
||||
},
|
||||
authProvider
|
||||
}
|
||||
return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
|
||||
} else if (server.type === 'sse') {
|
||||
const options: SSEClientTransportOptions = {
|
||||
eventSourceInit: {
|
||||
fetch: async (url, init) => {
|
||||
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
|
||||
// let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
||||
const authProvider = new McpOAuthClientProvider({
|
||||
serverUrlHash: crypto
|
||||
.createHash('md5')
|
||||
.update(server.baseUrl || '')
|
||||
.digest('hex')
|
||||
})
|
||||
|
||||
// Get tokens from authProvider to make sure using the latest tokens
|
||||
if (authProvider && typeof authProvider.tokens === 'function') {
|
||||
try {
|
||||
const tokens = await authProvider.tokens()
|
||||
if (tokens && tokens.access_token) {
|
||||
headers['Authorization'] = `Bearer ${tokens.access_token}`
|
||||
const initTransport = async (): Promise<
|
||||
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
||||
> => {
|
||||
// Create appropriate transport based on configuration
|
||||
if (server.type === 'inMemory') {
|
||||
Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`)
|
||||
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair()
|
||||
// start the in-memory server with the given name and environment variables
|
||||
const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {})
|
||||
try {
|
||||
await inMemoryServer.connect(serverTransport)
|
||||
Logger.info(`[MCP] In-memory server started: ${server.name}`)
|
||||
} catch (error: Error | any) {
|
||||
Logger.error(`[MCP] Error starting in-memory server: ${error}`)
|
||||
throw new Error(`Failed to start in-memory server: ${error.message}`)
|
||||
}
|
||||
// set the client transport to the client
|
||||
return clientTransport
|
||||
} else if (server.baseUrl) {
|
||||
if (server.type === 'streamableHttp') {
|
||||
const options: StreamableHTTPClientTransportOptions = {
|
||||
requestInit: {
|
||||
headers: server.headers || {}
|
||||
},
|
||||
authProvider
|
||||
}
|
||||
return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
|
||||
} else if (server.type === 'sse') {
|
||||
const options: SSEClientTransportOptions = {
|
||||
eventSourceInit: {
|
||||
fetch: async (url, init) => {
|
||||
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
|
||||
|
||||
// Get tokens from authProvider to make sure using the latest tokens
|
||||
if (authProvider && typeof authProvider.tokens === 'function') {
|
||||
try {
|
||||
const tokens = await authProvider.tokens()
|
||||
if (tokens && tokens.access_token) {
|
||||
headers['Authorization'] = `Bearer ${tokens.access_token}`
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error('Failed to fetch tokens:', error)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error('Failed to fetch tokens:', error)
|
||||
|
||||
return fetch(url, { ...init, headers })
|
||||
}
|
||||
},
|
||||
requestInit: {
|
||||
headers: server.headers || {}
|
||||
},
|
||||
authProvider
|
||||
}
|
||||
return new SSEClientTransport(new URL(server.baseUrl!), options)
|
||||
} else {
|
||||
throw new Error('Invalid server type')
|
||||
}
|
||||
} else if (server.command) {
|
||||
let cmd = server.command
|
||||
|
||||
if (server.command === 'npx') {
|
||||
cmd = await getBinaryPath('bun')
|
||||
Logger.info(`[MCP] Using command: ${cmd}`)
|
||||
|
||||
// add -x to args if args exist
|
||||
if (args && args.length > 0) {
|
||||
if (!args.includes('-y')) {
|
||||
args.unshift('-y')
|
||||
}
|
||||
if (!args.includes('x')) {
|
||||
args.unshift('x')
|
||||
}
|
||||
}
|
||||
if (server.registryUrl) {
|
||||
server.env = {
|
||||
...server.env,
|
||||
NPM_CONFIG_REGISTRY: server.registryUrl
|
||||
}
|
||||
|
||||
return fetch(url, { ...init, headers })
|
||||
// if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory
|
||||
if (server.name.includes('mcp-auto-install')) {
|
||||
const binPath = await getBinaryPath()
|
||||
makeSureDirExists(binPath)
|
||||
server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json')
|
||||
}
|
||||
}
|
||||
} else if (server.command === 'uvx' || server.command === 'uv') {
|
||||
cmd = await getBinaryPath(server.command)
|
||||
if (server.registryUrl) {
|
||||
server.env = {
|
||||
...server.env,
|
||||
UV_DEFAULT_INDEX: server.registryUrl,
|
||||
PIP_INDEX_URL: server.registryUrl
|
||||
}
|
||||
}
|
||||
},
|
||||
requestInit: {
|
||||
headers: server.headers || {}
|
||||
},
|
||||
authProvider
|
||||
}
|
||||
return new SSEClientTransport(new URL(server.baseUrl!), options)
|
||||
} else {
|
||||
throw new Error('Invalid server type')
|
||||
}
|
||||
} else if (server.command) {
|
||||
let cmd = server.command
|
||||
|
||||
if (server.command === 'npx') {
|
||||
cmd = await getBinaryPath('bun')
|
||||
Logger.info(`[MCP] Using command: ${cmd}`)
|
||||
|
||||
// add -x to args if args exist
|
||||
if (args && args.length > 0) {
|
||||
if (!args.includes('-y')) {
|
||||
args.unshift('-y')
|
||||
}
|
||||
if (!args.includes('x')) {
|
||||
args.unshift('x')
|
||||
}
|
||||
}
|
||||
if (server.registryUrl) {
|
||||
server.env = {
|
||||
...server.env,
|
||||
NPM_CONFIG_REGISTRY: server.registryUrl
|
||||
}
|
||||
|
||||
// if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory
|
||||
if (server.name.includes('mcp-auto-install')) {
|
||||
const binPath = await getBinaryPath()
|
||||
makeSureDirExists(binPath)
|
||||
server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json')
|
||||
}
|
||||
}
|
||||
} else if (server.command === 'uvx' || server.command === 'uv') {
|
||||
cmd = await getBinaryPath(server.command)
|
||||
if (server.registryUrl) {
|
||||
server.env = {
|
||||
...server.env,
|
||||
UV_DEFAULT_INDEX: server.registryUrl,
|
||||
PIP_INDEX_URL: server.registryUrl
|
||||
}
|
||||
Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`)
|
||||
// Logger.info(`[MCP] Environment variables for server:`, server.env)
|
||||
const loginShellEnv = await this.getLoginShellEnv()
|
||||
const stdioTransport = new StdioClientTransport({
|
||||
command: cmd,
|
||||
args,
|
||||
env: {
|
||||
...loginShellEnv,
|
||||
...server.env
|
||||
},
|
||||
stderr: 'pipe'
|
||||
})
|
||||
stdioTransport.stderr?.on('data', (data) =>
|
||||
Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
|
||||
)
|
||||
return stdioTransport
|
||||
} else {
|
||||
throw new Error('Either baseUrl or command must be provided')
|
||||
}
|
||||
}
|
||||
|
||||
Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`)
|
||||
// Logger.info(`[MCP] Environment variables for server:`, server.env)
|
||||
const loginShellEnv = await this.getLoginShellEnv()
|
||||
const stdioTransport = new StdioClientTransport({
|
||||
command: cmd,
|
||||
args,
|
||||
env: {
|
||||
...loginShellEnv,
|
||||
...server.env
|
||||
},
|
||||
stderr: 'pipe'
|
||||
})
|
||||
stdioTransport.stderr?.on('data', (data) =>
|
||||
Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
|
||||
)
|
||||
return stdioTransport
|
||||
} else {
|
||||
throw new Error('Either baseUrl or command must be provided')
|
||||
}
|
||||
}
|
||||
const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => {
|
||||
Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`)
|
||||
// Create an event emitter for the OAuth callback
|
||||
const events = new EventEmitter()
|
||||
|
||||
const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => {
|
||||
Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`)
|
||||
// Create an event emitter for the OAuth callback
|
||||
const events = new EventEmitter()
|
||||
// Create a callback server
|
||||
const callbackServer = new CallBackServer({
|
||||
port: authProvider.config.callbackPort,
|
||||
path: authProvider.config.callbackPath || '/oauth/callback',
|
||||
events
|
||||
})
|
||||
|
||||
// Create a callback server
|
||||
const callbackServer = new CallBackServer({
|
||||
port: authProvider.config.callbackPort,
|
||||
path: authProvider.config.callbackPath || '/oauth/callback',
|
||||
events
|
||||
})
|
||||
// Set a timeout to close the callback server
|
||||
const timeoutId = setTimeout(() => {
|
||||
Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`)
|
||||
callbackServer.close()
|
||||
}, 300000) // 5 minutes timeout
|
||||
|
||||
// Set a timeout to close the callback server
|
||||
const timeoutId = setTimeout(() => {
|
||||
Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`)
|
||||
callbackServer.close()
|
||||
}, 300000) // 5 minutes timeout
|
||||
try {
|
||||
// Wait for the authorization code
|
||||
const authCode = await callbackServer.waitForAuthCode()
|
||||
Logger.info(`[MCP] Received auth code: ${authCode}`)
|
||||
|
||||
try {
|
||||
// Wait for the authorization code
|
||||
const authCode = await callbackServer.waitForAuthCode()
|
||||
Logger.info(`[MCP] Received auth code: ${authCode}`)
|
||||
// Complete the OAuth flow
|
||||
await transport.finishAuth(authCode)
|
||||
|
||||
// Complete the OAuth flow
|
||||
await transport.finishAuth(authCode)
|
||||
Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`)
|
||||
|
||||
Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`)
|
||||
const newTransport = await initTransport()
|
||||
// Try to connect again
|
||||
await client.connect(newTransport)
|
||||
|
||||
const newTransport = await initTransport()
|
||||
// Try to connect again
|
||||
await client.connect(newTransport)
|
||||
Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`)
|
||||
} catch (oauthError) {
|
||||
Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError)
|
||||
throw new Error(
|
||||
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}`
|
||||
)
|
||||
} finally {
|
||||
// Clear the timeout and close the callback server
|
||||
clearTimeout(timeoutId)
|
||||
callbackServer.close()
|
||||
}
|
||||
}
|
||||
|
||||
Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`)
|
||||
} catch (oauthError) {
|
||||
Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError)
|
||||
throw new Error(
|
||||
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}`
|
||||
)
|
||||
try {
|
||||
const transport = await initTransport()
|
||||
try {
|
||||
await client.connect(transport)
|
||||
} catch (error: Error | any) {
|
||||
if (
|
||||
error instanceof Error &&
|
||||
(error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))
|
||||
) {
|
||||
Logger.info(`[MCP] Authentication required for server: ${server.name}`)
|
||||
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
|
||||
} else {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// Store the new client in the cache
|
||||
this.clients.set(serverKey, client)
|
||||
|
||||
Logger.info(`[MCP] Activated server: ${server.name}`)
|
||||
return client
|
||||
} catch (error: any) {
|
||||
Logger.error(`[MCP] Error activating server ${server.name}:`, error?.message)
|
||||
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
|
||||
}
|
||||
} finally {
|
||||
// Clear the timeout and close the callback server
|
||||
clearTimeout(timeoutId)
|
||||
callbackServer.close()
|
||||
// Clean up the pending promise when done
|
||||
this.pendingClients.delete(serverKey)
|
||||
}
|
||||
}
|
||||
})()
|
||||
|
||||
try {
|
||||
const transport = await initTransport()
|
||||
try {
|
||||
await client.connect(transport)
|
||||
} catch (error: Error | any) {
|
||||
if (error instanceof Error && (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))) {
|
||||
Logger.info(`[MCP] Authentication required for server: ${server.name}`)
|
||||
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
|
||||
} else {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
// Store the pending promise
|
||||
this.pendingClients.set(serverKey, initPromise)
|
||||
|
||||
// Store the new client in the cache
|
||||
this.clients.set(serverKey, client)
|
||||
|
||||
Logger.info(`[MCP] Activated server: ${server.name}`)
|
||||
return client
|
||||
} catch (error: any) {
|
||||
Logger.error(`[MCP] Error activating server ${server.name}:`, error)
|
||||
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
|
||||
}
|
||||
return initPromise
|
||||
}
|
||||
|
||||
async closeClient(serverKey: string) {
|
||||
@ -358,8 +390,8 @@ class McpService {
|
||||
for (const [key] of this.clients) {
|
||||
try {
|
||||
await this.closeClient(key)
|
||||
} catch (error) {
|
||||
Logger.error(`[MCP] Failed to close client: ${error}`)
|
||||
} catch (error: any) {
|
||||
Logger.error(`[MCP] Failed to close client: ${error?.message}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -380,8 +412,8 @@ class McpService {
|
||||
serverTools.push(serverTool)
|
||||
})
|
||||
return serverTools
|
||||
} catch (error) {
|
||||
Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error)
|
||||
} catch (error: any) {
|
||||
Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error?.message)
|
||||
return []
|
||||
}
|
||||
}
|
||||
@ -440,8 +472,8 @@ class McpService {
|
||||
* List prompts available on an MCP server
|
||||
*/
|
||||
private async listPromptsImpl(server: MCPServer): Promise<MCPPrompt[]> {
|
||||
Logger.info(`[MCP] Listing prompts for server: ${server.name}`)
|
||||
const client = await this.initClient(server)
|
||||
Logger.info(`[MCP] Listing prompts for server: ${server.name}`)
|
||||
try {
|
||||
const { prompts } = await client.listPrompts()
|
||||
return prompts.map((prompt: any) => ({
|
||||
@ -450,8 +482,11 @@ class McpService {
|
||||
serverId: server.id,
|
||||
serverName: server.name
|
||||
}))
|
||||
} catch (error) {
|
||||
Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error)
|
||||
} catch (error: any) {
|
||||
// -32601 is the code for the method not found
|
||||
if (error?.code !== -32601) {
|
||||
Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error?.message)
|
||||
}
|
||||
return []
|
||||
}
|
||||
}
|
||||
@ -509,8 +544,8 @@ class McpService {
|
||||
* List resources available on an MCP server (implementation)
|
||||
*/
|
||||
private async listResourcesImpl(server: MCPServer): Promise<MCPResource[]> {
|
||||
Logger.info(`[MCP] Listing resources for server: ${server.name}`)
|
||||
const client = await this.initClient(server)
|
||||
Logger.info(`[MCP] Listing resources for server: ${server.name}`)
|
||||
try {
|
||||
const result = await client.listResources()
|
||||
const resources = result.resources || []
|
||||
@ -520,8 +555,11 @@ class McpService {
|
||||
serverName: server.name
|
||||
}))
|
||||
return serverResources
|
||||
} catch (error) {
|
||||
Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error)
|
||||
} catch (error: any) {
|
||||
// -32601 is the code for the method not found
|
||||
if (error?.code !== -32601) {
|
||||
Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error?.message)
|
||||
}
|
||||
return []
|
||||
}
|
||||
}
|
||||
@ -564,7 +602,7 @@ class McpService {
|
||||
contents: contents
|
||||
}
|
||||
} catch (error: Error | any) {
|
||||
Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error)
|
||||
Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error.message)
|
||||
throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`)
|
||||
}
|
||||
}
|
||||
@ -603,5 +641,13 @@ class McpService {
|
||||
})
|
||||
}
|
||||
|
||||
const mcpService = new McpService()
|
||||
export default mcpService
|
||||
let mcpInstance: ReturnType<typeof McpService.getInstance> | null = null
|
||||
|
||||
export const getMcpInstance = () => {
|
||||
if (!mcpInstance) {
|
||||
mcpInstance = McpService.getInstance()
|
||||
}
|
||||
return mcpInstance
|
||||
}
|
||||
|
||||
export default McpService.getInstance
|
||||
|
||||
@ -75,7 +75,8 @@ export class WindowService {
|
||||
sandbox: false,
|
||||
webSecurity: false,
|
||||
webviewTag: true,
|
||||
allowRunningInsecureContent: true
|
||||
allowRunningInsecureContent: true,
|
||||
backgroundThrottling: false
|
||||
}
|
||||
})
|
||||
|
||||
@ -323,6 +324,11 @@ export class WindowService {
|
||||
|
||||
event.preventDefault()
|
||||
|
||||
if (mainWindow.isFullScreen()) {
|
||||
mainWindow.setFullScreen(false)
|
||||
return
|
||||
}
|
||||
|
||||
mainWindow.hide()
|
||||
|
||||
//for mac users, should hide dock icon if close to tray
|
||||
@ -440,7 +446,8 @@ export class WindowService {
|
||||
preload: join(__dirname, '../preload/index.js'),
|
||||
sandbox: false,
|
||||
webSecurity: false,
|
||||
webviewTag: true
|
||||
webviewTag: true,
|
||||
backgroundThrottling: false
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ function getLoginShellEnvironment(): Promise<Record<string, string>> {
|
||||
commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command
|
||||
}
|
||||
|
||||
Logger.log(`Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`)
|
||||
Logger.log(`[ShellEnv] Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`)
|
||||
|
||||
const child = spawn(shellPath, commandArgs, {
|
||||
cwd: homeDirectory, // Run the command in the user's home directory
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="initial-scale=1, width=device-width" />
|
||||
<meta http-equiv="Content-Security-Policy"
|
||||
content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; frame-src * file:" />
|
||||
content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' 'unsafe-inline' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; frame-src * file:" />
|
||||
<title>Cherry Studio</title>
|
||||
|
||||
<style>
|
||||
@ -21,7 +21,7 @@
|
||||
flex-direction: row;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
display: none;
|
||||
display: flex;
|
||||
}
|
||||
|
||||
#spinner img {
|
||||
@ -36,6 +36,9 @@
|
||||
<div id="spinner">
|
||||
<img src="/src/assets/images/logo.png" />
|
||||
</div>
|
||||
<script>
|
||||
console.time('init')
|
||||
</script>
|
||||
<script type="module" src="/src/init.ts"></script>
|
||||
<script type="module" src="/src/entryPoint.tsx"></script>
|
||||
</body>
|
||||
|
||||
@ -24,6 +24,11 @@ export function useAppInit() {
|
||||
const avatar = useLiveQuery(() => db.settings.get('image://avatar'))
|
||||
const { theme } = useTheme()
|
||||
|
||||
useEffect(() => {
|
||||
document.getElementById('spinner')?.remove()
|
||||
console.timeEnd('init')
|
||||
}, [])
|
||||
|
||||
useUpdateHandler()
|
||||
useFullScreenNotice()
|
||||
|
||||
@ -32,7 +37,6 @@ export function useAppInit() {
|
||||
}, [avatar, dispatch])
|
||||
|
||||
useEffect(() => {
|
||||
document.getElementById('spinner')?.remove()
|
||||
runAsyncFunction(async () => {
|
||||
const { isPackaged } = await window.api.getAppInfo()
|
||||
if (isPackaged && autoCheckUpdate) {
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import { createSelector } from '@reduxjs/toolkit'
|
||||
import store, { useAppDispatch, useAppSelector } from '@renderer/store'
|
||||
import { addMCPServer, deleteMCPServer, setMCPServers, updateMCPServer } from '@renderer/store/mcp'
|
||||
import { MCPServer } from '@renderer/types'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { useMemo } from 'react'
|
||||
|
||||
const ipcRenderer = window.electron.ipcRenderer
|
||||
|
||||
@ -14,9 +14,14 @@ ipcRenderer.on(IpcChannel.Mcp_AddServer, (_event, server: MCPServer) => {
|
||||
store.dispatch(addMCPServer(server))
|
||||
})
|
||||
|
||||
const selectMcpServers = (state) => state.mcp.servers
|
||||
const selectActiveMcpServers = createSelector([selectMcpServers], (servers) =>
|
||||
servers.filter((server) => server.isActive)
|
||||
)
|
||||
|
||||
export const useMCPServers = () => {
|
||||
const mcpServers = useAppSelector((state) => state.mcp.servers)
|
||||
const activedMcpServers = useMemo(() => mcpServers.filter((server) => server.isActive), [mcpServers])
|
||||
const mcpServers = useAppSelector(selectMcpServers)
|
||||
const activedMcpServers = useAppSelector(selectActiveMcpServers)
|
||||
const dispatch = useAppDispatch()
|
||||
|
||||
return {
|
||||
|
||||
@ -181,6 +181,7 @@
|
||||
"input.upload": "Upload image or document file",
|
||||
"input.upload.document": "Upload document file (model does not support images)",
|
||||
"input.web_search": "Web search",
|
||||
"input.web_search.settings": "Web Search Settings",
|
||||
"input.web_search.button.ok": "Go to Settings",
|
||||
"input.web_search.enable": "Enable web search",
|
||||
"input.web_search.enable_content": "Need to check web search connectivity in settings first",
|
||||
|
||||
@ -181,6 +181,7 @@
|
||||
"input.upload": "画像またはドキュメントをアップロード",
|
||||
"input.upload.document": "ドキュメントをアップロード(モデルは画像をサポートしません)",
|
||||
"input.web_search": "ウェブ検索",
|
||||
"input.web_search.settings": "ウェブ検索設定",
|
||||
"input.web_search.button.ok": "設定に移動",
|
||||
"input.web_search.enable": "ウェブ検索を有効にする",
|
||||
"input.web_search.enable_content": "ウェブ検索の接続性を先に設定で確認する必要があります",
|
||||
|
||||
@ -181,6 +181,7 @@
|
||||
"input.upload": "Загрузить изображение или документ",
|
||||
"input.upload.document": "Загрузить документ (модель не поддерживает изображения)",
|
||||
"input.web_search": "Веб-поиск",
|
||||
"input.web_search.settings": "Настройки веб-поиска",
|
||||
"input.web_search.button.ok": "Перейти в Настройки",
|
||||
"input.web_search.enable": "Включить веб-поиск",
|
||||
"input.web_search.enable_content": "Необходимо предварительно проверить подключение к веб-поиску в настройках",
|
||||
|
||||
@ -190,6 +190,7 @@
|
||||
"input.upload.upload_from_local": "上传本地文件...",
|
||||
"input.upload.document": "上传文档(模型不支持图片)",
|
||||
"input.web_search": "网络搜索",
|
||||
"input.web_search.settings": "网络搜索设置",
|
||||
"input.web_search.button.ok": "去设置",
|
||||
"input.web_search.enable": "开启网络搜索",
|
||||
"input.web_search.enable_content": "需要先在设置中检查网络搜索连通性",
|
||||
|
||||
@ -181,6 +181,7 @@
|
||||
"input.upload": "上傳圖片或文件",
|
||||
"input.upload.document": "上傳文件(模型不支援圖片)",
|
||||
"input.web_search": "網路搜尋",
|
||||
"input.web_search.settings": "網路搜尋設定",
|
||||
"input.web_search.button.ok": "去設定",
|
||||
"input.web_search.enable": "開啟網路搜尋",
|
||||
"input.web_search.enable_content": "需要先在設定中開啟網路搜尋",
|
||||
|
||||
@ -5,13 +5,6 @@ import { startNutstoreAutoSync } from './services/NutstoreService'
|
||||
import storeSyncService from './services/StoreSyncService'
|
||||
import store from './store'
|
||||
|
||||
function initSpinner() {
|
||||
const spinner = document.getElementById('spinner')
|
||||
if (spinner) {
|
||||
spinner.style.display = 'flex'
|
||||
}
|
||||
}
|
||||
|
||||
function initKeyv() {
|
||||
window.keyv = new KeyvStorage()
|
||||
window.keyv.init()
|
||||
@ -34,7 +27,6 @@ function initStoreSync() {
|
||||
storeSyncService.subscribe()
|
||||
}
|
||||
|
||||
initSpinner()
|
||||
initKeyv()
|
||||
initAutoSync()
|
||||
initStoreSync()
|
||||
|
||||
@ -581,7 +581,27 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
|
||||
|
||||
const onPaste = useCallback(
|
||||
async (event: ClipboardEvent) => {
|
||||
// 1. 文件/图片粘贴
|
||||
// 优先处理文本粘贴
|
||||
const clipboardText = event.clipboardData?.getData('text')
|
||||
if (clipboardText) {
|
||||
// 1. 文本粘贴
|
||||
if (pasteLongTextAsFile && clipboardText.length > pasteLongTextThreshold) {
|
||||
// 长文本直接转文件,阻止默认粘贴
|
||||
event.preventDefault()
|
||||
|
||||
const tempFilePath = await window.api.file.create('pasted_text.txt')
|
||||
await window.api.file.write(tempFilePath, clipboardText)
|
||||
const selectedFile = await window.api.file.get(tempFilePath)
|
||||
selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile])
|
||||
setText(text) // 保持输入框内容不变
|
||||
setTimeout(() => resizeTextArea(), 50)
|
||||
return
|
||||
}
|
||||
// 短文本走默认粘贴行为,直接返回
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 文件/图片粘贴(仅在无文本时处理)
|
||||
if (event.clipboardData?.files && event.clipboardData.files.length > 0) {
|
||||
event.preventDefault()
|
||||
for (const file of event.clipboardData.files) {
|
||||
@ -626,23 +646,7 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 文本粘贴
|
||||
const clipboardText = event.clipboardData?.getData('text')
|
||||
if (pasteLongTextAsFile && clipboardText && clipboardText.length > pasteLongTextThreshold) {
|
||||
// 长文本直接转文件,阻止默认粘贴
|
||||
event.preventDefault()
|
||||
|
||||
const tempFilePath = await window.api.file.create('pasted_text.txt')
|
||||
await window.api.file.write(tempFilePath, clipboardText)
|
||||
const selectedFile = await window.api.file.get(tempFilePath)
|
||||
selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile])
|
||||
setText(text) // 保持输入框内容不变
|
||||
setTimeout(() => resizeTextArea(), 50)
|
||||
return
|
||||
}
|
||||
|
||||
// 短文本走默认粘贴行为
|
||||
// 其他情况默认粘贴
|
||||
},
|
||||
[model, pasteLongTextAsFile, pasteLongTextThreshold, resizeTextArea, supportExts, t, text]
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@ import { useAssistant } from '@renderer/hooks/useAssistant'
|
||||
import { useMCPServers } from '@renderer/hooks/useMCPServers'
|
||||
import { EventEmitter } from '@renderer/services/EventService'
|
||||
import { Assistant, MCPPrompt, MCPResource, MCPServer } from '@renderer/types'
|
||||
import { delay, runAsyncFunction } from '@renderer/utils'
|
||||
import { Form, Input, Tooltip } from 'antd'
|
||||
import { Plus, SquareTerminal } from 'lucide-react'
|
||||
import { FC, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react'
|
||||
@ -109,6 +110,11 @@ const extractPromptContent = (response: any): string | null => {
|
||||
return null
|
||||
}
|
||||
|
||||
// Add static variable before component definition
|
||||
let isFirstResourcesListCall = true
|
||||
let isFirstPromptListCall = true
|
||||
const initMcpDelay = 3
|
||||
|
||||
const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, ToolbarButton, ...props }) => {
|
||||
const { activedMcpServers } = useMCPServers()
|
||||
const { t } = useTranslation()
|
||||
@ -308,6 +314,11 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
|
||||
const promptList = useMemo(async () => {
|
||||
const prompts: MCPPrompt[] = []
|
||||
|
||||
if (isFirstPromptListCall) {
|
||||
await delay(initMcpDelay)
|
||||
isFirstPromptListCall = false
|
||||
}
|
||||
|
||||
for (const server of activedMcpServers) {
|
||||
const serverPrompts = await window.api.mcp.listPrompts(server)
|
||||
prompts.push(...serverPrompts)
|
||||
@ -319,7 +330,8 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
|
||||
icon: <SquareTerminal />,
|
||||
action: () => handlePromptSelect(prompt as MCPPromptWithArgs)
|
||||
}))
|
||||
}, [handlePromptSelect, activedMcpServers])
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [activedMcpServers])
|
||||
|
||||
const openPromptList = useCallback(async () => {
|
||||
const prompts = await promptList
|
||||
@ -380,33 +392,42 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
|
||||
const [resourcesList, setResourcesList] = useState<QuickPanelListItem[]>([])
|
||||
|
||||
useEffect(() => {
|
||||
let isMounted = true
|
||||
runAsyncFunction(async () => {
|
||||
let isMounted = true
|
||||
|
||||
const fetchResources = async () => {
|
||||
const resources: MCPResource[] = []
|
||||
for (const server of activedMcpServers) {
|
||||
const serverResources = await window.api.mcp.listResources(server)
|
||||
resources.push(...serverResources)
|
||||
const fetchResources = async () => {
|
||||
const resources: MCPResource[] = []
|
||||
|
||||
for (const server of activedMcpServers) {
|
||||
const serverResources = await window.api.mcp.listResources(server)
|
||||
resources.push(...serverResources)
|
||||
}
|
||||
|
||||
if (isMounted) {
|
||||
setResourcesList(
|
||||
resources.map((resource) => ({
|
||||
label: resource.name,
|
||||
description: resource.description,
|
||||
icon: <SquareTerminal />,
|
||||
action: () => handleResourceSelect(resource)
|
||||
}))
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if (isMounted) {
|
||||
setResourcesList(
|
||||
resources.map((resource) => ({
|
||||
label: resource.name,
|
||||
description: resource.description,
|
||||
icon: <SquareTerminal />,
|
||||
action: () => handleResourceSelect(resource)
|
||||
}))
|
||||
)
|
||||
// Avoid mcp following the software startup, affecting the startup speed
|
||||
if (isFirstResourcesListCall) {
|
||||
await delay(initMcpDelay)
|
||||
isFirstResourcesListCall = false
|
||||
fetchResources()
|
||||
}
|
||||
}
|
||||
|
||||
fetchResources()
|
||||
|
||||
return () => {
|
||||
isMounted = false
|
||||
}
|
||||
}, [activedMcpServers, handleResourceSelect])
|
||||
return () => {
|
||||
isMounted = false
|
||||
}
|
||||
})
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [activedMcpServers])
|
||||
|
||||
const openResourcesList = useCallback(async () => {
|
||||
const resources = resourcesList
|
||||
|
||||
@ -22,18 +22,6 @@ const TokenCount: FC<Props> = ({ estimateTokenCount, inputTokenCount, contextCou
|
||||
}
|
||||
|
||||
const formatMaxCount = (max: number) => {
|
||||
if (max == 100) {
|
||||
return (
|
||||
<span
|
||||
style={{
|
||||
fontSize: '16px',
|
||||
position: 'relative',
|
||||
top: '1px'
|
||||
}}>
|
||||
∞
|
||||
</span>
|
||||
)
|
||||
}
|
||||
return max.toString()
|
||||
}
|
||||
|
||||
@ -43,7 +31,7 @@ const TokenCount: FC<Props> = ({ estimateTokenCount, inputTokenCount, contextCou
|
||||
<HStack justifyContent="space-between" w="100%">
|
||||
<Text>{t('chat.input.context_count.tip')}</Text>
|
||||
<Text>
|
||||
{contextCount.current} / {contextCount.max == 20 ? '∞' : contextCount.max}
|
||||
{contextCount.current} / {contextCount.max}
|
||||
</Text>
|
||||
</HStack>
|
||||
<Divider style={{ margin: '5px 0' }} />
|
||||
|
||||
@ -79,7 +79,7 @@ const WebSearchButton: FC<Props> = ({ ref, assistant, ToolbarButton }) => {
|
||||
}
|
||||
|
||||
items.push({
|
||||
label: '前往设置' + '...',
|
||||
label: t('chat.input.web_search.settings'),
|
||||
icon: <Settings />,
|
||||
action: () => navigate('/settings/web-search')
|
||||
})
|
||||
|
||||
@ -38,13 +38,14 @@ const MainTextBlock: React.FC<Props> = ({ block, citationBlockId, role, mentions
|
||||
// Use the passed citationBlockId directly in the selector
|
||||
const { renderInputMessageAsMarkdown } = useSettings()
|
||||
|
||||
const formattedCitations = useSelector((state: RootState) => {
|
||||
const citations = selectFormattedCitationsByBlockId(state, citationBlockId)
|
||||
return citations.map((citation) => ({
|
||||
const rawCitations = useSelector((state: RootState) => selectFormattedCitationsByBlockId(state, citationBlockId))
|
||||
|
||||
const formattedCitations = useMemo(() => {
|
||||
return rawCitations.map((citation) => ({
|
||||
...citation,
|
||||
content: citation.content ? cleanMarkdownContent(citation.content) : citation.content
|
||||
}))
|
||||
})
|
||||
}, [rawCitations])
|
||||
|
||||
const processedContent = useMemo(() => {
|
||||
let content = block.content
|
||||
|
||||
@ -7,7 +7,7 @@ import type { Topic } from '@renderer/types'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import { classNames } from '@renderer/utils'
|
||||
import { Popover } from 'antd'
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import styled, { css } from 'styled-components'
|
||||
|
||||
import MessageItem from './Message'
|
||||
@ -31,7 +31,8 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => {
|
||||
const prevMessageLengthRef = useRef(messageLength)
|
||||
const [selectedIndex, setSelectedIndex] = useState(messageLength - 1)
|
||||
|
||||
const getSelectedMessageId = useCallback(() => {
|
||||
const selectedMessageId = useMemo(() => {
|
||||
if (messages.length === 1) return messages[0]?.id
|
||||
const selectedMessage = messages.find((message) => message.foldSelected)
|
||||
if (selectedMessage) {
|
||||
return selectedMessage.id
|
||||
@ -41,9 +42,10 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => {
|
||||
|
||||
const setSelectedMessage = useCallback(
|
||||
(message: Message) => {
|
||||
messages.forEach(async (m) => {
|
||||
await editMessage(m.id, { foldSelected: m.id === message.id })
|
||||
})
|
||||
// 前一个
|
||||
editMessage(selectedMessageId, { foldSelected: false })
|
||||
// 当前选中的消息
|
||||
editMessage(message.id, { foldSelected: true })
|
||||
|
||||
setTimeout(() => {
|
||||
const messageElement = document.getElementById(`message-${message.id}`)
|
||||
@ -52,7 +54,7 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => {
|
||||
}
|
||||
}, 200)
|
||||
},
|
||||
[editMessage, messages]
|
||||
[editMessage, selectedMessageId]
|
||||
)
|
||||
|
||||
const isGrouped = messageLength > 1 && messages.every((m) => m.role === 'assistant')
|
||||
@ -67,8 +69,7 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => {
|
||||
setSelectedMessage(lastMessage)
|
||||
}
|
||||
} else {
|
||||
const selectedId = getSelectedMessageId()
|
||||
const newIndex = messages.findIndex((msg) => msg.id === selectedId)
|
||||
const newIndex = messages.findIndex((msg) => msg.id === selectedMessageId)
|
||||
if (newIndex !== -1) {
|
||||
setSelectedIndex(newIndex)
|
||||
}
|
||||
@ -147,7 +148,7 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => {
|
||||
}, [messages, setSelectedMessage])
|
||||
|
||||
const renderMessage = useCallback(
|
||||
(message: Message & { index: number }, index: number) => {
|
||||
(message: Message & { index: number }) => {
|
||||
const isGridGroupMessage = isGrid && message.role === 'assistant' && isGrouped
|
||||
const messageProps = {
|
||||
isGrouped,
|
||||
@ -164,13 +165,13 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => {
|
||||
<MessageWrapper
|
||||
id={`message-${message.id}`}
|
||||
$layout={multiModelMessageStyle}
|
||||
$selected={index === selectedIndex}
|
||||
// $selected={index === selectedIndex}
|
||||
$isGrouped={isGrouped}
|
||||
key={message.id}
|
||||
className={classNames({
|
||||
'group-message-wrapper': message.role === 'assistant' && isHorizontal && isGrouped,
|
||||
[multiModelMessageStyle]: isGrouped,
|
||||
selected: message.id === getSelectedMessageId()
|
||||
selected: message.id === selectedMessageId
|
||||
})}>
|
||||
<MessageItem {...messageProps} />
|
||||
</MessageWrapper>
|
||||
@ -183,7 +184,7 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => {
|
||||
content={
|
||||
<MessageWrapper
|
||||
$layout={multiModelMessageStyle}
|
||||
$selected={index === selectedIndex}
|
||||
// $selected={index === selectedIndex}
|
||||
$isGrouped={isGrouped}
|
||||
$isInPopover={true}>
|
||||
<MessageItem {...messageProps} />
|
||||
@ -204,11 +205,10 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => {
|
||||
isGrouped,
|
||||
isHorizontal,
|
||||
multiModelMessageStyle,
|
||||
selectedIndex,
|
||||
topic,
|
||||
hidePresetMessages,
|
||||
gridPopoverTrigger,
|
||||
getSelectedMessageId
|
||||
selectedMessageId
|
||||
]
|
||||
)
|
||||
|
||||
@ -235,7 +235,7 @@ const MessageGroup = ({ messages, topic, hidePresetMessages }: Props) => {
|
||||
})
|
||||
}}
|
||||
messages={messages}
|
||||
selectMessageId={getSelectedMessageId()}
|
||||
selectMessageId={selectedMessageId}
|
||||
setSelectedMessage={setSelectedMessage}
|
||||
topic={topic}
|
||||
/>
|
||||
@ -297,7 +297,7 @@ const GridContainer = styled.div<{ $count: number; $layout: MultiModelMessageSty
|
||||
|
||||
interface MessageWrapperProps {
|
||||
$layout: 'fold' | 'horizontal' | 'vertical' | 'grid'
|
||||
$selected: boolean
|
||||
// $selected: boolean
|
||||
$isGrouped: boolean
|
||||
$isInPopover?: boolean
|
||||
}
|
||||
|
||||
@ -34,7 +34,6 @@ const MessageTools: FC<Props> = ({ blocks }) => {
|
||||
return 'Invalid Result'
|
||||
}
|
||||
}, [toolResponse])
|
||||
const { renderedMarkdown: styledResult } = useShikiWithMarkdownIt(`\`\`\`json\n${resultString}\n\`\`\``)
|
||||
|
||||
if (!toolResponse) {
|
||||
return null
|
||||
@ -54,8 +53,6 @@ const MessageTools: FC<Props> = ({ blocks }) => {
|
||||
// Format tool responses for collapse items
|
||||
const getCollapseItems = () => {
|
||||
const items: { key: string; label: React.ReactNode; children: React.ReactNode }[] = []
|
||||
// Add tool responses
|
||||
// for (const toolResponse of toolResponses) {
|
||||
const { id, tool, status, response } = toolResponse
|
||||
const isInvoking = status === 'invoking'
|
||||
const isDone = status === 'done'
|
||||
@ -122,11 +119,10 @@ const MessageTools: FC<Props> = ({ blocks }) => {
|
||||
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
|
||||
fontSize: '12px'
|
||||
}}>
|
||||
<div className="markdown" dangerouslySetInnerHTML={{ __html: styledResult }} />
|
||||
<CollapsedContent isExpanded={activeKeys.includes(id)} resultString={resultString} />
|
||||
</ToolResponseContainer>
|
||||
)
|
||||
})
|
||||
// }
|
||||
|
||||
return items
|
||||
}
|
||||
@ -139,7 +135,6 @@ const MessageTools: FC<Props> = ({ blocks }) => {
|
||||
switch (parsedResult.content[0]?.type) {
|
||||
case 'text':
|
||||
return <PreviewBlock>{parsedResult.content[0].text}</PreviewBlock>
|
||||
// TODO: support other types
|
||||
default:
|
||||
return <PreviewBlock>{content}</PreviewBlock>
|
||||
}
|
||||
@ -177,7 +172,6 @@ const MessageTools: FC<Props> = ({ blocks }) => {
|
||||
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
|
||||
fontSize
|
||||
}}>
|
||||
{/* mode swtich tabs */}
|
||||
<Tabs
|
||||
tabBarExtraContent={
|
||||
<ActionButton
|
||||
@ -203,7 +197,16 @@ const MessageTools: FC<Props> = ({ blocks }) => {
|
||||
{
|
||||
key: 'raw',
|
||||
label: t('message.tools.raw'),
|
||||
children: <div className="markdown" dangerouslySetInnerHTML={{ __html: styledResult }} />
|
||||
children: (
|
||||
<CollapsedContent
|
||||
isExpanded={true}
|
||||
resultString={
|
||||
typeof expandedResponse.content === 'string'
|
||||
? expandedResponse.content
|
||||
: JSON.stringify(expandedResponse.content, null, 2)
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
]}
|
||||
/>
|
||||
@ -214,6 +217,19 @@ const MessageTools: FC<Props> = ({ blocks }) => {
|
||||
)
|
||||
}
|
||||
|
||||
// New component to handle collapsed content
|
||||
const CollapsedContent: FC<{ isExpanded: boolean; resultString: string }> = ({ isExpanded, resultString }) => {
|
||||
const { renderedMarkdown: styledResult } = useShikiWithMarkdownIt(
|
||||
isExpanded ? `\`\`\`json\n${resultString}\n\`\`\`` : ''
|
||||
)
|
||||
|
||||
if (!isExpanded) {
|
||||
return null
|
||||
}
|
||||
|
||||
return <div className="markdown" dangerouslySetInnerHTML={{ __html: styledResult }} />
|
||||
}
|
||||
|
||||
const CollapseContainer = styled(Collapse)`
|
||||
margin-top: 10px;
|
||||
margin-bottom: 12px;
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import { CopyOutlined } from '@ant-design/icons'
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant'
|
||||
import { getFileFromUrl, getKnowledgeBaseParams } from '@renderer/services/KnowledgeService'
|
||||
import { searchKnowledgeBase } from '@renderer/services/KnowledgeService'
|
||||
import { FileType, KnowledgeBase } from '@renderer/types'
|
||||
import { Input, List, message, Modal, Spin, Tooltip, Typography } from 'antd'
|
||||
import { useRef, useState } from 'react'
|
||||
@ -38,29 +37,8 @@ const PopupContainer: React.FC<Props> = ({ base, resolve }) => {
|
||||
setSearchKeyword(value.trim())
|
||||
setLoading(true)
|
||||
try {
|
||||
const searchResults = await window.api.knowledgeBase.search({
|
||||
search: value,
|
||||
base: getKnowledgeBaseParams(base)
|
||||
})
|
||||
let rerankResult = searchResults
|
||||
if (base.rerankModel) {
|
||||
rerankResult = await window.api.knowledgeBase.rerank({
|
||||
search: value,
|
||||
base: getKnowledgeBaseParams(base),
|
||||
results: searchResults
|
||||
})
|
||||
}
|
||||
const results = await Promise.all(
|
||||
rerankResult.map(async (item) => {
|
||||
const file = await getFileFromUrl(item.metadata.source)
|
||||
return { ...item, file }
|
||||
})
|
||||
)
|
||||
const filteredResults = results.filter((item) => {
|
||||
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
|
||||
return item.score >= threshold
|
||||
})
|
||||
setResults(filteredResults)
|
||||
const searchResults = await searchKnowledgeBase(value, base)
|
||||
setResults(searchResults)
|
||||
} catch (error) {
|
||||
console.error('Search failed:', error)
|
||||
} finally {
|
||||
|
||||
@ -29,7 +29,6 @@ interface FormData {
|
||||
chunkOverlap?: number
|
||||
threshold?: number
|
||||
rerankModel?: string
|
||||
topN?: number
|
||||
}
|
||||
|
||||
interface Props extends ShowParams {
|
||||
@ -95,8 +94,7 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
threshold: values.threshold ?? undefined,
|
||||
rerankModel: values.rerankModel
|
||||
? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel)
|
||||
: undefined,
|
||||
topN: values.topN
|
||||
: undefined
|
||||
}
|
||||
updateKnowledgeBase(newBase)
|
||||
setOpen(false)
|
||||
@ -283,23 +281,6 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
<InputNumber placeholder={t('knowledge.threshold_placeholder')} step={0.1} style={{ width: '100%' }} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="topN"
|
||||
label={t('knowledge.topN')}
|
||||
layout="horizontal"
|
||||
initialValue={base.topN}
|
||||
rules={[
|
||||
{
|
||||
validator(_, value) {
|
||||
if (value && (value < 0 || value > 30)) {
|
||||
return Promise.reject(new Error(t('knowledge.topN_too_large_or_small')))
|
||||
}
|
||||
return Promise.resolve()
|
||||
}
|
||||
}
|
||||
]}>
|
||||
<InputNumber placeholder={t('knowledge.topN_placeholder')} style={{ width: '100%' }} />
|
||||
</Form.Item>
|
||||
<Alert
|
||||
message={t('knowledge.chunk_size_change_warning')}
|
||||
type="warning"
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||
import { runAsyncFunction } from '@renderer/utils'
|
||||
import { getShikiInstance } from '@renderer/utils/shiki'
|
||||
import { Card } from 'antd'
|
||||
import MarkdownIt from 'markdown-it'
|
||||
@ -30,9 +31,11 @@ const MCPDescription = ({ searchKey }: McpDescriptionProps) => {
|
||||
}, [md, searchKey])
|
||||
|
||||
useEffect(() => {
|
||||
const sk = getShikiInstance(theme)
|
||||
md.current.use(sk)
|
||||
getMcpInfo()
|
||||
runAsyncFunction(async () => {
|
||||
const sk = await getShikiInstance(theme)
|
||||
md.current.use(sk)
|
||||
getMcpInfo()
|
||||
})
|
||||
}, [getMcpInfo, theme])
|
||||
|
||||
return (
|
||||
|
||||
@ -47,8 +47,8 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams
|
||||
rerankBaseURL: rerankHost,
|
||||
rerankApiKey: rerankAiProvider.getApiKey() || 'secret',
|
||||
rerankModel: base.rerankModel?.id,
|
||||
rerankModelProvider: base.rerankModel?.provider,
|
||||
topN: base.topN
|
||||
rerankModelProvider: base.rerankModel?.provider
|
||||
// topN: base.topN
|
||||
}
|
||||
}
|
||||
|
||||
@ -88,6 +88,51 @@ export const getKnowledgeSourceUrl = async (item: ExtractChunkData & { file: Fil
|
||||
return item.metadata.source
|
||||
}
|
||||
|
||||
export const searchKnowledgeBase = async (
|
||||
query: string,
|
||||
base: KnowledgeBase,
|
||||
rewrite?: string
|
||||
): Promise<Array<ExtractChunkData & { file: FileType | null }>> => {
|
||||
try {
|
||||
const baseParams = getKnowledgeBaseParams(base)
|
||||
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
|
||||
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
|
||||
|
||||
// 执行搜索
|
||||
const searchResults = await window.api.knowledgeBase.search({
|
||||
search: query,
|
||||
base: baseParams
|
||||
})
|
||||
|
||||
// 过滤阈值不达标的结果
|
||||
const filteredResults = searchResults.filter((item) => item.score >= threshold)
|
||||
|
||||
// 如果有rerank模型,执行重排
|
||||
let rerankResults = filteredResults
|
||||
if (base.rerankModel && filteredResults.length > 0) {
|
||||
rerankResults = await window.api.knowledgeBase.rerank({
|
||||
search: rewrite || query,
|
||||
base: baseParams,
|
||||
results: filteredResults
|
||||
})
|
||||
}
|
||||
|
||||
// 限制文档数量
|
||||
const limitedResults = rerankResults.slice(0, documentCount)
|
||||
|
||||
// 处理文件信息
|
||||
return await Promise.all(
|
||||
limitedResults.map(async (item) => {
|
||||
const file = await getFileFromUrl(item.metadata.source)
|
||||
return { ...item, file }
|
||||
})
|
||||
)
|
||||
} catch (error) {
|
||||
Logger.error(`Error searching knowledge base ${base.name}:`, error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
export const processKnowledgeSearch = async (
|
||||
extractResults: ExtractResults,
|
||||
knowledgeBaseIds: string[] | undefined
|
||||
@ -100,6 +145,7 @@ export const processKnowledgeSearch = async (
|
||||
Logger.log('No valid question found in extractResults.knowledge')
|
||||
return []
|
||||
}
|
||||
|
||||
const questions = extractResults.knowledge.question
|
||||
const rewrite = extractResults.knowledge.rewrite
|
||||
|
||||
@ -109,73 +155,35 @@ export const processKnowledgeSearch = async (
|
||||
return []
|
||||
}
|
||||
|
||||
const referencesPromises = bases.map(async (base) => {
|
||||
try {
|
||||
const baseParams = getKnowledgeBaseParams(base)
|
||||
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
|
||||
// 为每个知识库执行多问题搜索
|
||||
const baseSearchPromises = bases.map(async (base) => {
|
||||
// 为每个问题搜索并合并结果
|
||||
const allResults = await Promise.all(questions.map((question) => searchKnowledgeBase(question, base, rewrite)))
|
||||
|
||||
const allSearchResultsPromises = questions.map((question) =>
|
||||
window.api.knowledgeBase
|
||||
.search({
|
||||
search: question,
|
||||
base: baseParams
|
||||
})
|
||||
.then((results) =>
|
||||
results.filter((item) => {
|
||||
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
|
||||
return item.score >= threshold
|
||||
})
|
||||
)
|
||||
)
|
||||
// 合并结果并去重
|
||||
const flatResults = allResults.flat()
|
||||
const uniqueResults = Array.from(
|
||||
new Map(flatResults.map((item) => [item.metadata.uniqueId || item.pageContent, item])).values()
|
||||
).sort((a, b) => b.score - a.score)
|
||||
|
||||
const allSearchResults = await Promise.all(allSearchResultsPromises)
|
||||
|
||||
const searchResults = Array.from(
|
||||
new Map(allSearchResults.flat().map((item) => [item.metadata.uniqueId || item.pageContent, item])).values()
|
||||
).sort((a, b) => b.score - a.score)
|
||||
|
||||
Logger.log(`Knowledge base ${base.name} search results:`, searchResults)
|
||||
|
||||
let rerankResults = searchResults
|
||||
if (base.rerankModel && searchResults.length > 0) {
|
||||
rerankResults = await window.api.knowledgeBase.rerank({
|
||||
search: rewrite,
|
||||
base: baseParams,
|
||||
results: searchResults
|
||||
})
|
||||
}
|
||||
|
||||
if (rerankResults.length > 0) {
|
||||
rerankResults = rerankResults.slice(0, documentCount)
|
||||
}
|
||||
|
||||
const processdResults = await Promise.all(
|
||||
rerankResults.map(async (item) => {
|
||||
const file = await getFileFromUrl(item.metadata.source)
|
||||
return { ...item, file }
|
||||
})
|
||||
)
|
||||
|
||||
return await Promise.all(
|
||||
processdResults.map(async (item, index) => {
|
||||
// const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId)
|
||||
return {
|
||||
id: index + 1, // 搜索多个库会导致ID重复
|
||||
// 转换为引用格式
|
||||
return await Promise.all(
|
||||
uniqueResults.map(
|
||||
async (item, index) =>
|
||||
({
|
||||
id: index + 1,
|
||||
content: item.pageContent,
|
||||
sourceUrl: await getKnowledgeSourceUrl(item),
|
||||
type: 'file' // 需要映射 baseItem.type是'localPathLoader' -> 'file'
|
||||
} as KnowledgeReference
|
||||
})
|
||||
type: 'file'
|
||||
}) as KnowledgeReference
|
||||
)
|
||||
} catch (error) {
|
||||
Logger.error(`Error searching knowledge base ${base.name}:`, error)
|
||||
return []
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
const resultsPerBase = await Promise.all(referencesPromises)
|
||||
|
||||
// 汇总所有知识库的结果
|
||||
const resultsPerBase = await Promise.all(baseSearchPromises)
|
||||
const allReferencesRaw = resultsPerBase.flat().filter((ref): ref is KnowledgeReference => !!ref)
|
||||
|
||||
// 重新为引用分配ID
|
||||
return allReferencesRaw.map((ref, index) => ({
|
||||
...ref,
|
||||
|
||||
@ -372,7 +372,7 @@ export interface KnowledgeBase {
|
||||
chunkOverlap?: number
|
||||
threshold?: number
|
||||
rerankModel?: Model
|
||||
topN?: number
|
||||
// topN?: number
|
||||
}
|
||||
|
||||
export type KnowledgeBaseParams = {
|
||||
@ -388,7 +388,7 @@ export type KnowledgeBaseParams = {
|
||||
rerankBaseURL?: string
|
||||
rerankModel?: string
|
||||
rerankModelProvider?: string
|
||||
topN?: number
|
||||
documentCount?: number
|
||||
}
|
||||
|
||||
export type GenerateImageParams = {
|
||||
|
||||
32
src/renderer/src/utils/highlighter.ts
Normal file
32
src/renderer/src/utils/highlighter.ts
Normal file
@ -0,0 +1,32 @@
|
||||
import { bundledLanguages, bundledThemes, createHighlighter, type Highlighter } from 'shiki'
|
||||
|
||||
let highlighterPromise: Promise<Highlighter> | null = null
|
||||
|
||||
export async function getHighlighter() {
|
||||
if (!highlighterPromise) {
|
||||
highlighterPromise = createHighlighter({
|
||||
langs: ['javascript', 'typescript', 'python', 'java', 'markdown', 'json'],
|
||||
themes: ['one-light', 'material-theme-darker']
|
||||
})
|
||||
}
|
||||
|
||||
return await highlighterPromise
|
||||
}
|
||||
|
||||
export async function loadLanguageIfNeeded(highlighter: Highlighter, language: string) {
|
||||
if (!highlighter.getLoadedLanguages().includes(language)) {
|
||||
const languageImportFn = bundledLanguages[language]
|
||||
if (languageImportFn) {
|
||||
await highlighter.loadLanguage(await languageImportFn())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export async function loadThemeIfNeeded(highlighter: Highlighter, theme: string) {
|
||||
if (!highlighter.getLoadedThemes().includes(theme)) {
|
||||
const themeImportFn = bundledThemes[theme]
|
||||
if (themeImportFn) {
|
||||
await highlighter.loadTheme(await themeImportFn())
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,11 +1,13 @@
|
||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||
import { ThemeMode } from '@renderer/types'
|
||||
import { MarkdownItShikiOptions, setupMarkdownIt } from '@shikijs/markdown-it'
|
||||
import { setupMarkdownIt } from '@shikijs/markdown-it'
|
||||
import MarkdownIt from 'markdown-it'
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { BuiltinLanguage, BuiltinTheme, bundledLanguages, createHighlighter } from 'shiki'
|
||||
import { getTokenStyleObject, ThemedToken } from 'shiki/core'
|
||||
|
||||
import { runAsyncFunction } from '.'
|
||||
import { getHighlighter } from './highlighter'
|
||||
|
||||
/**
|
||||
* Shiki token 样式转换为 React 样式对象
|
||||
*
|
||||
@ -44,19 +46,9 @@ const defaultOptions = {
|
||||
defaultColor: 'light'
|
||||
}
|
||||
|
||||
const initHighlighter = async (options: MarkdownItShikiOptions) => {
|
||||
const themeNames = ('themes' in options ? Object.values(options.themes) : [options.theme]).filter(
|
||||
Boolean
|
||||
) as BuiltinTheme[]
|
||||
return await createHighlighter({
|
||||
themes: themeNames,
|
||||
langs: options.langs || (Object.keys(bundledLanguages) as BuiltinLanguage[])
|
||||
})
|
||||
}
|
||||
export async function getShikiInstance(theme: ThemeMode) {
|
||||
const highlighter = await getHighlighter()
|
||||
|
||||
const highlighter = await initHighlighter(defaultOptions)
|
||||
|
||||
export function getShikiInstance(theme: ThemeMode) {
|
||||
const options = {
|
||||
...defaultOptions,
|
||||
defaultColor: theme
|
||||
@ -77,9 +69,11 @@ export function useShikiWithMarkdownIt(content: string) {
|
||||
)
|
||||
const { theme } = useTheme()
|
||||
useEffect(() => {
|
||||
const sk = getShikiInstance(theme)
|
||||
md.current.use(sk)
|
||||
setRenderedMarkdown(md.current.render(content))
|
||||
runAsyncFunction(async () => {
|
||||
const sk = await getShikiInstance(theme)
|
||||
md.current.use(sk)
|
||||
setRenderedMarkdown(md.current.render(content))
|
||||
})
|
||||
}, [content, theme])
|
||||
return {
|
||||
renderedMarkdown
|
||||
|
||||
Loading…
Reference in New Issue
Block a user