Merge branch 'develop' into mutiple-select

This commit is contained in:
自由的世界人 2025-05-19 15:48:43 +08:00 committed by GitHub
commit 76d1b6db3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 1627 additions and 837 deletions

View File

@ -92,9 +92,9 @@ artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo: releaseInfo:
releaseNotes: | releaseNotes: |
⚠️ 注意:升级前请备份数据,否则将无法降级 ⚠️ 注意:升级前请备份数据,否则将无法降级
重构消息结构,支持不同类型消息按时间顺序显示 优化软件启动速度
智能体支持导入和导出 优化软件进入后台后性能问题
快捷面板增加网络搜索引擎选择 修复导出对话时自动重命名失败问题
显示设置增加缩放控制按钮 防止输入法切换期间误发消息问题
支持添加自定义小程序 修复群组消息重发功能问题及富文本粘贴兼容性问题
性能优化和错误修复 改进 MCP 服务处理及 IPC 注册逻辑

View File

@ -85,12 +85,19 @@ export default defineConfig({
miniWindow: resolve(__dirname, 'src/renderer/miniWindow.html') miniWindow: resolve(__dirname, 'src/renderer/miniWindow.html')
}, },
output: { output: {
manualChunks: (id) => { manualChunks: (id: string) => {
// 检测所有 worker 文件,提取 worker 名称作为 chunk 名 // 检测所有 worker 文件,提取 worker 名称作为 chunk 名
if (id.includes('.worker') && id.endsWith('?worker')) { if (id.includes('.worker') && id.endsWith('?worker')) {
const workerName = id.split('/').pop()?.split('.')[0] || 'worker' const workerName = id.split('/').pop()?.split('.')[0] || 'worker'
return `workers/${workerName}` return `workers/${workerName}`
} }
// All node_modules are in the vendor chunk
if (id.includes('node_modules')) {
return 'vendor'
}
// Other modules use default chunk splitting strategy
return undefined return undefined
} }
} }

View File

@ -1,6 +1,6 @@
{ {
"name": "CherryStudio", "name": "CherryStudio",
"version": "1.3.5", "version": "1.3.6",
"private": true, "private": true,
"description": "A powerful AI assistant for producer.", "description": "A powerful AI assistant for producer.",
"main": "./out/main/index.js", "main": "./out/main/index.js",

View File

@ -2,12 +2,11 @@ import '@main/config'
import { electronApp, optimizer } from '@electron-toolkit/utils' import { electronApp, optimizer } from '@electron-toolkit/utils'
import { replaceDevtoolsFont } from '@main/utils/windowUtil' import { replaceDevtoolsFont } from '@main/utils/windowUtil'
import { IpcChannel } from '@shared/IpcChannel' import { app } from 'electron'
import { app, BrowserWindow, ipcMain } from 'electron'
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer' import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
import Logger from 'electron-log' import Logger from 'electron-log'
import { isDev, isMac, isWin } from './constant' import { isDev } from './constant'
import { registerIpc } from './ipc' import { registerIpc } from './ipc'
import { configManager } from './services/ConfigManager' import { configManager } from './services/ConfigManager'
import mcpService from './services/MCPService' import mcpService from './services/MCPService'
@ -85,18 +84,6 @@ if (!app.requestSingleInstanceLock()) {
.then((name) => console.log(`Added Extension: ${name}`)) .then((name) => console.log(`Added Extension: ${name}`))
.catch((err) => console.log('An error occurred: ', err)) .catch((err) => console.log('An error occurred: ', err))
} }
ipcMain.handle(IpcChannel.System_GetDeviceType, () => {
return isMac ? 'mac' : isWin ? 'windows' : 'linux'
})
ipcMain.handle(IpcChannel.System_GetHostname, () => {
return require('os').hostname()
})
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
const win = BrowserWindow.fromWebContents(e.sender)
win && win.webContents.toggleDevTools()
})
}) })
registerProtocolClient(app) registerProtocolClient(app)
@ -128,7 +115,7 @@ if (!app.requestSingleInstanceLock()) {
app.on('will-quit', async () => { app.on('will-quit', async () => {
// event.preventDefault() // event.preventDefault()
try { try {
await mcpService.cleanup() await mcpService().cleanup()
} catch (error) { } catch (error) {
Logger.error('Error cleaning up MCP service:', error) Logger.error('Error cleaning up MCP service:', error)
} }

View File

@ -19,7 +19,7 @@ import FileService from './services/FileService'
import FileStorage from './services/FileStorage' import FileStorage from './services/FileStorage'
import { GeminiService } from './services/GeminiService' import { GeminiService } from './services/GeminiService'
import KnowledgeService from './services/KnowledgeService' import KnowledgeService from './services/KnowledgeService'
import mcpService from './services/MCPService' import { getMcpInstance } from './services/MCPService'
import * as NutstoreService from './services/NutstoreService' import * as NutstoreService from './services/NutstoreService'
import ObsidianVaultService from './services/ObsidianVaultService' import ObsidianVaultService from './services/ObsidianVaultService'
import { ProxyConfig, proxyManager } from './services/ProxyManager' import { ProxyConfig, proxyManager } from './services/ProxyManager'
@ -204,6 +204,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.Zip_Compress, (_, text: string) => compress(text)) ipcMain.handle(IpcChannel.Zip_Compress, (_, text: string) => compress(text))
ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text)) ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text))
// system
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
const win = BrowserWindow.fromWebContents(e.sender)
win && win.webContents.toggleDevTools()
})
// backup // backup
ipcMain.handle(IpcChannel.Backup_Backup, backupManager.backup) ipcMain.handle(IpcChannel.Backup_Backup, backupManager.backup)
ipcMain.handle(IpcChannel.Backup_Restore, backupManager.restore) ipcMain.handle(IpcChannel.Backup_Restore, backupManager.restore)
@ -301,16 +309,16 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
) )
// Register MCP handlers // Register MCP handlers
ipcMain.handle(IpcChannel.Mcp_RemoveServer, mcpService.removeServer) ipcMain.handle(IpcChannel.Mcp_RemoveServer, (event, server) => getMcpInstance().removeServer(event, server))
ipcMain.handle(IpcChannel.Mcp_RestartServer, mcpService.restartServer) ipcMain.handle(IpcChannel.Mcp_RestartServer, (event, server) => getMcpInstance().restartServer(event, server))
ipcMain.handle(IpcChannel.Mcp_StopServer, mcpService.stopServer) ipcMain.handle(IpcChannel.Mcp_StopServer, (event, server) => getMcpInstance().stopServer(event, server))
ipcMain.handle(IpcChannel.Mcp_ListTools, mcpService.listTools) ipcMain.handle(IpcChannel.Mcp_ListTools, (event, server) => getMcpInstance().listTools(event, server))
ipcMain.handle(IpcChannel.Mcp_CallTool, mcpService.callTool) ipcMain.handle(IpcChannel.Mcp_CallTool, (event, params) => getMcpInstance().callTool(event, params))
ipcMain.handle(IpcChannel.Mcp_ListPrompts, mcpService.listPrompts) ipcMain.handle(IpcChannel.Mcp_ListPrompts, (event, server) => getMcpInstance().listPrompts(event, server))
ipcMain.handle(IpcChannel.Mcp_GetPrompt, mcpService.getPrompt) ipcMain.handle(IpcChannel.Mcp_GetPrompt, (event, params) => getMcpInstance().getPrompt(event, params))
ipcMain.handle(IpcChannel.Mcp_ListResources, mcpService.listResources) ipcMain.handle(IpcChannel.Mcp_ListResources, (event, server) => getMcpInstance().listResources(event, server))
ipcMain.handle(IpcChannel.Mcp_GetResource, mcpService.getResource) ipcMain.handle(IpcChannel.Mcp_GetResource, (event, params) => getMcpInstance().getResource(event, params))
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo) ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, () => getMcpInstance().getInstallInfo())
ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name)) ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name))
ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name)) ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name))

View File

@ -38,7 +38,7 @@ export default abstract class BaseReranker {
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) { protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
const provider = this.base.rerankModelProvider const provider = this.base.rerankModelProvider
const documents = searchResults.map((doc) => doc.pageContent) const documents = searchResults.map((doc) => doc.pageContent)
const topN = this.base.topN || 10 const topN = this.base.documentCount
if (provider === 'voyageai') { if (provider === 'voyageai') {
return { return {

View File

@ -4,6 +4,7 @@ import path from 'node:path'
import { createInMemoryMCPServer } from '@main/mcpServers/factory' import { createInMemoryMCPServer } from '@main/mcpServers/factory'
import { makeSureDirExists } from '@main/utils' import { makeSureDirExists } from '@main/utils'
import { buildFunctionCallToolName } from '@main/utils/mcp'
import { getBinaryName, getBinaryPath } from '@main/utils/process' import { getBinaryName, getBinaryPath } from '@main/utils/process'
import { Client } from '@modelcontextprotocol/sdk/client/index.js' import { Client } from '@modelcontextprotocol/sdk/client/index.js'
import { SSEClientTransport, SSEClientTransportOptions } from '@modelcontextprotocol/sdk/client/sse.js' import { SSEClientTransport, SSEClientTransportOptions } from '@modelcontextprotocol/sdk/client/sse.js'
@ -68,20 +69,18 @@ function withCache<T extends unknown[], R>(
} }
class McpService { class McpService {
private static instance: McpService | null = null
private clients: Map<string, Client> = new Map() private clients: Map<string, Client> = new Map()
private pendingClients: Map<string, Promise<Client>> = new Map()
private getServerKey(server: MCPServer): string { public static getInstance(): McpService {
return JSON.stringify({ if (!McpService.instance) {
baseUrl: server.baseUrl, McpService.instance = new McpService()
command: server.command, }
args: server.args, return McpService.instance
registryUrl: server.registryUrl,
env: server.env,
id: server.id
})
} }
constructor() { private constructor() {
this.initClient = this.initClient.bind(this) this.initClient = this.initClient.bind(this)
this.listTools = this.listTools.bind(this) this.listTools = this.listTools.bind(this)
this.callTool = this.callTool.bind(this) this.callTool = this.callTool.bind(this)
@ -96,9 +95,26 @@ class McpService {
this.cleanup = this.cleanup.bind(this) this.cleanup = this.cleanup.bind(this)
} }
private getServerKey(server: MCPServer): string {
return JSON.stringify({
baseUrl: server.baseUrl,
command: server.command,
args: server.args,
registryUrl: server.registryUrl,
env: server.env,
id: server.id
})
}
async initClient(server: MCPServer): Promise<Client> { async initClient(server: MCPServer): Promise<Client> {
const serverKey = this.getServerKey(server) const serverKey = this.getServerKey(server)
// If there's a pending initialization, wait for it
const pendingClient = this.pendingClients.get(serverKey)
if (pendingClient) {
return pendingClient
}
// Check if we already have a client for this server configuration // Check if we already have a client for this server configuration
const existingClient = this.clients.get(serverKey) const existingClient = this.clients.get(serverKey)
if (existingClient) { if (existingClient) {
@ -113,209 +129,226 @@ class McpService {
} else { } else {
return existingClient return existingClient
} }
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Error pinging server ${server.name}:`, error) Logger.error(`[MCP] Error pinging server ${server.name}:`, error?.message)
this.clients.delete(serverKey) this.clients.delete(serverKey)
} }
} }
// Create new client instance for each connection
const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} })
const args = [...(server.args || [])] // Create a promise for the initialization process
const initPromise = (async () => {
try {
// Create new client instance for each connection
const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} })
// let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport const args = [...(server.args || [])]
const authProvider = new McpOAuthClientProvider({
serverUrlHash: crypto
.createHash('md5')
.update(server.baseUrl || '')
.digest('hex')
})
const initTransport = async (): Promise< // let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport const authProvider = new McpOAuthClientProvider({
> => { serverUrlHash: crypto
// Create appropriate transport based on configuration .createHash('md5')
if (server.type === 'inMemory') { .update(server.baseUrl || '')
Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`) .digest('hex')
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair() })
// start the in-memory server with the given name and environment variables
const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {})
try {
await inMemoryServer.connect(serverTransport)
Logger.info(`[MCP] In-memory server started: ${server.name}`)
} catch (error: Error | any) {
Logger.error(`[MCP] Error starting in-memory server: ${error}`)
throw new Error(`Failed to start in-memory server: ${error.message}`)
}
// set the client transport to the client
return clientTransport
} else if (server.baseUrl) {
if (server.type === 'streamableHttp') {
const options: StreamableHTTPClientTransportOptions = {
requestInit: {
headers: server.headers || {}
},
authProvider
}
return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
} else if (server.type === 'sse') {
const options: SSEClientTransportOptions = {
eventSourceInit: {
fetch: async (url, init) => {
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
// Get tokens from authProvider to make sure using the latest tokens const initTransport = async (): Promise<
if (authProvider && typeof authProvider.tokens === 'function') { StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
try { > => {
const tokens = await authProvider.tokens() // Create appropriate transport based on configuration
if (tokens && tokens.access_token) { if (server.type === 'inMemory') {
headers['Authorization'] = `Bearer ${tokens.access_token}` Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`)
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair()
// start the in-memory server with the given name and environment variables
const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {})
try {
await inMemoryServer.connect(serverTransport)
Logger.info(`[MCP] In-memory server started: ${server.name}`)
} catch (error: Error | any) {
Logger.error(`[MCP] Error starting in-memory server: ${error}`)
throw new Error(`Failed to start in-memory server: ${error.message}`)
}
// set the client transport to the client
return clientTransport
} else if (server.baseUrl) {
if (server.type === 'streamableHttp') {
const options: StreamableHTTPClientTransportOptions = {
requestInit: {
headers: server.headers || {}
},
authProvider
}
return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
} else if (server.type === 'sse') {
const options: SSEClientTransportOptions = {
eventSourceInit: {
fetch: async (url, init) => {
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
// Get tokens from authProvider to make sure using the latest tokens
if (authProvider && typeof authProvider.tokens === 'function') {
try {
const tokens = await authProvider.tokens()
if (tokens && tokens.access_token) {
headers['Authorization'] = `Bearer ${tokens.access_token}`
}
} catch (error) {
Logger.error('Failed to fetch tokens:', error)
}
} }
} catch (error) {
Logger.error('Failed to fetch tokens:', error) return fetch(url, { ...init, headers })
} }
},
requestInit: {
headers: server.headers || {}
},
authProvider
}
return new SSEClientTransport(new URL(server.baseUrl!), options)
} else {
throw new Error('Invalid server type')
}
} else if (server.command) {
let cmd = server.command
if (server.command === 'npx') {
cmd = await getBinaryPath('bun')
Logger.info(`[MCP] Using command: ${cmd}`)
// add -x to args if args exist
if (args && args.length > 0) {
if (!args.includes('-y')) {
args.unshift('-y')
}
if (!args.includes('x')) {
args.unshift('x')
}
}
if (server.registryUrl) {
server.env = {
...server.env,
NPM_CONFIG_REGISTRY: server.registryUrl
} }
return fetch(url, { ...init, headers }) // if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory
if (server.name.includes('mcp-auto-install')) {
const binPath = await getBinaryPath()
makeSureDirExists(binPath)
server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json')
}
}
} else if (server.command === 'uvx' || server.command === 'uv') {
cmd = await getBinaryPath(server.command)
if (server.registryUrl) {
server.env = {
...server.env,
UV_DEFAULT_INDEX: server.registryUrl,
PIP_INDEX_URL: server.registryUrl
}
} }
},
requestInit: {
headers: server.headers || {}
},
authProvider
}
return new SSEClientTransport(new URL(server.baseUrl!), options)
} else {
throw new Error('Invalid server type')
}
} else if (server.command) {
let cmd = server.command
if (server.command === 'npx') {
cmd = await getBinaryPath('bun')
Logger.info(`[MCP] Using command: ${cmd}`)
// add -x to args if args exist
if (args && args.length > 0) {
if (!args.includes('-y')) {
args.unshift('-y')
}
if (!args.includes('x')) {
args.unshift('x')
}
}
if (server.registryUrl) {
server.env = {
...server.env,
NPM_CONFIG_REGISTRY: server.registryUrl
} }
// if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`)
if (server.name.includes('mcp-auto-install')) { // Logger.info(`[MCP] Environment variables for server:`, server.env)
const binPath = await getBinaryPath() const loginShellEnv = await this.getLoginShellEnv()
makeSureDirExists(binPath) const stdioTransport = new StdioClientTransport({
server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json') command: cmd,
} args,
} env: {
} else if (server.command === 'uvx' || server.command === 'uv') { ...loginShellEnv,
cmd = await getBinaryPath(server.command) ...server.env
if (server.registryUrl) { },
server.env = { stderr: 'pipe'
...server.env, })
UV_DEFAULT_INDEX: server.registryUrl, stdioTransport.stderr?.on('data', (data) =>
PIP_INDEX_URL: server.registryUrl Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
} )
return stdioTransport
} else {
throw new Error('Either baseUrl or command must be provided')
} }
} }
Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`) const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => {
// Logger.info(`[MCP] Environment variables for server:`, server.env) Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`)
const loginShellEnv = await this.getLoginShellEnv() // Create an event emitter for the OAuth callback
const stdioTransport = new StdioClientTransport({ const events = new EventEmitter()
command: cmd,
args,
env: {
...loginShellEnv,
...server.env
},
stderr: 'pipe'
})
stdioTransport.stderr?.on('data', (data) =>
Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
)
return stdioTransport
} else {
throw new Error('Either baseUrl or command must be provided')
}
}
const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => { // Create a callback server
Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`) const callbackServer = new CallBackServer({
// Create an event emitter for the OAuth callback port: authProvider.config.callbackPort,
const events = new EventEmitter() path: authProvider.config.callbackPath || '/oauth/callback',
events
})
// Create a callback server // Set a timeout to close the callback server
const callbackServer = new CallBackServer({ const timeoutId = setTimeout(() => {
port: authProvider.config.callbackPort, Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`)
path: authProvider.config.callbackPath || '/oauth/callback', callbackServer.close()
events }, 300000) // 5 minutes timeout
})
// Set a timeout to close the callback server try {
const timeoutId = setTimeout(() => { // Wait for the authorization code
Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`) const authCode = await callbackServer.waitForAuthCode()
callbackServer.close() Logger.info(`[MCP] Received auth code: ${authCode}`)
}, 300000) // 5 minutes timeout
try { // Complete the OAuth flow
// Wait for the authorization code await transport.finishAuth(authCode)
const authCode = await callbackServer.waitForAuthCode()
Logger.info(`[MCP] Received auth code: ${authCode}`)
// Complete the OAuth flow Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`)
await transport.finishAuth(authCode)
Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`) const newTransport = await initTransport()
// Try to connect again
await client.connect(newTransport)
const newTransport = await initTransport() Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`)
// Try to connect again } catch (oauthError) {
await client.connect(newTransport) Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError)
throw new Error(
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}`
)
} finally {
// Clear the timeout and close the callback server
clearTimeout(timeoutId)
callbackServer.close()
}
}
Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`) try {
} catch (oauthError) { const transport = await initTransport()
Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError) try {
throw new Error( await client.connect(transport)
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}` } catch (error: Error | any) {
) if (
error instanceof Error &&
(error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))
) {
Logger.info(`[MCP] Authentication required for server: ${server.name}`)
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
} else {
throw error
}
}
// Store the new client in the cache
this.clients.set(serverKey, client)
Logger.info(`[MCP] Activated server: ${server.name}`)
return client
} catch (error: any) {
Logger.error(`[MCP] Error activating server ${server.name}:`, error?.message)
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
}
} finally { } finally {
// Clear the timeout and close the callback server // Clean up the pending promise when done
clearTimeout(timeoutId) this.pendingClients.delete(serverKey)
callbackServer.close()
} }
} })()
try { // Store the pending promise
const transport = await initTransport() this.pendingClients.set(serverKey, initPromise)
try {
await client.connect(transport)
} catch (error: Error | any) {
if (error instanceof Error && (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))) {
Logger.info(`[MCP] Authentication required for server: ${server.name}`)
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
} else {
throw error
}
}
// Store the new client in the cache return initPromise
this.clients.set(serverKey, client)
Logger.info(`[MCP] Activated server: ${server.name}`)
return client
} catch (error: any) {
Logger.error(`[MCP] Error activating server ${server.name}:`, error)
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
}
} }
async closeClient(serverKey: string) { async closeClient(serverKey: string) {
@ -357,8 +390,8 @@ class McpService {
for (const [key] of this.clients) { for (const [key] of this.clients) {
try { try {
await this.closeClient(key) await this.closeClient(key)
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Failed to close client: ${error}`) Logger.error(`[MCP] Failed to close client: ${error?.message}`)
} }
} }
} }
@ -372,15 +405,15 @@ class McpService {
tools.map((tool: any) => { tools.map((tool: any) => {
const serverTool: MCPTool = { const serverTool: MCPTool = {
...tool, ...tool,
id: `f${nanoid()}`, id: buildFunctionCallToolName(server.name, tool.name),
serverId: server.id, serverId: server.id,
serverName: server.name serverName: server.name
} }
serverTools.push(serverTool) serverTools.push(serverTool)
}) })
return serverTools return serverTools
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error) Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error?.message)
return [] return []
} }
} }
@ -439,8 +472,8 @@ class McpService {
* List prompts available on an MCP server * List prompts available on an MCP server
*/ */
private async listPromptsImpl(server: MCPServer): Promise<MCPPrompt[]> { private async listPromptsImpl(server: MCPServer): Promise<MCPPrompt[]> {
Logger.info(`[MCP] Listing prompts for server: ${server.name}`)
const client = await this.initClient(server) const client = await this.initClient(server)
Logger.info(`[MCP] Listing prompts for server: ${server.name}`)
try { try {
const { prompts } = await client.listPrompts() const { prompts } = await client.listPrompts()
return prompts.map((prompt: any) => ({ return prompts.map((prompt: any) => ({
@ -449,8 +482,11 @@ class McpService {
serverId: server.id, serverId: server.id,
serverName: server.name serverName: server.name
})) }))
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error) // -32601 is the code for the method not found
if (error?.code !== -32601) {
Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error?.message)
}
return [] return []
} }
} }
@ -508,8 +544,8 @@ class McpService {
* List resources available on an MCP server (implementation) * List resources available on an MCP server (implementation)
*/ */
private async listResourcesImpl(server: MCPServer): Promise<MCPResource[]> { private async listResourcesImpl(server: MCPServer): Promise<MCPResource[]> {
Logger.info(`[MCP] Listing resources for server: ${server.name}`)
const client = await this.initClient(server) const client = await this.initClient(server)
Logger.info(`[MCP] Listing resources for server: ${server.name}`)
try { try {
const result = await client.listResources() const result = await client.listResources()
const resources = result.resources || [] const resources = result.resources || []
@ -519,8 +555,11 @@ class McpService {
serverName: server.name serverName: server.name
})) }))
return serverResources return serverResources
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error) // -32601 is the code for the method not found
if (error?.code !== -32601) {
Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error?.message)
}
return [] return []
} }
} }
@ -563,7 +602,7 @@ class McpService {
contents: contents contents: contents
} }
} catch (error: Error | any) { } catch (error: Error | any) {
Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error) Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error.message)
throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`) throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`)
} }
} }
@ -602,5 +641,13 @@ class McpService {
}) })
} }
const mcpService = new McpService() let mcpInstance: ReturnType<typeof McpService.getInstance> | null = null
export default mcpService
export const getMcpInstance = () => {
if (!mcpInstance) {
mcpInstance = McpService.getInstance()
}
return mcpInstance
}
export default McpService.getInstance

View File

@ -75,7 +75,8 @@ export class WindowService {
sandbox: false, sandbox: false,
webSecurity: false, webSecurity: false,
webviewTag: true, webviewTag: true,
allowRunningInsecureContent: true allowRunningInsecureContent: true,
backgroundThrottling: false
} }
}) })
@ -323,6 +324,11 @@ export class WindowService {
event.preventDefault() event.preventDefault()
if (mainWindow.isFullScreen()) {
mainWindow.setFullScreen(false)
return
}
mainWindow.hide() mainWindow.hide()
//for mac users, should hide dock icon if close to tray //for mac users, should hide dock icon if close to tray
@ -440,7 +446,8 @@ export class WindowService {
preload: join(__dirname, '../preload/index.js'), preload: join(__dirname, '../preload/index.js'),
sandbox: false, sandbox: false,
webSecurity: false, webSecurity: false,
webviewTag: true webviewTag: true,
backgroundThrottling: false
} }
}) })

View File

@ -47,7 +47,7 @@ function getLoginShellEnvironment(): Promise<Record<string, string>> {
commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command
} }
Logger.log(`Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`) Logger.log(`[ShellEnv] Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`)
const child = spawn(shellPath, commandArgs, { const child = spawn(shellPath, commandArgs, {
cwd: homeDirectory, // Run the command in the user's home directory cwd: homeDirectory, // Run the command in the user's home directory

34
src/main/utils/mcp.ts Normal file
View File

@ -0,0 +1,34 @@
export function buildFunctionCallToolName(serverName: string, toolName: string) {
const sanitizedServer = serverName.trim().replace(/-/g, '_')
const sanitizedTool = toolName.trim().replace(/-/g, '_')
// Combine server name and tool name
let name = sanitizedTool
if (!sanitizedTool.includes(sanitizedServer.slice(0, 7))) {
name = `${sanitizedServer.slice(0, 7) || ''}-${sanitizedTool || ''}`
}
// Replace invalid characters with underscores or dashes
// Keep a-z, A-Z, 0-9, underscores and dashes
name = name.replace(/[^a-zA-Z0-9_-]/g, '_')
// Ensure name starts with a letter or underscore (for valid JavaScript identifier)
if (!/^[a-zA-Z]/.test(name)) {
name = `tool-${name}`
}
// Remove consecutive underscores/dashes (optional improvement)
name = name.replace(/[_-]{2,}/g, '_')
// Truncate to 63 characters maximum
if (name.length > 63) {
name = name.slice(0, 63)
}
// Handle edge case: ensure we still have a valid name if truncation left invalid chars at edges
if (name.endsWith('_') || name.endsWith('-')) {
name = name.slice(0, -1)
}
return name
}

View File

@ -18,3 +18,32 @@ vi.mock('electron-log/renderer', () => {
} }
} }
}) })
vi.stubGlobal('window', {
electron: {
ipcRenderer: {
on: vi.fn(), // Mocking ipcRenderer.on
send: vi.fn() // Mocking ipcRenderer.send
}
},
api: {
file: {
read: vi.fn().mockResolvedValue('[]'), // Mock file.read to return an empty array (you can customize this)
writeWithId: vi.fn().mockResolvedValue(undefined) // Mock file.writeWithId to do nothing
}
}
})
vi.mock('axios', () => ({
default: {
get: vi.fn().mockResolvedValue({ data: {} }), // Mocking axios GET request
post: vi.fn().mockResolvedValue({ data: {} }) // Mocking axios POST request
// You can add other axios methods like put, delete etc. as needed
}
}))
vi.stubGlobal('window', {
...global.window, // Copy other global properties
addEventListener: vi.fn(), // Mock addEventListener
removeEventListener: vi.fn() // You can also mock removeEventListener if needed
})

View File

@ -5,7 +5,7 @@
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="initial-scale=1, width=device-width" /> <meta name="viewport" content="initial-scale=1, width=device-width" />
<meta http-equiv="Content-Security-Policy" <meta http-equiv="Content-Security-Policy"
content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; frame-src * file:" /> content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' 'unsafe-inline' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; frame-src * file:" />
<title>Cherry Studio</title> <title>Cherry Studio</title>
<style> <style>
@ -21,7 +21,7 @@
flex-direction: row; flex-direction: row;
justify-content: center; justify-content: center;
align-items: center; align-items: center;
display: none; display: flex;
} }
#spinner img { #spinner img {
@ -36,6 +36,9 @@
<div id="spinner"> <div id="spinner">
<img src="/src/assets/images/logo.png" /> <img src="/src/assets/images/logo.png" />
</div> </div>
<script>
console.time('init')
</script>
<script type="module" src="/src/init.ts"></script> <script type="module" src="/src/init.ts"></script>
<script type="module" src="/src/entryPoint.tsx"></script> <script type="module" src="/src/entryPoint.tsx"></script>
</body> </body>

View File

@ -8,13 +8,6 @@
@import '../fonts/icon-fonts/iconfont.css'; @import '../fonts/icon-fonts/iconfont.css';
@import '../fonts/ubuntu/ubuntu.css'; @import '../fonts/ubuntu/ubuntu.css';
body {
-webkit-user-select: none;
-moz-user-select: none;
-ms-user-select: none;
user-select: none;
}
*, *,
*::before, *::before,
*::after { *::after {
@ -31,8 +24,18 @@ body {
-webkit-tap-highlight-color: transparent; -webkit-tap-highlight-color: transparent;
} }
ul { html,
list-style: none; body,
#root {
height: 100%;
width: 100%;
margin: 0;
}
#root {
display: flex;
flex-direction: row;
flex: 1;
} }
body { body {
@ -44,9 +47,15 @@ body {
overflow: hidden; overflow: hidden;
font-family: var(--font-family); font-family: var(--font-family);
text-rendering: optimizeLegibility; text-rendering: optimizeLegibility;
transition: background-color 0.3s linear;
-webkit-font-smoothing: antialiased; -webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale; -moz-osx-font-smoothing: grayscale;
transition: background-color 0.3s linear;
-webkit-user-select: none;
-moz-user-select: none;
-ms-user-select: none;
user-select: none;
} }
input, input,
@ -67,20 +76,8 @@ a {
-webkit-user-drag: none; -webkit-user-drag: none;
} }
html, ul {
body, list-style: none;
#root {
height: 100%;
width: 100%;
margin: 0;
}
#root {
width: 100%;
height: 100%;
display: flex;
flex-direction: row;
flex: 1;
} }
.loader { .loader {

View File

@ -228,7 +228,7 @@ const ContentContainer = styled.div<{
$isCodeWrappable: boolean $isCodeWrappable: boolean
}>` }>`
position: relative; position: relative;
border: 0.5px solid var(--color-code-background); border: 0.5px solid transparent;
border-radius: 5px; border-radius: 5px;
margin-top: 0; margin-top: 0;
transition: opacity 0.3s ease; transition: opacity 0.3s ease;

View File

@ -219,7 +219,7 @@ const CodeEditor = ({ children, language, onSave, onChange, maxHeight, options,
fontSize: `${fontSize - 1}px`, fontSize: `${fontSize - 1}px`,
overflow: collapsible && !isExpanded ? 'auto' : 'visible', overflow: collapsible && !isExpanded ? 'auto' : 'visible',
position: 'relative', position: 'relative',
border: '0.5px solid var(--color-code-background)', border: '0.5px solid transparent',
borderRadius: '5px', borderRadius: '5px',
marginTop: 0, marginTop: 0,
...style ...style

View File

@ -236,8 +236,8 @@ export const ContentSearch = React.forwardRef<ContentSearchRef, Props>(
if (shouldScroll) { if (shouldScroll) {
highlightTextNode.scrollIntoView({ highlightTextNode.scrollIntoView({
behavior: 'smooth', behavior: 'smooth',
block: 'center', block: 'center'
inline: 'center' // inline: 'center' 水平方向居中可能会导致 content 页面整体偏右, 使得左半部的内容被遮挡. 因此先注释掉该代码
}) })
} }
} }

View File

@ -0,0 +1,90 @@
import HomeTabs from '@renderer/pages/home/Tabs/index'
import { Assistant, Topic } from '@renderer/types'
import { Popover } from 'antd'
import { FC, useEffect, useState } from 'react'
import { useHotkeys } from 'react-hotkeys-hook'
import styled from 'styled-components'
import Scrollbar from '../Scrollbar'
interface Props {
children: React.ReactNode
activeAssistant: Assistant
setActiveAssistant: (assistant: Assistant) => void
activeTopic: Topic
setActiveTopic: (topic: Topic) => void
position: 'left' | 'right'
}
const FloatingSidebar: FC<Props> = ({
children,
activeAssistant,
setActiveAssistant,
activeTopic,
setActiveTopic,
position = 'left'
}) => {
const [open, setOpen] = useState(false)
useHotkeys('esc', () => {
setOpen(false)
})
const [maxHeight, setMaxHeight] = useState(Math.floor(window.innerHeight * 0.75))
useEffect(() => {
const handleResize = () => {
setMaxHeight(Math.floor(window.innerHeight * 0.75))
}
window.addEventListener('resize', handleResize)
return () => {
window.removeEventListener('resize', handleResize)
}
}, [])
const content = (
<PopoverContent maxHeight={maxHeight}>
<HomeTabs
activeAssistant={activeAssistant}
activeTopic={activeTopic}
setActiveAssistant={setActiveAssistant}
setActiveTopic={setActiveTopic}
position={position}
forceToSeeAllTab={true}></HomeTabs>
</PopoverContent>
)
return (
<Popover
open={open}
onOpenChange={(visible) => {
setOpen(visible)
}}
content={content}
trigger={['hover', 'click']}
placement="bottomRight"
arrow={false}
mouseEnterDelay={0.8} // 800ms delay before showing
mouseLeaveDelay={20}
styles={{
body: {
padding: 0,
background: 'var(--color-background)',
border: '1px solid var(--color-border)',
borderRadius: '8px',
boxShadow: '0 6px 16px 0 rgba(0, 0, 0, 0.08), 0 3px 6px -4px rgba(0, 0, 0, 0.12)'
}
}}>
{children}
</Popover>
)
}
const PopoverContent = styled(Scrollbar)<{ maxHeight: number }>`
max-height: ${(props) => props.maxHeight}px;
overflow-y: auto;
`
export default FloatingSidebar

View File

@ -2661,7 +2661,7 @@ export const THINKING_TOKEN_MAP: Record<string, { min: number; max: number }> =
'qwen-turbo-.*$': { min: 0, max: 38912 }, 'qwen-turbo-.*$': { min: 0, max: 38912 },
'qwen3-0\\.6b$': { min: 0, max: 30720 }, 'qwen3-0\\.6b$': { min: 0, max: 30720 },
'qwen3-1\\.7b$': { min: 0, max: 30720 }, 'qwen3-1\\.7b$': { min: 0, max: 30720 },
'qwen3-.*$': { min: 0, max: 38912 }, 'qwen3-.*$': { min: 1024, max: 38912 },
// Claude models // Claude models
'claude-3[.-]7.*sonnet.*$': { min: 0, max: 64000 } 'claude-3[.-]7.*sonnet.*$': { min: 0, max: 64000 }

View File

@ -0,0 +1,33 @@
import { createContext, ReactNode, use, useState } from 'react'
interface MessageEditingContextType {
editingMessageId: string | null
startEditing: (messageId: string) => void
stopEditing: () => void
}
const MessageEditingContext = createContext<MessageEditingContextType | null>(null)
export function MessageEditingProvider({ children }: { children: ReactNode }) {
const [editingMessageId, setEditingMessageId] = useState<string | null>(null)
const startEditing = (messageId: string) => {
setEditingMessageId(messageId)
}
const stopEditing = () => {
setEditingMessageId(null)
}
return (
<MessageEditingContext value={{ editingMessageId, startEditing, stopEditing }}>{children}</MessageEditingContext>
)
}
export function useMessageEditing() {
const context = use(MessageEditingContext)
if (!context) {
throw new Error('useMessageEditing must be used within a MessageEditingProvider')
}
return context
}

View File

@ -24,6 +24,11 @@ export function useAppInit() {
const avatar = useLiveQuery(() => db.settings.get('image://avatar')) const avatar = useLiveQuery(() => db.settings.get('image://avatar'))
const { theme } = useTheme() const { theme } = useTheme()
useEffect(() => {
document.getElementById('spinner')?.remove()
console.timeEnd('init')
}, [])
useUpdateHandler() useUpdateHandler()
useFullScreenNotice() useFullScreenNotice()
@ -32,7 +37,6 @@ export function useAppInit() {
}, [avatar, dispatch]) }, [avatar, dispatch])
useEffect(() => { useEffect(() => {
document.getElementById('spinner')?.remove()
runAsyncFunction(async () => { runAsyncFunction(async () => {
const { isPackaged } = await window.api.getAppInfo() const { isPackaged } = await window.api.getAppInfo()
if (isPackaged && autoCheckUpdate) { if (isPackaged && autoCheckUpdate) {

View File

@ -1,22 +1,25 @@
import { createSelector } from '@reduxjs/toolkit'
import store, { useAppDispatch, useAppSelector } from '@renderer/store' import store, { useAppDispatch, useAppSelector } from '@renderer/store'
import { addMCPServer, deleteMCPServer, setMCPServers, updateMCPServer } from '@renderer/store/mcp' import { addMCPServer, deleteMCPServer, setMCPServers, updateMCPServer } from '@renderer/store/mcp'
import { MCPServer } from '@renderer/types' import { MCPServer } from '@renderer/types'
import { IpcChannel } from '@shared/IpcChannel' import { IpcChannel } from '@shared/IpcChannel'
import { useMemo } from 'react'
const ipcRenderer = window.electron.ipcRenderer
// Listen for server changes from main process // Listen for server changes from main process
ipcRenderer.on(IpcChannel.Mcp_ServersChanged, (_event, servers) => { window.electron.ipcRenderer.on(IpcChannel.Mcp_ServersChanged, (_event, servers) => {
store.dispatch(setMCPServers(servers)) store.dispatch(setMCPServers(servers))
}) })
ipcRenderer.on(IpcChannel.Mcp_AddServer, (_event, server: MCPServer) => { window.electron.ipcRenderer.on(IpcChannel.Mcp_AddServer, (_event, server: MCPServer) => {
store.dispatch(addMCPServer(server)) store.dispatch(addMCPServer(server))
}) })
const selectMcpServers = (state) => state.mcp.servers
const selectActiveMcpServers = createSelector([selectMcpServers], (servers) =>
servers.filter((server) => server.isActive)
)
export const useMCPServers = () => { export const useMCPServers = () => {
const mcpServers = useAppSelector((state) => state.mcp.servers) const mcpServers = useAppSelector(selectMcpServers)
const activedMcpServers = useMemo(() => mcpServers.filter((server) => server.isActive), [mcpServers]) const activedMcpServers = useAppSelector(selectActiveMcpServers)
const dispatch = useAppDispatch() const dispatch = useAppDispatch()
return { return {

View File

@ -3,7 +3,7 @@ import Logger from '@renderer/config/logger'
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService' import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
import { estimateUserPromptUsage } from '@renderer/services/TokenService' import { estimateUserPromptUsage } from '@renderer/services/TokenService'
import store, { type RootState, useAppDispatch, useAppSelector } from '@renderer/store' import store, { type RootState, useAppDispatch, useAppSelector } from '@renderer/store'
import { messageBlocksSelectors, updateOneBlock } from '@renderer/store/messageBlock' import { updateOneBlock } from '@renderer/store/messageBlock'
import { newMessagesActions, selectMessagesForTopic } from '@renderer/store/newMessage' import { newMessagesActions, selectMessagesForTopic } from '@renderer/store/newMessage'
import { import {
appendAssistantResponseThunk, appendAssistantResponseThunk,
@ -13,6 +13,7 @@ import {
deleteSingleMessageThunk, deleteSingleMessageThunk,
initiateTranslationThunk, initiateTranslationThunk,
regenerateAssistantResponseThunk, regenerateAssistantResponseThunk,
removeBlocksThunk,
resendMessageThunk, resendMessageThunk,
resendUserMessageWithEditThunk, resendUserMessageWithEditThunk,
updateMessageAndBlocksThunk, updateMessageAndBlocksThunk,
@ -22,21 +23,8 @@ import type { Assistant, Model, Topic } from '@renderer/types'
import type { Message, MessageBlock } from '@renderer/types/newMessage' import type { Message, MessageBlock } from '@renderer/types/newMessage'
import { MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage' import { MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
import { abortCompletion } from '@renderer/utils/abortController' import { abortCompletion } from '@renderer/utils/abortController'
import { findFileBlocks } from '@renderer/utils/messageUtils/find'
import { useCallback } from 'react' import { useCallback } from 'react'
const findMainTextBlockId = (message: Message): string | undefined => {
if (!message || !message.blocks) return undefined
const state = store.getState()
for (const blockId of message.blocks) {
const block = messageBlocksSelectors.selectById(state, String(blockId))
if (block && block.type === MessageBlockType.MAIN_TEXT) {
return block.id
}
}
return undefined
}
const selectMessagesState = (state: RootState) => state.messages const selectMessagesState = (state: RootState) => state.messages
export const selectNewTopicLoading = createSelector( export const selectNewTopicLoading = createSelector(
@ -113,36 +101,6 @@ export function useMessageOperations(topic: Topic) {
[dispatch, topic.id] [dispatch, topic.id]
) )
/**
* / Resends a user message after its main text block has been edited.
* Dispatches resendUserMessageWithEditThunk.
*/
const resendUserMessageWithEdit = useCallback(
async (message: Message, editedContent: string, assistant: Assistant) => {
const mainTextBlockId = findMainTextBlockId(message)
if (!mainTextBlockId) {
console.error('Cannot resend edited message: Main text block not found.')
return
}
const files = findFileBlocks(message).map((block) => block.file)
const usage = await estimateUserPromptUsage({ content: editedContent, files })
const messageUpdates: Partial<Message> & Pick<Message, 'id'> = {
id: message.id,
updatedAt: new Date().toISOString(),
usage
}
await dispatch(
newMessagesActions.updateMessage({ topicId: topic.id, messageId: message.id, updates: messageUpdates })
)
// 对于message的修改会在下面的thunk中保存
await dispatch(resendUserMessageWithEditThunk(topic.id, message, mainTextBlockId, editedContent, assistant))
},
[dispatch, topic.id]
)
/** /**
* / Clears all messages for the current or specified topic. * / Clears all messages for the current or specified topic.
* Dispatches clearTopicMessagesThunk. * Dispatches clearTopicMessagesThunk.
@ -309,29 +267,127 @@ export function useMessageOperations(topic: Topic) {
) )
/** /**
* Updates properties of specific message blocks (e.g., content). * Updates message blocks by comparing original and edited blocks.
* Uses the generalized thunk for persistence. * Handles adding, updating, and removing blocks in a single operation.
* @param messageId The ID of the message to update
* @param editedBlocks The complete set of blocks after editing
*/ */
const editMessageBlocks = useCallback( const editMessageBlocks = useCallback(
async (messageId: string, updates: Partial<MessageBlock>) => { async (messageId: string, editedBlocks: MessageBlock[]) => {
if (!topic?.id) { if (!topic?.id) {
console.error('[editMessageBlocks] Topic prop is not valid.') console.error('[editMessageBlocks] Topic prop is not valid.')
return return
} }
const blockUpdatesListProcessed = { try {
updatedAt: new Date().toISOString(), // 1. Get the current state of the message and its blocks
...updates const state = store.getState()
} const message = state.messages.entities[messageId]
if (!message) {
console.error('[editMessageBlocks] Message not found:', messageId)
return
}
const messageUpdates: Partial<Message> & Pick<Message, 'id'> = { // 2. Get all original blocks
id: messageId, const originalBlocks = message.blocks
updatedAt: new Date().toISOString() ? (message.blocks
} .map((blockId) => state.messageBlocks.entities[blockId])
.filter((block) => block !== undefined) as MessageBlock[])
: []
await dispatch(updateMessageAndBlocksThunk(topic.id, messageUpdates, [blockUpdatesListProcessed])) // 3. Create sets for efficient comparison
const originalBlockIds = new Set(originalBlocks.map((block) => block.id))
const editedBlockIds = new Set(editedBlocks.map((block) => block.id))
// 4. Identify blocks to remove, update, and add
const blockIdsToRemove = originalBlocks
.filter((block) => !editedBlockIds.has(block.id))
.map((block) => block.id)
const blocksToUpdate = editedBlocks
.filter((block) => originalBlockIds.has(block.id))
.map((block) => ({
...block,
updatedAt: new Date().toISOString()
}))
const blocksToAdd = editedBlocks
.filter((block) => !originalBlockIds.has(block.id))
.map((block) => ({
...block,
updatedAt: new Date().toISOString()
}))
// 5. Prepare message update with new block IDs
const updatedBlockIds = editedBlocks.map((block) => block.id)
const messageUpdates: Partial<Message> & Pick<Message, 'id'> = {
id: messageId,
updatedAt: new Date().toISOString(),
blocks: updatedBlockIds
}
// 6. Log operations for debugging
console.log('[editMessageBlocks] Operations:', {
blocksToRemove: blockIdsToRemove.length,
blocksToUpdate: blocksToUpdate.length,
blocksToAdd: blocksToAdd.length
})
// 7. Update Redux state and database
// First update message and add/update blocks
if (blocksToAdd.length > 0) {
await dispatch(updateMessageAndBlocksThunk(topic.id, messageUpdates, blocksToAdd))
}
if (blocksToUpdate.length > 0) {
await dispatch(updateMessageAndBlocksThunk(topic.id, messageUpdates, blocksToUpdate))
}
// Then remove blocks if needed
if (blockIdsToRemove.length > 0) {
await dispatch(removeBlocksThunk(topic.id, messageId, blockIdsToRemove))
}
} catch (error) {
console.error('[editMessageBlocks] Failed to update message blocks:', error)
}
}, },
[dispatch, topic.id] [dispatch, topic?.id]
)
/**
* / Resends a user message after its main text block has been edited.
* Dispatches resendUserMessageWithEditThunk.
*/
const resendUserMessageWithEdit = useCallback(
async (message: Message, editedBlocks: MessageBlock[], assistant: Assistant) => {
await editMessageBlocks(message.id, editedBlocks)
const mainTextBlock = editedBlocks.find((block) => block.type === MessageBlockType.MAIN_TEXT)
if (!mainTextBlock) {
console.error('[resendUserMessageWithEdit] Main text block not found in edited blocks')
return
}
const fileBlocks = editedBlocks.filter(
(block) => block.type === MessageBlockType.FILE || block.type === MessageBlockType.IMAGE
)
const files = fileBlocks.map((block) => block.file).filter((file) => file !== undefined)
const usage = await estimateUserPromptUsage({ content: mainTextBlock.content, files })
const messageUpdates: Partial<Message> & Pick<Message, 'id'> = {
id: message.id,
updatedAt: new Date().toISOString(),
usage
}
await dispatch(
newMessagesActions.updateMessage({ topicId: topic.id, messageId: message.id, updates: messageUpdates })
)
// 对于message的修改会在下面的thunk中保存
await dispatch(resendUserMessageWithEditThunk(topic.id, message, assistant))
},
[dispatch, editMessageBlocks, topic.id]
) )
/** /**

View File

@ -181,6 +181,7 @@
"input.upload": "Upload image or document file", "input.upload": "Upload image or document file",
"input.upload.document": "Upload document file (model does not support images)", "input.upload.document": "Upload document file (model does not support images)",
"input.web_search": "Web search", "input.web_search": "Web search",
"input.web_search.settings": "Web Search Settings",
"input.web_search.button.ok": "Go to Settings", "input.web_search.button.ok": "Go to Settings",
"input.web_search.enable": "Enable web search", "input.web_search.enable": "Enable web search",
"input.web_search.enable_content": "Need to check web search connectivity in settings first", "input.web_search.enable_content": "Need to check web search connectivity in settings first",

View File

@ -181,6 +181,7 @@
"input.upload": "画像またはドキュメントをアップロード", "input.upload": "画像またはドキュメントをアップロード",
"input.upload.document": "ドキュメントをアップロード(モデルは画像をサポートしません)", "input.upload.document": "ドキュメントをアップロード(モデルは画像をサポートしません)",
"input.web_search": "ウェブ検索", "input.web_search": "ウェブ検索",
"input.web_search.settings": "ウェブ検索設定",
"input.web_search.button.ok": "設定に移動", "input.web_search.button.ok": "設定に移動",
"input.web_search.enable": "ウェブ検索を有効にする", "input.web_search.enable": "ウェブ検索を有効にする",
"input.web_search.enable_content": "ウェブ検索の接続性を先に設定で確認する必要があります", "input.web_search.enable_content": "ウェブ検索の接続性を先に設定で確認する必要があります",

View File

@ -181,6 +181,7 @@
"input.upload": "Загрузить изображение или документ", "input.upload": "Загрузить изображение или документ",
"input.upload.document": "Загрузить документ (модель не поддерживает изображения)", "input.upload.document": "Загрузить документ (модель не поддерживает изображения)",
"input.web_search": "Веб-поиск", "input.web_search": "Веб-поиск",
"input.web_search.settings": "Настройки веб-поиска",
"input.web_search.button.ok": "Перейти в Настройки", "input.web_search.button.ok": "Перейти в Настройки",
"input.web_search.enable": "Включить веб-поиск", "input.web_search.enable": "Включить веб-поиск",
"input.web_search.enable_content": "Необходимо предварительно проверить подключение к веб-поиску в настройках", "input.web_search.enable_content": "Необходимо предварительно проверить подключение к веб-поиску в настройках",

View File

@ -190,6 +190,7 @@
"input.upload.upload_from_local": "上传本地文件...", "input.upload.upload_from_local": "上传本地文件...",
"input.upload.document": "上传文档(模型不支持图片)", "input.upload.document": "上传文档(模型不支持图片)",
"input.web_search": "网络搜索", "input.web_search": "网络搜索",
"input.web_search.settings": "网络搜索设置",
"input.web_search.button.ok": "去设置", "input.web_search.button.ok": "去设置",
"input.web_search.enable": "开启网络搜索", "input.web_search.enable": "开启网络搜索",
"input.web_search.enable_content": "需要先在设置中检查网络搜索连通性", "input.web_search.enable_content": "需要先在设置中检查网络搜索连通性",

View File

@ -181,6 +181,7 @@
"input.upload": "上傳圖片或文件", "input.upload": "上傳圖片或文件",
"input.upload.document": "上傳文件(模型不支援圖片)", "input.upload.document": "上傳文件(模型不支援圖片)",
"input.web_search": "網路搜尋", "input.web_search": "網路搜尋",
"input.web_search.settings": "網路搜尋設定",
"input.web_search.button.ok": "去設定", "input.web_search.button.ok": "去設定",
"input.web_search.enable": "開啟網路搜尋", "input.web_search.enable": "開啟網路搜尋",
"input.web_search.enable_content": "需要先在設定中開啟網路搜尋", "input.web_search.enable_content": "需要先在設定中開啟網路搜尋",

View File

@ -5,13 +5,6 @@ import { startNutstoreAutoSync } from './services/NutstoreService'
import storeSyncService from './services/StoreSyncService' import storeSyncService from './services/StoreSyncService'
import store from './store' import store from './store'
function initSpinner() {
const spinner = document.getElementById('spinner')
if (spinner) {
spinner.style.display = 'flex'
}
}
function initKeyv() { function initKeyv() {
window.keyv = new KeyvStorage() window.keyv = new KeyvStorage()
window.keyv.init() window.keyv.init()
@ -34,7 +27,6 @@ function initStoreSync() {
storeSyncService.subscribe() storeSyncService.subscribe()
} }
initSpinner()
initKeyv() initKeyv()
initAutoSync() initAutoSync()
initStoreSync() initStoreSync()

View File

@ -1,6 +1,7 @@
import { useAssistants } from '@renderer/hooks/useAssistant' import { useAssistants } from '@renderer/hooks/useAssistant'
import { useSettings } from '@renderer/hooks/useSettings' import { useSettings } from '@renderer/hooks/useSettings'
import { useActiveTopic } from '@renderer/hooks/useTopic' import { useActiveTopic } from '@renderer/hooks/useTopic'
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
import NavigationService from '@renderer/services/NavigationService' import NavigationService from '@renderer/services/NavigationService'
import { Assistant } from '@renderer/types' import { Assistant } from '@renderer/types'
import { FC, useEffect, useState } from 'react' import { FC, useEffect, useState } from 'react'
@ -36,6 +37,19 @@ const HomePage: FC = () => {
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [state]) }, [state])
useEffect(() => {
const unsubscribe = EventEmitter.on(EVENT_NAMES.SWITCH_ASSISTANT, (assistantId: string) => {
const newAssistant = assistants.find((a) => a.id === assistantId)
if (newAssistant) {
setActiveAssistant(newAssistant)
}
})
return () => {
unsubscribe()
}
}, [assistants, setActiveAssistant])
useEffect(() => { useEffect(() => {
const canMinimize = topicPosition == 'left' ? !showAssistants : !showAssistants && !showTopics const canMinimize = topicPosition == 'left' ? !showAssistants : !showAssistants && !showTopics
window.api.window.setMinimumSize(canMinimize ? 520 : 1080, 600) window.api.window.setMinimumSize(canMinimize ? 520 : 1080, 600)
@ -47,7 +61,13 @@ const HomePage: FC = () => {
return ( return (
<Container id="home-page"> <Container id="home-page">
<Navbar activeAssistant={activeAssistant} activeTopic={activeTopic} setActiveTopic={setActiveTopic} /> <Navbar
activeAssistant={activeAssistant}
activeTopic={activeTopic}
setActiveTopic={setActiveTopic}
setActiveAssistant={setActiveAssistant}
position="left"
/>
<ContentContainer id="content-container"> <ContentContainer id="content-container">
{showAssistants && ( {showAssistants && (
<HomeTabs <HomeTabs

View File

@ -32,7 +32,55 @@ function truncateFileName(name: string, maxLength: number = MAX_FILENAME_DISPLAY
return name.slice(0, maxLength - 3) + '...' return name.slice(0, maxLength - 3) + '...'
} }
const FileNameRender: FC<{ file: FileType }> = ({ file }) => { export const getFileIcon = (type?: string) => {
if (!type) return <FileUnknownFilled />
const ext = type.toLowerCase()
if (['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'].includes(ext)) {
return <FileImageFilled />
}
if (['.doc', '.docx'].includes(ext)) {
return <FileWordFilled />
}
if (['.xls', '.xlsx'].includes(ext)) {
return <FileExcelFilled />
}
if (['.ppt', '.pptx'].includes(ext)) {
return <FilePptFilled />
}
if (ext === '.pdf') {
return <FilePdfFilled />
}
if (['.md', '.markdown'].includes(ext)) {
return <FileMarkdownFilled />
}
if (['.zip', '.rar', '.7z', '.tar', '.gz'].includes(ext)) {
return <FileZipFilled />
}
if (['.txt', '.json', '.log', '.yml', '.yaml', '.xml', '.csv'].includes(ext)) {
return <FileTextFilled />
}
if (['.url'].includes(ext)) {
return <LinkOutlined />
}
if (['.sitemap'].includes(ext)) {
return <GlobalOutlined />
}
if (['.folder'].includes(ext)) {
return <FolderOpenFilled />
}
return <FileUnknownFilled />
}
export const FileNameRender: FC<{ file: FileType }> = ({ file }) => {
const [visible, setVisible] = useState<boolean>(false) const [visible, setVisible] = useState<boolean>(false)
const isImage = (ext: string) => { const isImage = (ext: string) => {
return ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'].includes(ext) return ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'].includes(ext)
@ -85,54 +133,6 @@ const FileNameRender: FC<{ file: FileType }> = ({ file }) => {
} }
const AttachmentPreview: FC<Props> = ({ files, setFiles }) => { const AttachmentPreview: FC<Props> = ({ files, setFiles }) => {
const getFileIcon = (type?: string) => {
if (!type) return <FileUnknownFilled />
const ext = type.toLowerCase()
if (['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'].includes(ext)) {
return <FileImageFilled />
}
if (['.doc', '.docx'].includes(ext)) {
return <FileWordFilled />
}
if (['.xls', '.xlsx'].includes(ext)) {
return <FileExcelFilled />
}
if (['.ppt', '.pptx'].includes(ext)) {
return <FilePptFilled />
}
if (ext === '.pdf') {
return <FilePdfFilled />
}
if (['.md', '.markdown'].includes(ext)) {
return <FileMarkdownFilled />
}
if (['.zip', '.rar', '.7z', '.tar', '.gz'].includes(ext)) {
return <FileZipFilled />
}
if (['.txt', '.json', '.log', '.yml', '.yaml', '.xml', '.csv'].includes(ext)) {
return <FileTextFilled />
}
if (['.url'].includes(ext)) {
return <LinkOutlined />
}
if (['.sitemap'].includes(ext)) {
return <GlobalOutlined />
}
if (['.folder'].includes(ext)) {
return <FolderOpenFilled />
}
return <FileUnknownFilled />
}
if (isEmpty(files)) { if (isEmpty(files)) {
return null return null
} }

View File

@ -581,7 +581,27 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
const onPaste = useCallback( const onPaste = useCallback(
async (event: ClipboardEvent) => { async (event: ClipboardEvent) => {
// 1. 文件/图片粘贴 // 优先处理文本粘贴
const clipboardText = event.clipboardData?.getData('text')
if (clipboardText) {
// 1. 文本粘贴
if (pasteLongTextAsFile && clipboardText.length > pasteLongTextThreshold) {
// 长文本直接转文件,阻止默认粘贴
event.preventDefault()
const tempFilePath = await window.api.file.create('pasted_text.txt')
await window.api.file.write(tempFilePath, clipboardText)
const selectedFile = await window.api.file.get(tempFilePath)
selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile])
setText(text) // 保持输入框内容不变
setTimeout(() => resizeTextArea(), 50)
return
}
// 短文本走默认粘贴行为,直接返回
return
}
// 2. 文件/图片粘贴(仅在无文本时处理)
if (event.clipboardData?.files && event.clipboardData.files.length > 0) { if (event.clipboardData?.files && event.clipboardData.files.length > 0) {
event.preventDefault() event.preventDefault()
for (const file of event.clipboardData.files) { for (const file of event.clipboardData.files) {
@ -626,43 +646,11 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
} }
return return
} }
// 其他情况默认粘贴
// 2. 文本粘贴
const clipboardText = event.clipboardData?.getData('text')
if (pasteLongTextAsFile && clipboardText && clipboardText.length > pasteLongTextThreshold) {
// 长文本直接转文件,阻止默认粘贴
event.preventDefault()
const tempFilePath = await window.api.file.create('pasted_text.txt')
await window.api.file.write(tempFilePath, clipboardText)
const selectedFile = await window.api.file.get(tempFilePath)
selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile])
setText(text) // 保持输入框内容不变
setTimeout(() => resizeTextArea(), 50)
return
}
// 短文本走默认粘贴行为
}, },
[model, pasteLongTextAsFile, pasteLongTextThreshold, resizeTextArea, supportExts, t, text] [model, pasteLongTextAsFile, pasteLongTextThreshold, resizeTextArea, supportExts, t, text]
) )
// 添加全局粘贴事件处理
useEffect(() => {
const handleGlobalPaste = (event: ClipboardEvent) => {
if (document.activeElement === textareaRef.current?.resizableTextArea?.textArea) {
return
}
onPaste(event)
}
document.addEventListener('paste', handleGlobalPaste)
return () => {
document.removeEventListener('paste', handleGlobalPaste)
}
}, [onPaste])
const handleDragOver = (e: React.DragEvent<HTMLDivElement>) => { const handleDragOver = (e: React.DragEvent<HTMLDivElement>) => {
e.preventDefault() e.preventDefault()
e.stopPropagation() e.stopPropagation()

View File

@ -3,6 +3,7 @@ import { useAssistant } from '@renderer/hooks/useAssistant'
import { useMCPServers } from '@renderer/hooks/useMCPServers' import { useMCPServers } from '@renderer/hooks/useMCPServers'
import { EventEmitter } from '@renderer/services/EventService' import { EventEmitter } from '@renderer/services/EventService'
import { Assistant, MCPPrompt, MCPResource, MCPServer } from '@renderer/types' import { Assistant, MCPPrompt, MCPResource, MCPServer } from '@renderer/types'
import { delay, runAsyncFunction } from '@renderer/utils'
import { Form, Input, Tooltip } from 'antd' import { Form, Input, Tooltip } from 'antd'
import { Plus, SquareTerminal } from 'lucide-react' import { Plus, SquareTerminal } from 'lucide-react'
import { FC, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react' import { FC, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react'
@ -109,6 +110,11 @@ const extractPromptContent = (response: any): string | null => {
return null return null
} }
// Add static variable before component definition
let isFirstResourcesListCall = true
let isFirstPromptListCall = true
const initMcpDelay = 3
const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, ToolbarButton, ...props }) => { const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, ToolbarButton, ...props }) => {
const { activedMcpServers } = useMCPServers() const { activedMcpServers } = useMCPServers()
const { t } = useTranslation() const { t } = useTranslation()
@ -308,6 +314,11 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
const promptList = useMemo(async () => { const promptList = useMemo(async () => {
const prompts: MCPPrompt[] = [] const prompts: MCPPrompt[] = []
if (isFirstPromptListCall) {
await delay(initMcpDelay)
isFirstPromptListCall = false
}
for (const server of activedMcpServers) { for (const server of activedMcpServers) {
const serverPrompts = await window.api.mcp.listPrompts(server) const serverPrompts = await window.api.mcp.listPrompts(server)
prompts.push(...serverPrompts) prompts.push(...serverPrompts)
@ -319,7 +330,8 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
icon: <SquareTerminal />, icon: <SquareTerminal />,
action: () => handlePromptSelect(prompt as MCPPromptWithArgs) action: () => handlePromptSelect(prompt as MCPPromptWithArgs)
})) }))
}, [handlePromptSelect, activedMcpServers]) // eslint-disable-next-line react-hooks/exhaustive-deps
}, [activedMcpServers])
const openPromptList = useCallback(async () => { const openPromptList = useCallback(async () => {
const prompts = await promptList const prompts = await promptList
@ -380,33 +392,42 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
const [resourcesList, setResourcesList] = useState<QuickPanelListItem[]>([]) const [resourcesList, setResourcesList] = useState<QuickPanelListItem[]>([])
useEffect(() => { useEffect(() => {
let isMounted = true runAsyncFunction(async () => {
let isMounted = true
const fetchResources = async () => { const fetchResources = async () => {
const resources: MCPResource[] = [] const resources: MCPResource[] = []
for (const server of activedMcpServers) {
const serverResources = await window.api.mcp.listResources(server) for (const server of activedMcpServers) {
resources.push(...serverResources) const serverResources = await window.api.mcp.listResources(server)
resources.push(...serverResources)
}
if (isMounted) {
setResourcesList(
resources.map((resource) => ({
label: resource.name,
description: resource.description,
icon: <SquareTerminal />,
action: () => handleResourceSelect(resource)
}))
)
}
} }
if (isMounted) { // Avoid mcp following the software startup, affecting the startup speed
setResourcesList( if (isFirstResourcesListCall) {
resources.map((resource) => ({ await delay(initMcpDelay)
label: resource.name, isFirstResourcesListCall = false
description: resource.description, fetchResources()
icon: <SquareTerminal />,
action: () => handleResourceSelect(resource)
}))
)
} }
}
fetchResources() return () => {
isMounted = false
return () => { }
isMounted = false })
} // eslint-disable-next-line react-hooks/exhaustive-deps
}, [activedMcpServers, handleResourceSelect]) }, [activedMcpServers])
const openResourcesList = useCallback(async () => { const openResourcesList = useCallback(async () => {
const resources = resourcesList const resources = resourcesList

View File

@ -22,18 +22,6 @@ const TokenCount: FC<Props> = ({ estimateTokenCount, inputTokenCount, contextCou
} }
const formatMaxCount = (max: number) => { const formatMaxCount = (max: number) => {
if (max == 100) {
return (
<span
style={{
fontSize: '16px',
position: 'relative',
top: '1px'
}}>
</span>
)
}
return max.toString() return max.toString()
} }
@ -43,7 +31,7 @@ const TokenCount: FC<Props> = ({ estimateTokenCount, inputTokenCount, contextCou
<HStack justifyContent="space-between" w="100%"> <HStack justifyContent="space-between" w="100%">
<Text>{t('chat.input.context_count.tip')}</Text> <Text>{t('chat.input.context_count.tip')}</Text>
<Text> <Text>
{contextCount.current} / {contextCount.max == 20 ? '∞' : contextCount.max} {contextCount.current} / {contextCount.max}
</Text> </Text>
</HStack> </HStack>
<Divider style={{ margin: '5px 0' }} /> <Divider style={{ margin: '5px 0' }} />

View File

@ -79,7 +79,7 @@ const WebSearchButton: FC<Props> = ({ ref, assistant, ToolbarButton }) => {
} }
items.push({ items.push({
label: '前往设置' + '...', label: t('chat.input.web_search.settings'),
icon: <Settings />, icon: <Settings />,
action: () => navigate('/settings/web-search') action: () => navigate('/settings/web-search')
}) })

View File

@ -38,13 +38,14 @@ const MainTextBlock: React.FC<Props> = ({ block, citationBlockId, role, mentions
// Use the passed citationBlockId directly in the selector // Use the passed citationBlockId directly in the selector
const { renderInputMessageAsMarkdown } = useSettings() const { renderInputMessageAsMarkdown } = useSettings()
const formattedCitations = useSelector((state: RootState) => { const rawCitations = useSelector((state: RootState) => selectFormattedCitationsByBlockId(state, citationBlockId))
const citations = selectFormattedCitationsByBlockId(state, citationBlockId)
return citations.map((citation) => ({ const formattedCitations = useMemo(() => {
return rawCitations.map((citation) => ({
...citation, ...citation,
content: citation.content ? cleanMarkdownContent(citation.content) : citation.content content: citation.content ? cleanMarkdownContent(citation.content) : citation.content
})) }))
}) }, [rawCitations])
const processedContent = useMemo(() => { const processedContent = useMemo(() => {
let content = block.content let content = block.content

View File

@ -23,12 +23,6 @@ const ThinkingBlock: React.FC<Props> = ({ block }) => {
const isThinking = useMemo(() => block.status === MessageBlockStatus.STREAMING, [block.status]) const isThinking = useMemo(() => block.status === MessageBlockStatus.STREAMING, [block.status])
const fontFamily = useMemo(() => {
return messageFont === 'serif'
? 'serif'
: '-apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen, Ubuntu, Cantarell, "Open Sans","Helvetica Neue", sans-serif'
}, [messageFont])
useEffect(() => { useEffect(() => {
if (!isThinking && thoughtAutoCollapse) { if (!isThinking && thoughtAutoCollapse) {
setActiveKey('') setActiveKey('')
@ -98,7 +92,11 @@ const ThinkingBlock: React.FC<Props> = ({ block }) => {
), ),
children: ( children: (
// FIXME: 临时兼容 // FIXME: 临时兼容
<div style={{ fontFamily, fontSize }}> <div
style={{
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
fontSize
}}>
<Markdown block={block} /> <Markdown block={block} />
</div> </div>
) )

View File

@ -1,12 +1,14 @@
import ContextMenu from '@renderer/components/ContextMenu' import ContextMenu from '@renderer/components/ContextMenu'
import { useMessageEditing } from '@renderer/context/MessageEditingContext'
import { useAssistant } from '@renderer/hooks/useAssistant' import { useAssistant } from '@renderer/hooks/useAssistant'
import { useMessageOperations } from '@renderer/hooks/useMessageOperations'
import { useModel } from '@renderer/hooks/useModel' import { useModel } from '@renderer/hooks/useModel'
import { useMessageStyle, useSettings } from '@renderer/hooks/useSettings' import { useMessageStyle, useSettings } from '@renderer/hooks/useSettings'
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService' import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
import { getMessageModelId } from '@renderer/services/MessagesService' import { getMessageModelId } from '@renderer/services/MessagesService'
import { getModelUniqId } from '@renderer/services/ModelService' import { getModelUniqId } from '@renderer/services/ModelService'
import { Assistant, Topic } from '@renderer/types' import { Assistant, Topic } from '@renderer/types'
import type { Message } from '@renderer/types/newMessage' import type { Message, MessageBlock } from '@renderer/types/newMessage'
import { classNames } from '@renderer/utils' import { classNames } from '@renderer/utils'
import { Divider } from 'antd' import { Divider } from 'antd'
import React, { Dispatch, FC, memo, SetStateAction, useCallback, useEffect, useRef } from 'react' import React, { Dispatch, FC, memo, SetStateAction, useCallback, useEffect, useRef } from 'react'
@ -14,6 +16,7 @@ import { useTranslation } from 'react-i18next'
import styled from 'styled-components' import styled from 'styled-components'
import MessageContent from './MessageContent' import MessageContent from './MessageContent'
import MessageEditor from './MessageEditor'
import MessageErrorBoundary from './MessageErrorBoundary' import MessageErrorBoundary from './MessageErrorBoundary'
import MessageHeader from './MessageHeader' import MessageHeader from './MessageHeader'
import MessageMenubar from './MessageMenubar' import MessageMenubar from './MessageMenubar'
@ -47,11 +50,54 @@ const MessageItem: FC<Props> = ({
const model = useModel(getMessageModelId(message), message.model?.provider) || message.model const model = useModel(getMessageModelId(message), message.model?.provider) || message.model
const { isBubbleStyle } = useMessageStyle() const { isBubbleStyle } = useMessageStyle()
const { showMessageDivider, messageFont, fontSize } = useSettings() const { showMessageDivider, messageFont, fontSize } = useSettings()
const { editMessageBlocks, resendUserMessageWithEdit } = useMessageOperations(topic)
const messageContainerRef = useRef<HTMLDivElement>(null) const messageContainerRef = useRef<HTMLDivElement>(null)
const { editingMessageId, stopEditing } = useMessageEditing()
const isEditing = editingMessageId === message.id
useEffect(() => {
if (isEditing && messageContainerRef.current) {
messageContainerRef.current.scrollIntoView({
behavior: 'smooth',
block: 'center'
})
}
}, [isEditing])
const handleEditSave = useCallback(
async (blocks: MessageBlock[]) => {
try {
console.log('after save blocks', blocks)
await editMessageBlocks(message.id, blocks)
stopEditing()
} catch (error) {
console.error('Failed to save message blocks:', error)
}
},
[message, editMessageBlocks, stopEditing]
)
const handleEditResend = useCallback(
async (blocks: MessageBlock[]) => {
try {
// 编辑后重新发送消息
console.log('after resend blocks', blocks)
await resendUserMessageWithEdit(message, blocks, assistant)
stopEditing()
} catch (error) {
console.error('Failed to resend message:', error)
}
},
[message, resendUserMessageWithEdit, assistant, stopEditing]
)
const handleEditCancel = useCallback(() => {
stopEditing()
}, [stopEditing])
const isLastMessage = index === 0 const isLastMessage = index === 0
const isAssistantMessage = message.role === 'assistant' const isAssistantMessage = message.role === 'assistant'
const showMenubar = !isStreaming && !message.status.includes('ing') const showMenubar = !isStreaming && !message.status.includes('ing') && !isEditing
const messageBorder = showMessageDivider ? undefined : 'none' const messageBorder = showMessageDivider ? undefined : 'none'
const messageBackground = getMessageBackground(isBubbleStyle, isAssistantMessage) const messageBackground = getMessageBackground(isBubbleStyle, isAssistantMessage)
@ -114,9 +160,18 @@ const MessageItem: FC<Props> = ({
background: messageBackground, background: messageBackground,
overflowY: 'visible' overflowY: 'visible'
}}> }}>
<MessageErrorBoundary> {isEditing ? (
<MessageContent message={message} /> <MessageEditor
</MessageErrorBoundary> message={message}
onSave={handleEditSave}
onResend={handleEditResend}
onCancel={handleEditCancel}
/>
) : (
<MessageErrorBoundary>
<MessageContent message={message} />
</MessageErrorBoundary>
)}
{showMenubar && ( {showMenubar && (
<MessageFooter <MessageFooter
className="MessageFooter" className="MessageFooter"

View File

@ -20,17 +20,6 @@ const MessageContent: React.FC<Props> = ({ message }) => {
) )
} }
// const SearchingContainer = styled.div`
// display: flex;
// flex-direction: row;
// align-items: center;
// background-color: var(--color-background-mute);
// padding: 10px;
// border-radius: 10px;
// margin-bottom: 10px;
// gap: 10px;
// `
const MentionTag = styled.span` const MentionTag = styled.span`
color: var(--color-link); color: var(--color-link);
` `

View File

@ -0,0 +1,367 @@
import CustomTag from '@renderer/components/CustomTag'
import TranslateButton from '@renderer/components/TranslateButton'
import { isGenerateImageModel, isVisionModel } from '@renderer/config/models'
import { useAssistant } from '@renderer/hooks/useAssistant'
import { useSettings } from '@renderer/hooks/useSettings'
import FileManager from '@renderer/services/FileManager'
import { FileType, FileTypes } from '@renderer/types'
import { Message, MessageBlock, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
import { classNames, getFileExtension } from '@renderer/utils'
import { getFilesFromDropEvent } from '@renderer/utils/input'
import { createFileBlock, createImageBlock } from '@renderer/utils/messageUtils/create'
import { findAllBlocks } from '@renderer/utils/messageUtils/find'
import { documentExts, imageExts, textExts } from '@shared/config/constant'
import { Tooltip } from 'antd'
import TextArea, { TextAreaRef } from 'antd/es/input/TextArea'
import { Save, Send, X } from 'lucide-react'
import { FC, memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
import AttachmentButton, { AttachmentButtonRef } from '../Inputbar/AttachmentButton'
import { FileNameRender, getFileIcon } from '../Inputbar/AttachmentPreview'
import { ToolbarButton } from '../Inputbar/Inputbar'
interface Props {
message: Message
onSave: (blocks: MessageBlock[]) => void
onResend: (blocks: MessageBlock[]) => void
onCancel: () => void
}
const MessageBlockEditor: FC<Props> = ({ message, onSave, onResend, onCancel }) => {
const allBlocks = findAllBlocks(message)
const [editedBlocks, setEditedBlocks] = useState<MessageBlock[]>(allBlocks)
const [files, setFiles] = useState<FileType[]>([])
const [isProcessing, setIsProcessing] = useState(false)
const [isFileDragging, setIsFileDragging] = useState(false)
const { assistant } = useAssistant(message.assistantId)
const model = assistant.model || assistant.defaultModel
const isVision = useMemo(() => isVisionModel(model), [model])
const supportExts = useMemo(() => [...textExts, ...documentExts, ...(isVision ? imageExts : [])], [isVision])
const { pasteLongTextAsFile, pasteLongTextThreshold, fontSize } = useSettings()
const { t } = useTranslation()
const textareaRef = useRef<TextAreaRef>(null)
const attachmentButtonRef = useRef<AttachmentButtonRef>(null)
useEffect(() => {
setTimeout(() => {
resizeTextArea()
if (textareaRef.current) {
textareaRef.current.focus({ cursor: 'end' })
}
}, 0)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
const resizeTextArea = useCallback(() => {
const textArea = textareaRef.current?.resizableTextArea?.textArea
if (textArea) {
textArea.style.height = 'auto'
textArea.style.height = textArea?.scrollHeight > 400 ? '400px' : `${textArea?.scrollHeight}px`
}
}, [])
const handleTextChange = (blockId: string, content: string) => {
setEditedBlocks((prev) => prev.map((block) => (block.id === blockId ? { ...block, content } : block)))
}
const onTranslated = (translatedText: string) => {
const mainTextBlock = editedBlocks.find((b) => b.type === MessageBlockType.MAIN_TEXT)
if (mainTextBlock) {
handleTextChange(mainTextBlock.id, translatedText)
}
setTimeout(() => resizeTextArea(), 0)
}
// 处理文件删除
const handleFileRemove = async (blockId: string) => {
setEditedBlocks((prev) => prev.filter((block) => block.id !== blockId))
}
// 处理拖拽上传
const handleDrop = async (e: React.DragEvent<HTMLDivElement>) => {
e.preventDefault()
e.stopPropagation()
setIsFileDragging(false)
const files = await getFilesFromDropEvent(e).catch((err) => {
console.error('[src/renderer/src/pages/home/Inputbar/Inputbar.tsx] handleDrop:', err)
return null
})
if (files) {
let supportedFiles = 0
files.forEach((file) => {
if (supportExts.includes(getFileExtension(file.path))) {
setFiles((prevFiles) => [...prevFiles, file])
supportedFiles++
}
})
// 如果有文件,但都不支持
if (files.length > 0 && supportedFiles === 0) {
window.message.info({
key: 'file_not_supported',
content: t('chat.input.file_not_supported')
})
}
}
}
const handleClick = async (withResend?: boolean) => {
if (isProcessing) return
setIsProcessing(true)
const updatedBlocks = [...editedBlocks]
if (files && files.length) {
const uploadedFiles = await FileManager.uploadFiles(files)
uploadedFiles.forEach((file) => {
if (file.type === FileTypes.IMAGE) {
const imgBlock = createImageBlock(message.id, { file, status: MessageBlockStatus.SUCCESS })
updatedBlocks.push(imgBlock)
} else {
const fileBlock = createFileBlock(message.id, file, { status: MessageBlockStatus.SUCCESS })
updatedBlocks.push(fileBlock)
}
})
}
if (withResend) {
onResend(updatedBlocks)
} else {
onSave(updatedBlocks)
}
}
const onPaste = useCallback(
async (event: ClipboardEvent) => {
// 1. 文本粘贴
const clipboardText = event.clipboardData?.getData('text')
if (clipboardText) {
if (pasteLongTextAsFile && clipboardText.length > pasteLongTextThreshold) {
// 长文本直接转文件,阻止默认粘贴
event.preventDefault()
const tempFilePath = await window.api.file.create('pasted_text.txt')
await window.api.file.write(tempFilePath, clipboardText)
const selectedFile = await window.api.file.get(tempFilePath)
selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile])
setTimeout(() => resizeTextArea(), 50)
return
}
// 短文本走默认粘贴行为,直接返回
return
}
// 2. 文件/图片粘贴
if (event.clipboardData?.files && event.clipboardData.files.length > 0) {
event.preventDefault()
for (const file of event.clipboardData.files) {
const filePath = window.api.file.getPathForFile(file)
if (!filePath) {
// 图像生成也支持图像编辑
if (file.type.startsWith('image/') && (isVisionModel(model) || isGenerateImageModel(model))) {
const tempFilePath = await window.api.file.create(file.name)
const arrayBuffer = await file.arrayBuffer()
const uint8Array = new Uint8Array(arrayBuffer)
await window.api.file.write(tempFilePath, uint8Array)
const selectedFile = await window.api.file.get(tempFilePath)
selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile])
break
} else {
window.message.info({
key: 'file_not_supported',
content: t('chat.input.file_not_supported')
})
}
}
if (supportExts.includes(getFileExtension(filePath))) {
const selectedFile = await window.api.file.get(filePath)
selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile])
} else {
window.message.info({
key: 'file_not_supported',
content: t('chat.input.file_not_supported')
})
}
}
return
}
// 短文本走默认粘贴行为
},
[model, pasteLongTextAsFile, pasteLongTextThreshold, resizeTextArea, supportExts, t]
)
const autoResizeTextArea = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
const textarea = e.target
textarea.style.height = 'auto'
textarea.style.height = `${textarea.scrollHeight}px`
}
return (
<>
<EditorContainer onDragOver={(e) => e.preventDefault()} onDrop={handleDrop}>
{editedBlocks
.filter((block) => block.type === MessageBlockType.MAIN_TEXT)
.map((block) => (
<Textarea
className={classNames(isFileDragging && 'file-dragging')}
key={block.id}
ref={textareaRef}
variant="borderless"
value={block.content}
onChange={(e) => {
handleTextChange(block.id, e.target.value)
autoResizeTextArea(e)
}}
autoFocus
contextMenu="true"
spellCheck={false}
onPaste={(e) => onPaste(e.nativeEvent)}
style={{
fontSize,
padding: '0px 15px 8px 15px'
}}>
<TranslateButton onTranslated={onTranslated} />
</Textarea>
))}
{(editedBlocks.some((block) => block.type === MessageBlockType.FILE || block.type === MessageBlockType.IMAGE) ||
files.length > 0) && (
<FileBlocksContainer>
{editedBlocks
.filter((block) => block.type === MessageBlockType.FILE || block.type === MessageBlockType.IMAGE)
.map(
(block) =>
block.file && (
<CustomTag
key={block.id}
icon={getFileIcon(block.file.ext)}
color="#37a5aa"
closable
onClose={() => handleFileRemove(block.id)}>
<FileNameRender file={block.file} />
</CustomTag>
)
)}
{files.map((file) => (
<CustomTag
key={file.id}
icon={getFileIcon(file.ext)}
color="#37a5aa"
closable
onClose={() => setFiles((prevFiles) => prevFiles.filter((f) => f.id !== file.id))}>
<FileNameRender file={file} />
</CustomTag>
))}
</FileBlocksContainer>
)}
<ActionBar>
<ActionBarLeft>
<AttachmentButton
ref={attachmentButtonRef}
model={model}
files={files}
setFiles={setFiles}
ToolbarButton={ToolbarButton}
/>
</ActionBarLeft>
<ActionBarMiddle />
<ActionBarRight>
<Tooltip title={t('common.cancel')}>
<ToolbarButton type="text" onClick={onCancel}>
<X size={16} />
</ToolbarButton>
</Tooltip>
<Tooltip title={t('common.save')}>
<ToolbarButton type="text" onClick={() => handleClick()}>
<Save size={16} />
</ToolbarButton>
</Tooltip>
<Tooltip title={t('chat.resend')}>
<ToolbarButton type="text" onClick={() => handleClick(true)}>
<Send size={16} />
</ToolbarButton>
</Tooltip>
</ActionBarRight>
</ActionBar>
</EditorContainer>
</>
)
}
const FileBlocksContainer = styled.div`
display: flex;
flex-wrap: wrap;
gap: 8px;
padding: 0 15px;
margin: 8px 0;
background: transplant;
border-radius: 4px;
`
const EditorContainer = styled.div`
padding: 8px 0;
border: 1px solid var(--color-border);
transition: all 0.2s ease;
border-radius: 15px;
margin-top: 0;
background-color: var(--color-background-opacity);
&.file-dragging {
border: 2px dashed #2ecc71;
&::before {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
background-color: rgba(46, 204, 113, 0.03);
border-radius: 14px;
z-index: 5;
pointer-events: none;
}
}
`
const Textarea = styled(TextArea)`
padding: 0;
border-radius: 0;
display: flex;
flex: 1;
font-family: Ubuntu;
resize: none !important;
overflow: auto;
width: 100%;
box-sizing: border-box;
&.ant-input {
line-height: 1.4;
}
`
const ActionBar = styled.div`
display: flex;
padding: 0 8px;
justify-content: space-between;
margin-top: 8px;
`
const ActionBarLeft = styled.div`
display: flex;
align-items: center;
`
const ActionBarMiddle = styled.div`
flex: 1;
`
const ActionBarRight = styled.div`
display: flex;
align-items: center;
gap: 8px;
`
export default memo(MessageBlockEditor)

View File

@ -1,4 +1,5 @@
import Scrollbar from '@renderer/components/Scrollbar' import Scrollbar from '@renderer/components/Scrollbar'
import { MessageEditingProvider } from '@renderer/context/MessageEditingContext'
import { useMessageOperations } from '@renderer/hooks/useMessageOperations' import { useMessageOperations } from '@renderer/hooks/useMessageOperations'
import { useSettings } from '@renderer/hooks/useSettings' import { useSettings } from '@renderer/hooks/useSettings'
import { MultiModelMessageStyle } from '@renderer/store/settings' import { MultiModelMessageStyle } from '@renderer/store/settings'
@ -6,7 +7,7 @@ import type { Topic } from '@renderer/types'
import type { Message } from '@renderer/types/newMessage' import type { Message } from '@renderer/types/newMessage'
import { classNames } from '@renderer/utils' import { classNames } from '@renderer/utils'
import { Popover } from 'antd' import { Popover } from 'antd'
import { memo, useCallback, useEffect, useRef, useState } from 'react' import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'
import styled, { css } from 'styled-components' import styled, { css } from 'styled-components'
import { useChatContext } from './ChatContext' import { useChatContext } from './ChatContext'
@ -45,7 +46,8 @@ const MessageGroup = ({
const prevMessageLengthRef = useRef(messageLength) const prevMessageLengthRef = useRef(messageLength)
const [selectedIndex, setSelectedIndex] = useState(messageLength - 1) const [selectedIndex, setSelectedIndex] = useState(messageLength - 1)
const getSelectedMessageId = useCallback(() => { const selectedMessageId = useMemo(() => {
if (messages.length === 1) return messages[0]?.id
const selectedMessage = messages.find((message) => message.foldSelected) const selectedMessage = messages.find((message) => message.foldSelected)
if (selectedMessage) { if (selectedMessage) {
return selectedMessage.id return selectedMessage.id
@ -55,9 +57,10 @@ const MessageGroup = ({
const setSelectedMessage = useCallback( const setSelectedMessage = useCallback(
(message: Message) => { (message: Message) => {
messages.forEach(async (m) => { // 前一个
await editMessage(m.id, { foldSelected: m.id === message.id }) editMessage(selectedMessageId, { foldSelected: false })
}) // 当前选中的消息
editMessage(message.id, { foldSelected: true })
setTimeout(() => { setTimeout(() => {
const messageElement = document.getElementById(`message-${message.id}`) const messageElement = document.getElementById(`message-${message.id}`)
@ -66,7 +69,7 @@ const MessageGroup = ({
} }
}, 200) }, 200)
}, },
[editMessage, messages] [editMessage, selectedMessageId]
) )
const isGrouped = messageLength > 1 && messages.every((m) => m.role === 'assistant') const isGrouped = messageLength > 1 && messages.every((m) => m.role === 'assistant')
@ -81,8 +84,7 @@ const MessageGroup = ({
setSelectedMessage(lastMessage) setSelectedMessage(lastMessage)
} }
} else { } else {
const selectedId = getSelectedMessageId() const newIndex = messages.findIndex((msg) => msg.id === selectedMessageId)
const newIndex = messages.findIndex((msg) => msg.id === selectedId)
if (newIndex !== -1) { if (newIndex !== -1) {
setSelectedIndex(newIndex) setSelectedIndex(newIndex)
} }
@ -140,7 +142,7 @@ const MessageGroup = ({
}, [messages, contextRegisterMessageElement]) }, [messages, contextRegisterMessageElement])
const renderMessage = useCallback( const renderMessage = useCallback(
(message: Message & { index: number }, index: number) => { (message: Message & { index: number }) => {
const isGridGroupMessage = isGrid && message.role === 'assistant' && isGrouped const isGridGroupMessage = isGrid && message.role === 'assistant' && isGrouped
const messageProps = { const messageProps = {
isGrouped, isGrouped,
@ -157,13 +159,13 @@ const MessageGroup = ({
<MessageWrapper <MessageWrapper
id={`message-${message.id}`} id={`message-${message.id}`}
$layout={multiModelMessageStyle} $layout={multiModelMessageStyle}
$selected={index === selectedIndex} // $selected={index === selectedIndex}
$isGrouped={isGrouped} $isGrouped={isGrouped}
key={message.id} key={message.id}
className={classNames({ className={classNames({
'group-message-wrapper': message.role === 'assistant' && isHorizontal && isGrouped, 'group-message-wrapper': message.role === 'assistant' && isHorizontal && isGrouped,
[multiModelMessageStyle]: isGrouped, [multiModelMessageStyle]: isGrouped,
selected: message.id === getSelectedMessageId() selected: message.id === selectedMessageId
})}> })}>
<MessageItem {...messageProps} /> <MessageItem {...messageProps} />
</MessageWrapper> </MessageWrapper>
@ -188,7 +190,7 @@ const MessageGroup = ({
content={ content={
<MessageWrapper <MessageWrapper
$layout={multiModelMessageStyle} $layout={multiModelMessageStyle}
$selected={index === selectedIndex} // $selected={index === selectedIndex}
$isGrouped={isGrouped} $isGrouped={isGrouped}
$isInPopover={true}> $isInPopover={true}>
<MessageItem {...messageProps} /> <MessageItem {...messageProps} />
@ -222,34 +224,36 @@ const MessageGroup = ({
) )
return ( return (
<GroupContainer <MessageEditingProvider>
id={`message-group-${messages[0].askId}`} <GroupContainer
$isGrouped={isGrouped} id={`message-group-${messages[0].askId}`}
$layout={multiModelMessageStyle} $isGrouped={isGrouped}
className={classNames([isGrouped && 'group-container', isHorizontal && 'horizontal', isGrid && 'grid'])}>
<GridContainer
$count={messageLength}
$layout={multiModelMessageStyle} $layout={multiModelMessageStyle}
$gridColumns={gridColumns} className={classNames([isGrouped && 'group-container', isHorizontal && 'horizontal', isGrid && 'grid'])}>
className={classNames([isGrouped && 'group-grid-container', isHorizontal && 'horizontal', isGrid && 'grid'])}> <GridContainer
{messages.map(renderMessage)} $count={messageLength}
</GridContainer> $layout={multiModelMessageStyle}
{isGrouped && ( $gridColumns={gridColumns}
<MessageGroupMenuBar className={classNames([isGrouped && 'group-grid-container', isHorizontal && 'horizontal', isGrid && 'grid'])}>
multiModelMessageStyle={multiModelMessageStyle} {messages.map(renderMessage)}
setMultiModelMessageStyle={(style) => { </GridContainer>
setMultiModelMessageStyle(style) {isGrouped && (
messages.forEach((message) => { <MessageGroupMenuBar
editMessage(message.id, { multiModelMessageStyle: style }) multiModelMessageStyle={multiModelMessageStyle}
}) setMultiModelMessageStyle={(style) => {
}} setMultiModelMessageStyle(style)
messages={messages} messages.forEach((message) => {
selectMessageId={getSelectedMessageId()} editMessage(message.id, { multiModelMessageStyle: style })
setSelectedMessage={setSelectedMessage} })
topic={topic} }}
/> messages={messages}
)} selectMessageId={selectedMessageId}
</GroupContainer> setSelectedMessage={setSelectedMessage}
topic={topic}
/>
)}
</GroupContainer>
</MessageEditingProvider>
) )
} }
@ -306,7 +310,7 @@ const GridContainer = styled.div<{ $count: number; $layout: MultiModelMessageSty
interface MessageWrapperProps { interface MessageWrapperProps {
$layout: 'fold' | 'horizontal' | 'vertical' | 'grid' $layout: 'fold' | 'horizontal' | 'vertical' | 'grid'
$selected: boolean // $selected: boolean
$isGrouped: boolean $isGrouped: boolean
$isInPopover?: boolean $isInPopover?: boolean
} }

View File

@ -1,8 +1,8 @@
import { CheckOutlined, EditOutlined, MenuOutlined, QuestionCircleOutlined, SyncOutlined } from '@ant-design/icons' import { CheckOutlined, EditOutlined, MenuOutlined, QuestionCircleOutlined, SyncOutlined } from '@ant-design/icons'
import ObsidianExportPopup from '@renderer/components/Popups/ObsidianExportPopup' import ObsidianExportPopup from '@renderer/components/Popups/ObsidianExportPopup'
import SelectModelPopup from '@renderer/components/Popups/SelectModelPopup' import SelectModelPopup from '@renderer/components/Popups/SelectModelPopup'
import TextEditPopup from '@renderer/components/Popups/TextEditPopup'
import { TranslateLanguageOptions } from '@renderer/config/translate' import { TranslateLanguageOptions } from '@renderer/config/translate'
import { useMessageEditing } from '@renderer/context/MessageEditingContext'
import { useMessageOperations, useTopicLoading } from '@renderer/hooks/useMessageOperations' import { useMessageOperations, useTopicLoading } from '@renderer/hooks/useMessageOperations'
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService' import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
import { getMessageTitle } from '@renderer/services/MessagesService' import { getMessageTitle } from '@renderer/services/MessagesService'
@ -23,13 +23,8 @@ import {
} from '@renderer/utils/export' } from '@renderer/utils/export'
// import { withMessageThought } from '@renderer/utils/formats' // import { withMessageThought } from '@renderer/utils/formats'
import { removeTrailingDoubleSpaces } from '@renderer/utils/markdown' import { removeTrailingDoubleSpaces } from '@renderer/utils/markdown'
import { import { findMainTextBlocks, findTranslationBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
findImageBlocks, import { Dropdown, Popconfirm, Tooltip } from 'antd'
findMainTextBlocks,
findTranslationBlocks,
getMainTextContent
} from '@renderer/utils/messageUtils/find'
import { Button, Dropdown, Popconfirm, Tooltip } from 'antd'
import dayjs from 'dayjs' import dayjs from 'dayjs'
import { AtSign, Copy, Languages, Menu, RefreshCw, Save, Share, Split, ThumbsUp, Trash } from 'lucide-react' import { AtSign, Copy, Languages, Menu, RefreshCw, Save, Share, Split, ThumbsUp, Trash } from 'lucide-react'
import { FilePenLine } from 'lucide-react' import { FilePenLine } from 'lucide-react'
@ -67,10 +62,8 @@ const MessageMenubar: FC<Props> = (props) => {
deleteMessage, deleteMessage,
resendMessage, resendMessage,
regenerateAssistantMessage, regenerateAssistantMessage,
resendUserMessageWithEdit,
getTranslationUpdater, getTranslationUpdater,
appendAssistantResponse, appendAssistantResponse,
editMessageBlocks,
removeMessageBlock removeMessageBlock
} = useMessageOperations(topic) } = useMessageOperations(topic)
const loading = useTopicLoading(topic) const loading = useTopicLoading(topic)
@ -121,92 +114,11 @@ const MessageMenubar: FC<Props> = (props) => {
[assistant, loading, message, resendMessage] [assistant, loading, message, resendMessage]
) )
const { startEditing } = useMessageEditing()
const onEdit = useCallback(async () => { const onEdit = useCallback(async () => {
// 禁用了助手消息的编辑,现在都是用户消息的编辑 startEditing(message.id)
let resendMessage = false }, [message.id, startEditing])
let textToEdit = ''
const imageBlocks = findImageBlocks(message)
// 如果是包含图片的消息,添加图片的 markdown 格式
if (imageBlocks.length > 0) {
const imageMarkdown = imageBlocks
.map((image, index) => `![image-${index}](file://${image?.file?.path})`)
.join('\n')
textToEdit = `${textToEdit}\n\n${imageMarkdown}`
}
textToEdit += mainTextContent
// if (message.role === 'assistant' && message.model && isReasoningModel(message.model)) {
// // const processedMessage = withMessageThought(clone(message))
// // textToEdit = getMainTextContent(processedMessage)
// textToEdit = mainTextContent
// }
const editedText = await TextEditPopup.show({
text: textToEdit,
children: (props) => {
const onPress = () => {
props.onOk?.()
resendMessage = true
}
return message.role === 'user' ? (
<ReSendButton
icon={<i className="iconfont icon-ic_send" style={{ color: 'var(--color-primary)' }} />}
onClick={onPress}>
{t('chat.resend')}
</ReSendButton>
) : null
}
})
if (editedText && editedText !== textToEdit) {
// 解析编辑后的文本,提取图片 URL
// const imageRegex = /!\[image-\d+\]\((.*?)\)/g
// const imageUrls: string[] = []
// let match
// let content = editedText
// TODO 按理说图片应该走上传,不应该在这改
// while ((match = imageRegex.exec(editedText)) !== null) {
// imageUrls.push(match[1])
// content = content.replace(match[0], '')
// }
if (resendMessage) {
resendUserMessageWithEdit(message, editedText, assistant)
} else {
editMessageBlocks(message.id, { id: findMainTextBlocks(message)[0].id, content: editedText })
}
// // 更新消息内容,保留图片信息
// await editMessage(message.id, {
// content: content.trim(),
// metadata: {
// ...message.metadata,
// generateImage:
// imageUrls.length > 0
// ? {
// type: 'url',
// images: imageUrls
// }
// : undefined
// }
// })
// resendMessage &&
// handleResendUserMessage({
// ...message,
// content: content.trim(),
// metadata: {
// ...message.metadata,
// generateImage:
// imageUrls.length > 0
// ? {
// type: 'url',
// images: imageUrls
// }
// : undefined
// }
// })
}
}, [resendUserMessageWithEdit, editMessageBlocks, assistant, mainTextContent, message, t])
const handleTranslate = useCallback( const handleTranslate = useCallback(
async (language: string) => { async (language: string) => {
@ -594,10 +506,10 @@ const ActionButton = styled.div`
} }
` `
const ReSendButton = styled(Button)` // const ReSendButton = styled(Button)`
position: absolute; // position: absolute;
top: 10px; // top: 10px;
left: 0; // left: 0;
` // `
export default memo(MessageMenubar) export default memo(MessageMenubar)

View File

@ -17,11 +17,6 @@ const MessageTools: FC<Props> = ({ blocks }) => {
const [expandedResponse, setExpandedResponse] = useState<{ content: string; title: string } | null>(null) const [expandedResponse, setExpandedResponse] = useState<{ content: string; title: string } | null>(null)
const { t } = useTranslation() const { t } = useTranslation()
const { messageFont, fontSize } = useSettings() const { messageFont, fontSize } = useSettings()
const fontFamily = useMemo(() => {
return messageFont === 'serif'
? 'serif'
: '-apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen, Ubuntu, Cantarell, "Open Sans","Helvetica Neue", sans-serif'
}, [messageFont])
const toolResponse = blocks.metadata?.rawMcpToolResponse const toolResponse = blocks.metadata?.rawMcpToolResponse
@ -39,7 +34,6 @@ const MessageTools: FC<Props> = ({ blocks }) => {
return 'Invalid Result' return 'Invalid Result'
} }
}, [toolResponse]) }, [toolResponse])
const { renderedMarkdown: styledResult } = useShikiWithMarkdownIt(`\`\`\`json\n${resultString}\n\`\`\``)
if (!toolResponse) { if (!toolResponse) {
return null return null
@ -59,8 +53,6 @@ const MessageTools: FC<Props> = ({ blocks }) => {
// Format tool responses for collapse items // Format tool responses for collapse items
const getCollapseItems = () => { const getCollapseItems = () => {
const items: { key: string; label: React.ReactNode; children: React.ReactNode }[] = [] const items: { key: string; label: React.ReactNode; children: React.ReactNode }[] = []
// Add tool responses
// for (const toolResponse of toolResponses) {
const { id, tool, status, response } = toolResponse const { id, tool, status, response } = toolResponse
const isInvoking = status === 'invoking' const isInvoking = status === 'invoking'
const isDone = status === 'done' const isDone = status === 'done'
@ -122,12 +114,15 @@ const MessageTools: FC<Props> = ({ blocks }) => {
</MessageTitleLabel> </MessageTitleLabel>
), ),
children: isDone && result && ( children: isDone && result && (
<ToolResponseContainer style={{ fontFamily, fontSize: '12px' }}> <ToolResponseContainer
<div className="markdown" dangerouslySetInnerHTML={{ __html: styledResult }} /> style={{
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
fontSize: '12px'
}}>
<CollapsedContent isExpanded={activeKeys.includes(id)} resultString={resultString} />
</ToolResponseContainer> </ToolResponseContainer>
) )
}) })
// }
return items return items
} }
@ -140,7 +135,6 @@ const MessageTools: FC<Props> = ({ blocks }) => {
switch (parsedResult.content[0]?.type) { switch (parsedResult.content[0]?.type) {
case 'text': case 'text':
return <PreviewBlock>{parsedResult.content[0].text}</PreviewBlock> return <PreviewBlock>{parsedResult.content[0].text}</PreviewBlock>
// TODO: support other types
default: default:
return <PreviewBlock>{content}</PreviewBlock> return <PreviewBlock>{content}</PreviewBlock>
} }
@ -173,8 +167,11 @@ const MessageTools: FC<Props> = ({ blocks }) => {
transitionName="animation-move-down" transitionName="animation-move-down"
styles={{ body: { maxHeight: '80vh', overflow: 'auto' } }}> styles={{ body: { maxHeight: '80vh', overflow: 'auto' } }}>
{expandedResponse && ( {expandedResponse && (
<ExpandedResponseContainer style={{ fontFamily, fontSize }}> <ExpandedResponseContainer
{/* mode swtich tabs */} style={{
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
fontSize
}}>
<Tabs <Tabs
tabBarExtraContent={ tabBarExtraContent={
<ActionButton <ActionButton
@ -200,7 +197,16 @@ const MessageTools: FC<Props> = ({ blocks }) => {
{ {
key: 'raw', key: 'raw',
label: t('message.tools.raw'), label: t('message.tools.raw'),
children: <div className="markdown" dangerouslySetInnerHTML={{ __html: styledResult }} /> children: (
<CollapsedContent
isExpanded={true}
resultString={
typeof expandedResponse.content === 'string'
? expandedResponse.content
: JSON.stringify(expandedResponse.content, null, 2)
}
/>
)
} }
]} ]}
/> />
@ -211,6 +217,19 @@ const MessageTools: FC<Props> = ({ blocks }) => {
) )
} }
// New component to handle collapsed content
const CollapsedContent: FC<{ isExpanded: boolean; resultString: string }> = ({ isExpanded, resultString }) => {
const { renderedMarkdown: styledResult } = useShikiWithMarkdownIt(
isExpanded ? `\`\`\`json\n${resultString}\n\`\`\`` : ''
)
if (!isExpanded) {
return null
}
return <div className="markdown" dangerouslySetInnerHTML={{ __html: styledResult }} />
}
const CollapseContainer = styled(Collapse)` const CollapseContainer = styled(Collapse)`
margin-top: 10px; margin-top: 10px;
margin-bottom: 12px; margin-bottom: 12px;

View File

@ -1,5 +1,6 @@
import { Navbar, NavbarLeft, NavbarRight } from '@renderer/components/app/Navbar' import { Navbar, NavbarLeft, NavbarRight } from '@renderer/components/app/Navbar'
import { HStack } from '@renderer/components/Layout' import { HStack } from '@renderer/components/Layout'
import FloatingSidebar from '@renderer/components/Popups/FloatingSidebar'
import MinAppsPopover from '@renderer/components/Popups/MinAppsPopover' import MinAppsPopover from '@renderer/components/Popups/MinAppsPopover'
import SearchPopup from '@renderer/components/Popups/SearchPopup' import SearchPopup from '@renderer/components/Popups/SearchPopup'
import { isMac } from '@renderer/config/constant' import { isMac } from '@renderer/config/constant'
@ -15,7 +16,7 @@ import { Assistant, Topic } from '@renderer/types'
import { Tooltip } from 'antd' import { Tooltip } from 'antd'
import { t } from 'i18next' import { t } from 'i18next'
import { LayoutGrid, MessageSquareDiff, PanelLeftClose, PanelRightClose, Search } from 'lucide-react' import { LayoutGrid, MessageSquareDiff, PanelLeftClose, PanelRightClose, Search } from 'lucide-react'
import { FC } from 'react' import { FC, useCallback, useState } from 'react'
import styled from 'styled-components' import styled from 'styled-components'
import SelectModelButton from './components/SelectModelButton' import SelectModelButton from './components/SelectModelButton'
@ -25,18 +26,47 @@ interface Props {
activeAssistant: Assistant activeAssistant: Assistant
activeTopic: Topic activeTopic: Topic
setActiveTopic: (topic: Topic) => void setActiveTopic: (topic: Topic) => void
setActiveAssistant: (assistant: Assistant) => void
position: 'left' | 'right'
} }
const HeaderNavbar: FC<Props> = ({ activeAssistant }) => { const HeaderNavbar: FC<Props> = ({ activeAssistant, setActiveAssistant, activeTopic, setActiveTopic }) => {
const { assistant } = useAssistant(activeAssistant.id) const { assistant } = useAssistant(activeAssistant.id)
const { showAssistants, toggleShowAssistants } = useShowAssistants() const { showAssistants, toggleShowAssistants } = useShowAssistants()
const { topicPosition, sidebarIcons, narrowMode } = useSettings() const { topicPosition, sidebarIcons, narrowMode } = useSettings()
const { showTopics, toggleShowTopics } = useShowTopics() const { showTopics, toggleShowTopics } = useShowTopics()
const dispatch = useAppDispatch() const dispatch = useAppDispatch()
const [sidebarHideCooldown, setSidebarHideCooldown] = useState(false)
useShortcut('toggle_show_assistants', () => { // Function to toggle assistants with cooldown
toggleShowAssistants() const handleToggleShowAssistants = useCallback(() => {
}) if (showAssistants) {
// When hiding sidebar, set cooldown
toggleShowAssistants()
setSidebarHideCooldown(true)
// setTimeout(() => {
// setSidebarHideCooldown(false)
// }, 10000) // 10 seconds cooldown
} else {
// When showing sidebar, no cooldown needed
toggleShowAssistants()
}
}, [showAssistants, toggleShowAssistants])
const handleToggleShowTopics = useCallback(() => {
if (showTopics) {
// When hiding sidebar, set cooldown
toggleShowTopics()
setSidebarHideCooldown(true)
// setTimeout(() => {
// setSidebarHideCooldown(false)
// }, 10000) // 10 seconds cooldown
} else {
// When showing sidebar, no cooldown needed
toggleShowTopics()
}
}, [showTopics, toggleShowTopics])
useShortcut('toggle_show_assistants', handleToggleShowAssistants)
useShortcut('toggle_show_topics', () => { useShortcut('toggle_show_topics', () => {
if (topicPosition === 'right') { if (topicPosition === 'right') {
@ -60,7 +90,7 @@ const HeaderNavbar: FC<Props> = ({ activeAssistant }) => {
{showAssistants && ( {showAssistants && (
<NavbarLeft style={{ justifyContent: 'space-between', borderRight: 'none', padding: 0 }}> <NavbarLeft style={{ justifyContent: 'space-between', borderRight: 'none', padding: 0 }}>
<Tooltip title={t('navbar.hide_sidebar')} mouseEnterDelay={0.8}> <Tooltip title={t('navbar.hide_sidebar')} mouseEnterDelay={0.8}>
<NavbarIcon onClick={toggleShowAssistants} style={{ marginLeft: isMac ? 16 : 0 }}> <NavbarIcon onClick={handleToggleShowAssistants} style={{ marginLeft: isMac ? 16 : 0 }}>
<PanelLeftClose size={18} /> <PanelLeftClose size={18} />
</NavbarIcon> </NavbarIcon>
</Tooltip> </Tooltip>
@ -73,11 +103,28 @@ const HeaderNavbar: FC<Props> = ({ activeAssistant }) => {
)} )}
<NavbarRight style={{ justifyContent: 'space-between', flex: 1 }} className="home-navbar-right"> <NavbarRight style={{ justifyContent: 'space-between', flex: 1 }} className="home-navbar-right">
<HStack alignItems="center"> <HStack alignItems="center">
{!showAssistants && ( {!showAssistants && !sidebarHideCooldown && (
<FloatingSidebar
activeAssistant={assistant}
setActiveAssistant={setActiveAssistant}
activeTopic={activeTopic}
setActiveTopic={setActiveTopic}
position={'left'}>
<Tooltip title={t('navbar.show_sidebar')} mouseEnterDelay={2}>
<NavbarIcon
onClick={() => toggleShowAssistants()}
style={{ marginRight: 8, marginLeft: isMac ? 4 : -12 }}>
<PanelRightClose size={18} />
</NavbarIcon>
</Tooltip>
</FloatingSidebar>
)}
{!showAssistants && sidebarHideCooldown && (
<Tooltip title={t('navbar.show_sidebar')} mouseEnterDelay={0.8}> <Tooltip title={t('navbar.show_sidebar')} mouseEnterDelay={0.8}>
<NavbarIcon <NavbarIcon
onClick={() => toggleShowAssistants()} onClick={() => toggleShowAssistants()}
style={{ marginRight: 8, marginLeft: isMac ? 4 : -12 }}> style={{ marginRight: 8, marginLeft: isMac ? 4 : -12 }}
onMouseOut={() => setSidebarHideCooldown(false)}>
<PanelRightClose size={18} /> <PanelRightClose size={18} />
</NavbarIcon> </NavbarIcon>
</Tooltip> </Tooltip>
@ -105,10 +152,33 @@ const HeaderNavbar: FC<Props> = ({ activeAssistant }) => {
</Tooltip> </Tooltip>
</MinAppsPopover> </MinAppsPopover>
)} )}
{topicPosition === 'right' && ( {topicPosition === 'right' && !showTopics && !sidebarHideCooldown && (
<NarrowIcon onClick={toggleShowTopics}> <FloatingSidebar
{showTopics ? <PanelRightClose size={18} /> : <PanelLeftClose size={18} />} activeAssistant={assistant}
</NarrowIcon> setActiveAssistant={setActiveAssistant}
activeTopic={activeTopic}
setActiveTopic={setActiveTopic}
position={'right'}>
<Tooltip title={t('navbar.show_sidebar')} mouseEnterDelay={2}>
<NavbarIcon onClick={() => toggleShowTopics()}>
<PanelLeftClose size={18} />
</NavbarIcon>
</Tooltip>
</FloatingSidebar>
)}
{topicPosition === 'right' && !showTopics && sidebarHideCooldown && (
<Tooltip title={t('navbar.show_sidebar')} mouseEnterDelay={2}>
<NavbarIcon onClick={() => toggleShowTopics()} onMouseOut={() => setSidebarHideCooldown(false)}>
<PanelLeftClose size={18} />
</NavbarIcon>
</Tooltip>
)}
{topicPosition === 'right' && showTopics && (
<Tooltip title={t('navbar.hide_sidebar')} mouseEnterDelay={2}>
<NavbarIcon onClick={() => handleToggleShowTopics()}>
<PanelRightClose size={18} />
</NavbarIcon>
</Tooltip>
)} )}
</HStack> </HStack>
</NavbarRight> </NavbarRight>

View File

@ -20,18 +20,26 @@ interface Props {
setActiveAssistant: (assistant: Assistant) => void setActiveAssistant: (assistant: Assistant) => void
setActiveTopic: (topic: Topic) => void setActiveTopic: (topic: Topic) => void
position: 'left' | 'right' position: 'left' | 'right'
forceToSeeAllTab?: boolean
} }
type Tab = 'assistants' | 'topic' | 'settings' type Tab = 'assistants' | 'topic' | 'settings'
let _tab: any = '' let _tab: any = ''
const HomeTabs: FC<Props> = ({ activeAssistant, activeTopic, setActiveAssistant, setActiveTopic, position }) => { const HomeTabs: FC<Props> = ({
activeAssistant,
activeTopic,
setActiveAssistant,
setActiveTopic,
position,
forceToSeeAllTab
}) => {
const { addAssistant } = useAssistants() const { addAssistant } = useAssistants()
const [tab, setTab] = useState<Tab>(position === 'left' ? _tab || 'assistants' : 'topic') const [tab, setTab] = useState<Tab>(position === 'left' ? _tab || 'assistants' : 'topic')
const { topicPosition } = useSettings() const { topicPosition } = useSettings()
const { defaultAssistant } = useDefaultAssistant() const { defaultAssistant } = useDefaultAssistant()
const { toggleShowTopics } = useShowTopics() const { showTopics, toggleShowTopics } = useShowTopics()
const { t } = useTranslation() const { t } = useTranslation()
@ -86,20 +94,22 @@ const HomeTabs: FC<Props> = ({ activeAssistant, activeTopic, setActiveAssistant,
if (position === 'right' && topicPosition === 'right' && tab === 'assistants') { if (position === 'right' && topicPosition === 'right' && tab === 'assistants') {
setTab('topic') setTab('topic')
} }
if (position === 'left' && topicPosition === 'right' && tab !== 'assistants') { if (position === 'left' && topicPosition === 'right' && forceToSeeAllTab != true && tab !== 'assistants') {
setTab('assistants') setTab('assistants')
} }
}, [position, tab, topicPosition]) }, [position, tab, topicPosition, forceToSeeAllTab])
return ( return (
<Container style={border} className="home-tabs"> <Container style={border} className="home-tabs">
{showTab && ( {(showTab || (forceToSeeAllTab == true && !showTopics)) && (
<Segmented <Segmented
value={tab} value={tab}
style={{ borderRadius: 16, paddingTop: 10, margin: '0 10px', gap: 2 }} style={{ borderRadius: 16, paddingTop: 10, margin: '0 10px', gap: 2 }}
options={ options={
[ [
position === 'left' && topicPosition === 'left' ? assistantTab : undefined, (position === 'left' && topicPosition === 'left') || (forceToSeeAllTab == true && position === 'left')
? assistantTab
: undefined,
{ {
label: t('common.topics'), label: t('common.topics'),
value: 'topic' value: 'topic'
@ -137,7 +147,6 @@ const Container = styled.div`
flex-direction: column; flex-direction: column;
max-width: var(--assistants-width); max-width: var(--assistants-width);
min-width: var(--assistants-width); min-width: var(--assistants-width);
height: calc(100vh - var(--navbar-height));
background-color: var(--color-background); background-color: var(--color-background);
overflow: hidden; overflow: hidden;
.collapsed { .collapsed {

View File

@ -1,8 +1,7 @@
import { CopyOutlined } from '@ant-design/icons' import { CopyOutlined } from '@ant-design/icons'
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { TopView } from '@renderer/components/TopView' import { TopView } from '@renderer/components/TopView'
import { DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant' import { searchKnowledgeBase } from '@renderer/services/KnowledgeService'
import { getFileFromUrl, getKnowledgeBaseParams } from '@renderer/services/KnowledgeService'
import { FileType, KnowledgeBase } from '@renderer/types' import { FileType, KnowledgeBase } from '@renderer/types'
import { Input, List, message, Modal, Spin, Tooltip, Typography } from 'antd' import { Input, List, message, Modal, Spin, Tooltip, Typography } from 'antd'
import { useRef, useState } from 'react' import { useRef, useState } from 'react'
@ -38,29 +37,8 @@ const PopupContainer: React.FC<Props> = ({ base, resolve }) => {
setSearchKeyword(value.trim()) setSearchKeyword(value.trim())
setLoading(true) setLoading(true)
try { try {
const searchResults = await window.api.knowledgeBase.search({ const searchResults = await searchKnowledgeBase(value, base)
search: value, setResults(searchResults)
base: getKnowledgeBaseParams(base)
})
let rerankResult = searchResults
if (base.rerankModel) {
rerankResult = await window.api.knowledgeBase.rerank({
search: value,
base: getKnowledgeBaseParams(base),
results: searchResults
})
}
const results = await Promise.all(
rerankResult.map(async (item) => {
const file = await getFileFromUrl(item.metadata.source)
return { ...item, file }
})
)
const filteredResults = results.filter((item) => {
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
return item.score >= threshold
})
setResults(filteredResults)
} catch (error) { } catch (error) {
console.error('Search failed:', error) console.error('Search failed:', error)
} finally { } finally {

View File

@ -29,7 +29,6 @@ interface FormData {
chunkOverlap?: number chunkOverlap?: number
threshold?: number threshold?: number
rerankModel?: string rerankModel?: string
topN?: number
} }
interface Props extends ShowParams { interface Props extends ShowParams {
@ -95,8 +94,7 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
threshold: values.threshold ?? undefined, threshold: values.threshold ?? undefined,
rerankModel: values.rerankModel rerankModel: values.rerankModel
? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel) ? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel)
: undefined, : undefined
topN: values.topN
} }
updateKnowledgeBase(newBase) updateKnowledgeBase(newBase)
setOpen(false) setOpen(false)
@ -283,23 +281,6 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
<InputNumber placeholder={t('knowledge.threshold_placeholder')} step={0.1} style={{ width: '100%' }} /> <InputNumber placeholder={t('knowledge.threshold_placeholder')} step={0.1} style={{ width: '100%' }} />
</Form.Item> </Form.Item>
<Form.Item
name="topN"
label={t('knowledge.topN')}
layout="horizontal"
initialValue={base.topN}
rules={[
{
validator(_, value) {
if (value && (value < 0 || value > 30)) {
return Promise.reject(new Error(t('knowledge.topN_too_large_or_small')))
}
return Promise.resolve()
}
}
]}>
<InputNumber placeholder={t('knowledge.topN_placeholder')} style={{ width: '100%' }} />
</Form.Item>
<Alert <Alert
message={t('knowledge.chunk_size_change_warning')} message={t('knowledge.chunk_size_change_warning')}
type="warning" type="warning"

View File

@ -1,4 +1,5 @@
import { useTheme } from '@renderer/context/ThemeProvider' import { useTheme } from '@renderer/context/ThemeProvider'
import { runAsyncFunction } from '@renderer/utils'
import { getShikiInstance } from '@renderer/utils/shiki' import { getShikiInstance } from '@renderer/utils/shiki'
import { Card } from 'antd' import { Card } from 'antd'
import MarkdownIt from 'markdown-it' import MarkdownIt from 'markdown-it'
@ -30,9 +31,11 @@ const MCPDescription = ({ searchKey }: McpDescriptionProps) => {
}, [md, searchKey]) }, [md, searchKey])
useEffect(() => { useEffect(() => {
const sk = getShikiInstance(theme) runAsyncFunction(async () => {
md.current.use(sk) const sk = await getShikiInstance(theme)
getMcpInfo() md.current.use(sk)
getMcpInfo()
})
}, [getMcpInfo, theme]) }, [getMcpInfo, theme])
return ( return (

View File

@ -1136,13 +1136,16 @@ export default class OpenAIProvider extends BaseOpenAIProvider {
return { valid: false, error: new Error('No model found') } return { valid: false, error: new Error('No model found') }
} }
const body = { const body: any = {
model: model.id, model: model.id,
messages: [{ role: 'user', content: 'hi' }], messages: [{ role: 'user', content: 'hi' }],
enable_thinking: false, // qwen3
stream stream
} }
if (this.provider.id !== 'github') {
body.enable_thinking = false; // qwen3
}
try { try {
await this.checkIsCopilot() await this.checkIsCopilot()
if (!stream) { if (!stream) {

View File

@ -18,6 +18,7 @@ export const EVENT_NAMES = {
SHOW_CHAT_SETTINGS: 'SHOW_CHAT_SETTINGS', SHOW_CHAT_SETTINGS: 'SHOW_CHAT_SETTINGS',
SHOW_TOPIC_SIDEBAR: 'SHOW_TOPIC_SIDEBAR', SHOW_TOPIC_SIDEBAR: 'SHOW_TOPIC_SIDEBAR',
SWITCH_TOPIC_SIDEBAR: 'SWITCH_TOPIC_SIDEBAR', SWITCH_TOPIC_SIDEBAR: 'SWITCH_TOPIC_SIDEBAR',
SWITCH_ASSISTANT: 'SWITCH_ASSISTANT',
NEW_CONTEXT: 'NEW_CONTEXT', NEW_CONTEXT: 'NEW_CONTEXT',
NEW_BRANCH: 'NEW_BRANCH', NEW_BRANCH: 'NEW_BRANCH',
COPY_TOPIC_IMAGE: 'COPY_TOPIC_IMAGE', COPY_TOPIC_IMAGE: 'COPY_TOPIC_IMAGE',

View File

@ -47,8 +47,8 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams
rerankBaseURL: rerankHost, rerankBaseURL: rerankHost,
rerankApiKey: rerankAiProvider.getApiKey() || 'secret', rerankApiKey: rerankAiProvider.getApiKey() || 'secret',
rerankModel: base.rerankModel?.id, rerankModel: base.rerankModel?.id,
rerankModelProvider: base.rerankModel?.provider, rerankModelProvider: base.rerankModel?.provider
topN: base.topN // topN: base.topN
} }
} }
@ -88,6 +88,51 @@ export const getKnowledgeSourceUrl = async (item: ExtractChunkData & { file: Fil
return item.metadata.source return item.metadata.source
} }
export const searchKnowledgeBase = async (
query: string,
base: KnowledgeBase,
rewrite?: string
): Promise<Array<ExtractChunkData & { file: FileType | null }>> => {
try {
const baseParams = getKnowledgeBaseParams(base)
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
// 执行搜索
const searchResults = await window.api.knowledgeBase.search({
search: query,
base: baseParams
})
// 过滤阈值不达标的结果
const filteredResults = searchResults.filter((item) => item.score >= threshold)
// 如果有rerank模型执行重排
let rerankResults = filteredResults
if (base.rerankModel && filteredResults.length > 0) {
rerankResults = await window.api.knowledgeBase.rerank({
search: rewrite || query,
base: baseParams,
results: filteredResults
})
}
// 限制文档数量
const limitedResults = rerankResults.slice(0, documentCount)
// 处理文件信息
return await Promise.all(
limitedResults.map(async (item) => {
const file = await getFileFromUrl(item.metadata.source)
return { ...item, file }
})
)
} catch (error) {
Logger.error(`Error searching knowledge base ${base.name}:`, error)
return []
}
}
export const processKnowledgeSearch = async ( export const processKnowledgeSearch = async (
extractResults: ExtractResults, extractResults: ExtractResults,
knowledgeBaseIds: string[] | undefined knowledgeBaseIds: string[] | undefined
@ -100,6 +145,7 @@ export const processKnowledgeSearch = async (
Logger.log('No valid question found in extractResults.knowledge') Logger.log('No valid question found in extractResults.knowledge')
return [] return []
} }
const questions = extractResults.knowledge.question const questions = extractResults.knowledge.question
const rewrite = extractResults.knowledge.rewrite const rewrite = extractResults.knowledge.rewrite
@ -109,73 +155,35 @@ export const processKnowledgeSearch = async (
return [] return []
} }
const referencesPromises = bases.map(async (base) => { // 为每个知识库执行多问题搜索
try { const baseSearchPromises = bases.map(async (base) => {
const baseParams = getKnowledgeBaseParams(base) // 为每个问题搜索并合并结果
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT const allResults = await Promise.all(questions.map((question) => searchKnowledgeBase(question, base, rewrite)))
const allSearchResultsPromises = questions.map((question) => // 合并结果并去重
window.api.knowledgeBase const flatResults = allResults.flat()
.search({ const uniqueResults = Array.from(
search: question, new Map(flatResults.map((item) => [item.metadata.uniqueId || item.pageContent, item])).values()
base: baseParams ).sort((a, b) => b.score - a.score)
})
.then((results) =>
results.filter((item) => {
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
return item.score >= threshold
})
)
)
const allSearchResults = await Promise.all(allSearchResultsPromises) // 转换为引用格式
return await Promise.all(
const searchResults = Array.from( uniqueResults.map(
new Map(allSearchResults.flat().map((item) => [item.metadata.uniqueId || item.pageContent, item])).values() async (item, index) =>
).sort((a, b) => b.score - a.score) ({
id: index + 1,
Logger.log(`Knowledge base ${base.name} search results:`, searchResults)
let rerankResults = searchResults
if (base.rerankModel && searchResults.length > 0) {
rerankResults = await window.api.knowledgeBase.rerank({
search: rewrite,
base: baseParams,
results: searchResults
})
}
if (rerankResults.length > 0) {
rerankResults = rerankResults.slice(0, documentCount)
}
const processdResults = await Promise.all(
rerankResults.map(async (item) => {
const file = await getFileFromUrl(item.metadata.source)
return { ...item, file }
})
)
return await Promise.all(
processdResults.map(async (item, index) => {
// const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId)
return {
id: index + 1, // 搜索多个库会导致ID重复
content: item.pageContent, content: item.pageContent,
sourceUrl: await getKnowledgeSourceUrl(item), sourceUrl: await getKnowledgeSourceUrl(item),
type: 'file' // 需要映射 baseItem.type是'localPathLoader' -> 'file' type: 'file'
} as KnowledgeReference }) as KnowledgeReference
})
) )
} catch (error) { )
Logger.error(`Error searching knowledge base ${base.name}:`, error)
return []
}
}) })
const resultsPerBase = await Promise.all(referencesPromises) // 汇总所有知识库的结果
const resultsPerBase = await Promise.all(baseSearchPromises)
const allReferencesRaw = resultsPerBase.flat().filter((ref): ref is KnowledgeReference => !!ref) const allReferencesRaw = resultsPerBase.flat().filter((ref): ref is KnowledgeReference => !!ref)
// 重新为引用分配ID // 重新为引用分配ID
return allReferencesRaw.map((ref, index) => ({ return allReferencesRaw.map((ref, index) => ({
...ref, ...ref,

View File

@ -971,22 +971,7 @@ export const resendMessageThunk =
* of its associated assistant responses using resendMessageThunk. * of its associated assistant responses using resendMessageThunk.
*/ */
export const resendUserMessageWithEditThunk = export const resendUserMessageWithEditThunk =
( (topicId: Topic['id'], originalMessage: Message, assistant: Assistant) => async (dispatch: AppDispatch) => {
topicId: Topic['id'],
originalMessage: Message,
mainTextBlockId: string,
editedContent: string,
assistant: Assistant
) =>
async (dispatch: AppDispatch) => {
const blockChanges = {
content: editedContent,
updatedAt: new Date().toISOString()
}
// Update block in Redux and DB
dispatch(updateOneBlock({ id: mainTextBlockId, changes: blockChanges }))
await db.message_blocks.update(mainTextBlockId, blockChanges)
// Trigger the regeneration logic for associated assistant messages // Trigger the regeneration logic for associated assistant messages
dispatch(resendMessageThunk(topicId, originalMessage, assistant)) dispatch(resendMessageThunk(topicId, originalMessage, assistant))
} }
@ -1411,14 +1396,14 @@ export const updateMessageAndBlocksThunk =
topicId: string, topicId: string,
// Allow messageUpdates to be optional or just contain the ID if only blocks are updated // Allow messageUpdates to be optional or just contain the ID if only blocks are updated
messageUpdates: (Partial<Message> & Pick<Message, 'id'>) | null, // ID is always required for context messageUpdates: (Partial<Message> & Pick<Message, 'id'>) | null, // ID is always required for context
blockUpdatesList: Partial<MessageBlock>[] // Block updates remain required for this thunk's purpose blockUpdatesList: MessageBlock[] // Block updates remain required for this thunk's purpose
) => ) =>
async (dispatch: AppDispatch): Promise<boolean> => { async (dispatch: AppDispatch): Promise<void> => {
const messageId = messageUpdates?.id const messageId = messageUpdates?.id
if (messageUpdates && !messageId) { if (messageUpdates && !messageId) {
console.error('[updateMessageAndBlocksThunk] Message ID is required.') console.error('[updateMessageAndUpdateBlocksThunk] Message ID is required.')
return false return
} }
try { try {
@ -1434,14 +1419,7 @@ export const updateMessageAndBlocksThunk =
} }
if (blockUpdatesList.length > 0) { if (blockUpdatesList.length > 0) {
blockUpdatesList.forEach((blockUpdate) => { dispatch(upsertManyBlocks(blockUpdatesList))
const { id: blockId, ...blockChanges } = blockUpdate
if (blockId && Object.keys(blockChanges).length > 0) {
dispatch(updateOneBlock({ id: blockId, changes: blockChanges }))
} else if (!blockId) {
console.warn('[updateMessageAndBlocksThunk] Skipping block update due to missing block ID:', blockUpdate)
}
})
} }
// 2. 更新数据库 (在事务中) // 2. 更新数据库 (在事务中)
@ -1468,27 +1446,57 @@ export const updateMessageAndBlocksThunk =
} }
} }
// Always process block updates if the list is provided and not empty
if (blockUpdatesList.length > 0) { if (blockUpdatesList.length > 0) {
const validBlockUpdatesForDb = blockUpdatesList await db.message_blocks.bulkPut(blockUpdatesList)
.map((bu) => { }
const { id, ...changes } = bu })
if (id && Object.keys(changes).length > 0) { } catch (error) {
return { key: id, changes: changes } console.error(`[updateMessageAndBlocksThunk] Failed to process updates for message ${messageId}:`, error)
} }
return null }
})
.filter((bu) => bu !== null) as { key: string; changes: Partial<MessageBlock> }[]
if (validBlockUpdatesForDb.length > 0) { export const removeBlocksThunk =
await db.message_blocks.bulkUpdate(validBlockUpdatesForDb) (topicId: string, messageId: string, blockIdsToRemove: string[]) =>
} async (dispatch: AppDispatch, getState: () => RootState): Promise<void> => {
if (!blockIdsToRemove.length) {
console.warn('[removeBlocksFromMessageThunk] No block IDs provided to remove.')
return
}
try {
const state = getState()
const message = state.messages.entities[messageId]
if (!message) {
console.error(`[removeBlocksFromMessageThunk] Message ${messageId} not found in state.`)
return
}
const blockIdsToRemoveSet = new Set(blockIdsToRemove)
const updatedBlockIds = (message.blocks || []).filter((id) => !blockIdsToRemoveSet.has(id))
// 1. Update Redux state
dispatch(newMessagesActions.updateMessage({ topicId, messageId, updates: { blocks: updatedBlockIds } }))
if (blockIdsToRemove.length > 0) {
dispatch(removeManyBlocks(blockIdsToRemove))
}
const finalMessagesToSave = selectMessagesForTopic(getState(), topicId)
// 2. Update database (in a transaction)
await db.transaction('rw', db.topics, db.message_blocks, async () => {
// Update the message in the topic
await db.topics.update(topicId, { messages: finalMessagesToSave })
// Delete the blocks from the database
if (blockIdsToRemove.length > 0) {
await db.message_blocks.bulkDelete(blockIdsToRemove)
} }
}) })
return true return
} catch (error) { } catch (error) {
console.error(`[updateMessageAndBlocksThunk] Failed to process updates for message ${messageId}:`, error) console.error(`[removeBlocksFromMessageThunk] Failed to remove blocks from message ${messageId}:`, error)
return false throw error
} }
} }

View File

@ -372,7 +372,7 @@ export interface KnowledgeBase {
chunkOverlap?: number chunkOverlap?: number
threshold?: number threshold?: number
rerankModel?: Model rerankModel?: Model
topN?: number // topN?: number
} }
export type KnowledgeBaseParams = { export type KnowledgeBaseParams = {
@ -388,7 +388,7 @@ export type KnowledgeBaseParams = {
rerankBaseURL?: string rerankBaseURL?: string
rerankModel?: string rerankModel?: string
rerankModelProvider?: string rerankModelProvider?: string
topN?: number documentCount?: number
} }
export type GenerateImageParams = { export type GenerateImageParams = {

View File

@ -0,0 +1,32 @@
import { bundledLanguages, bundledThemes, createHighlighter, type Highlighter } from 'shiki'
let highlighterPromise: Promise<Highlighter> | null = null
export async function getHighlighter() {
if (!highlighterPromise) {
highlighterPromise = createHighlighter({
langs: ['javascript', 'typescript', 'python', 'java', 'markdown', 'json'],
themes: ['one-light', 'material-theme-darker']
})
}
return await highlighterPromise
}
export async function loadLanguageIfNeeded(highlighter: Highlighter, language: string) {
if (!highlighter.getLoadedLanguages().includes(language)) {
const languageImportFn = bundledLanguages[language]
if (languageImportFn) {
await highlighter.loadLanguage(await languageImportFn())
}
}
}
export async function loadThemeIfNeeded(highlighter: Highlighter, theme: string) {
if (!highlighter.getLoadedThemes().includes(theme)) {
const themeImportFn = bundledThemes[theme]
if (themeImportFn) {
await highlighter.loadTheme(await themeImportFn())
}
}
}

View File

@ -1,16 +1,33 @@
import store from '@renderer/store' import store from '@renderer/store'
import { messageBlocksSelectors } from '@renderer/store/messageBlock' import { messageBlocksSelectors } from '@renderer/store/messageBlock'
import { FileType } from '@renderer/types'
import type { import type {
CitationMessageBlock, CitationMessageBlock,
FileMessageBlock, FileMessageBlock,
ImageMessageBlock, ImageMessageBlock,
MainTextMessageBlock, MainTextMessageBlock,
Message, Message,
MessageBlock,
ThinkingMessageBlock, ThinkingMessageBlock,
TranslationMessageBlock TranslationMessageBlock
} from '@renderer/types/newMessage' } from '@renderer/types/newMessage'
import { MessageBlockType } from '@renderer/types/newMessage' import { MessageBlockType } from '@renderer/types/newMessage'
export const findAllBlocks = (message: Message): MessageBlock[] => {
if (!message || !message.blocks || message.blocks.length === 0) {
return []
}
const state = store.getState()
const allBlocks: MessageBlock[] = []
for (const blockId of message.blocks) {
const block = messageBlocksSelectors.selectById(state, blockId)
if (block) {
allBlocks.push(block)
}
}
return allBlocks
}
/** /**
* Finds all MainTextMessageBlocks associated with a given message, in order. * Finds all MainTextMessageBlocks associated with a given message, in order.
* @param message - The message object. * @param message - The message object.
@ -122,6 +139,28 @@ export const getKnowledgeBaseIds = (message: Message): string[] | undefined => {
return firstTextBlock?.flatMap((block) => block.knowledgeBaseIds).filter((id): id is string => Boolean(id)) return firstTextBlock?.flatMap((block) => block.knowledgeBaseIds).filter((id): id is string => Boolean(id))
} }
/**
* Gets the file content from all FileMessageBlocks and ImageMessageBlocks of a message.
* @param message - The message object.
* @returns The file content or an empty string if no file blocks are found.
*/
export const getFileContent = (message: Message): FileType[] => {
const files: FileType[] = []
const fileBlocks = findFileBlocks(message)
for (const block of fileBlocks) {
if (block.file) {
files.push(block.file)
}
}
const imageBlocks = findImageBlocks(message)
for (const block of imageBlocks) {
if (block.file) {
files.push(block.file)
}
}
return files
}
/** /**
* Finds all CitationBlocks associated with a given message. * Finds all CitationBlocks associated with a given message.
* @param message - The message object. * @param message - The message object.

View File

@ -1,11 +1,13 @@
import { useTheme } from '@renderer/context/ThemeProvider' import { useTheme } from '@renderer/context/ThemeProvider'
import { ThemeMode } from '@renderer/types' import { ThemeMode } from '@renderer/types'
import { MarkdownItShikiOptions, setupMarkdownIt } from '@shikijs/markdown-it' import { setupMarkdownIt } from '@shikijs/markdown-it'
import MarkdownIt from 'markdown-it' import MarkdownIt from 'markdown-it'
import { useEffect, useRef, useState } from 'react' import { useEffect, useRef, useState } from 'react'
import { BuiltinLanguage, BuiltinTheme, bundledLanguages, createHighlighter } from 'shiki'
import { getTokenStyleObject, ThemedToken } from 'shiki/core' import { getTokenStyleObject, ThemedToken } from 'shiki/core'
import { runAsyncFunction } from '.'
import { getHighlighter } from './highlighter'
/** /**
* Shiki token React * Shiki token React
* *
@ -44,19 +46,9 @@ const defaultOptions = {
defaultColor: 'light' defaultColor: 'light'
} }
const initHighlighter = async (options: MarkdownItShikiOptions) => { export async function getShikiInstance(theme: ThemeMode) {
const themeNames = ('themes' in options ? Object.values(options.themes) : [options.theme]).filter( const highlighter = await getHighlighter()
Boolean
) as BuiltinTheme[]
return await createHighlighter({
themes: themeNames,
langs: options.langs || (Object.keys(bundledLanguages) as BuiltinLanguage[])
})
}
const highlighter = await initHighlighter(defaultOptions)
export function getShikiInstance(theme: ThemeMode) {
const options = { const options = {
...defaultOptions, ...defaultOptions,
defaultColor: theme defaultColor: theme
@ -77,9 +69,11 @@ export function useShikiWithMarkdownIt(content: string) {
) )
const { theme } = useTheme() const { theme } = useTheme()
useEffect(() => { useEffect(() => {
const sk = getShikiInstance(theme) runAsyncFunction(async () => {
md.current.use(sk) const sk = await getShikiInstance(theme)
setRenderedMarkdown(md.current.render(content)) md.current.use(sk)
setRenderedMarkdown(md.current.render(content))
})
}, [content, theme]) }, [content, theme])
return { return {
renderedMarkdown renderedMarkdown