Merge remote-tracking branch 'origin/main' into feat/cherry-store

This commit is contained in:
MyPrototypeWhat 2025-08-01 15:54:12 +08:00
commit 9f944ff42c
105 changed files with 3974 additions and 15668 deletions

View File

@ -5,8 +5,8 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
## Development Commands
### Environment Setup
- **Prerequisites**: Node.js v20.x.x, Yarn 4.6.0
- **Setup Yarn**: `corepack enable && corepack prepare yarn@4.6.0 --activate`
- **Prerequisites**: Node.js v22.x.x or higher, Yarn 4.9.1
- **Setup Yarn**: `corepack enable && corepack prepare yarn@4.9.1 --activate`
- **Install Dependencies**: `yarn install`
### Development
@ -61,7 +61,8 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
- **Loaders**: Support for various file formats (PDF, DOCX, EPUB, etc.)
### Build System
- **Electron-Vite**: Development and build tooling
- **Electron-Vite**: Development and build tooling (v4.0.0)
- **Rolldown-Vite**: Using experimental rolldown-vite instead of standard vite
- **Workspaces**: Monorepo structure with `packages/` directory
- **Multiple Entry Points**: Main app, mini window, selection toolbar
- **Styled Components**: CSS-in-JS styling with SWC optimization

Binary file not shown.

Before

Width:  |  Height:  |  Size: 38 KiB

After

Width:  |  Height:  |  Size: 40 KiB

View File

@ -50,11 +50,8 @@ files:
- '!node_modules/rollup-plugin-visualizer'
- '!node_modules/js-tiktoken'
- '!node_modules/@tavily/core/node_modules/js-tiktoken'
- '!node_modules/pdf-parse/lib/pdf.js/{v1.9.426,v1.10.88,v2.0.550}'
- '!node_modules/mammoth/{mammoth.browser.js,mammoth.browser.min.js}'
- '!node_modules/selection-hook/prebuilds/**/*' # we rebuild .node, don't use prebuilds
- '!node_modules/pdfjs-dist/web/**/*'
- '!node_modules/pdfjs-dist/legacy/**/*'
- '!node_modules/selection-hook/node_modules' # we don't need what in the node_modules dir
- '!node_modules/selection-hook/src' # we don't need source files
- '!**/*.{h,iobj,ipdb,tlog,recipe,vcxproj,vcxproj.filters,Makefile,*.Makefile}' # filter .node build files

View File

@ -26,13 +26,11 @@ export default defineConfig({
},
build: {
rollupOptions: {
external: ['@libsql/client', 'bufferutil', 'utf-8-validate', '@cherrystudio/mac-system-ocr'],
output: isProd
? {
manualChunks: undefined, // 彻底禁用代码分割 - 返回 null 强制单文件打包
inlineDynamicImports: true // 内联所有动态导入,这是关键配置
}
: undefined
external: ['@libsql/client', 'bufferutil', 'utf-8-validate'],
output: {
manualChunks: undefined, // 彻底禁用代码分割 - 返回 null 强制单文件打包
inlineDynamicImports: true // 内联所有动态导入,这是关键配置
}
},
sourcemap: isDev
},

View File

@ -1,6 +1,6 @@
{
"name": "CherryStudio",
"version": "1.5.4-rc.1",
"version": "1.5.4-rc.2",
"private": true,
"description": "A powerful AI assistant for producer.",
"main": "./out/main/index.js",
@ -70,20 +70,15 @@
"prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky"
},
"dependencies": {
"@cherrystudio/pdf-to-img-napi": "^0.0.1",
"@libsql/client": "0.14.0",
"@libsql/win32-x64-msvc": "^0.4.7",
"@strongtz/win32-arm64-msvc": "^0.4.7",
"express": "^5.1.0",
"graceful-fs": "^4.2.11",
"jsdom": "26.1.0",
"node-stream-zip": "^1.15.0",
"officeparser": "^4.2.0",
"os-proxy-config": "^1.1.2",
"pdfjs-dist": "4.10.38",
"selection-hook": "^1.0.8",
"swagger-jsdoc": "^6.2.8",
"swagger-ui-express": "^5.0.1",
"turndown": "7.2.0"
},
"devDependencies": {
@ -152,10 +147,7 @@
"@testing-library/user-event": "^14.6.1",
"@tryfabric/martian": "^1.2.4",
"@types/cli-progress": "^3",
"@types/content-type": "^1.1.9",
"@types/cors": "^2.8.19",
"@types/diff": "^7",
"@types/express": "^5",
"@types/fs-extra": "^11",
"@types/lodash": "^4.17.5",
"@types/markdown-it": "^14",
@ -166,8 +158,6 @@
"@types/react-dom": "^19.0.4",
"@types/react-infinite-scroll-component": "^5.0.0",
"@types/react-window": "^1",
"@types/swagger-jsdoc": "^6",
"@types/swagger-ui-express": "^4.1.8",
"@types/tinycolor2": "^1",
"@types/word-extractor": "^1",
"@uiw/codemirror-extensions-langs": "^4.23.14",
@ -240,6 +230,7 @@
"npx-scope-finder": "^1.2.0",
"openai": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
"p-queue": "^8.1.0",
"pdf-lib": "^1.17.1",
"playwright": "^1.52.0",
"prettier": "^3.5.3",
"prettier-plugin-sort-json": "^4.1.1",
@ -295,11 +286,7 @@
"zipread": "^1.3.3",
"zod": "^3.25.74"
},
"optionalDependencies": {
"@cherrystudio/mac-system-ocr": "^0.2.2"
},
"resolutions": {
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A0.3.16#~/.yarn/patches/@langchain-openai-npm-0.3.16-e525b59526.patch",
"@langchain/openai@npm:>=0.1.0 <0.4.0": "patch:@langchain/openai@npm%3A0.3.16#~/.yarn/patches/@langchain-openai-npm-0.3.16-e525b59526.patch",
"libsql@npm:^0.4.4": "patch:libsql@npm%3A0.4.7#~/.yarn/patches/libsql-npm-0.4.7-444e260fb1.patch",

View File

@ -34,6 +34,7 @@ export enum IpcChannel {
App_InstallUvBinary = 'app:install-uv-binary',
App_InstallBunBinary = 'app:install-bun-binary',
App_LogToMain = 'app:log-to-main',
App_SaveData = 'app:save-data',
App_MacIsProcessTrusted = 'app:mac-is-process-trusted',
App_MacRequestProcessTrust = 'app:mac-request-process-trust',
@ -273,11 +274,5 @@ export enum IpcChannel {
TRACE_SET_TITLE = 'trace:setTitle',
TRACE_ADD_END_MESSAGE = 'trace:addEndMessage',
TRACE_CLEAN_LOCAL_DATA = 'trace:cleanLocalData',
TRACE_ADD_STREAM_MESSAGE = 'trace:addStreamMessage',
// API Server
ApiServer_Start = 'api-server:start',
ApiServer_Stop = 'api-server:stop',
ApiServer_Restart = 'api-server:restart',
ApiServer_GetStatus = 'api-server:get-status',
ApiServer_GetConfig = 'api-server:get-config'
TRACE_ADD_STREAM_MESSAGE = 'trace:addStreamMessage'
}

File diff suppressed because one or more lines are too long

View File

@ -53,7 +53,7 @@ exports.default = async function (context) {
* @param {string} nodeModulesPath
*/
function removeMacOnlyPackages(nodeModulesPath) {
const macOnlyPackages = ['@cherrystudio/mac-system-ocr']
const macOnlyPackages = []
macOnlyPackages.forEach((packageName) => {
const packagePath = path.join(nodeModulesPath, packageName)

View File

@ -1,128 +0,0 @@
import { loggerService } from '@main/services/LoggerService'
import cors from 'cors'
import express from 'express'
import { v4 as uuidv4 } from 'uuid'
import { authMiddleware } from './middleware/auth'
import { errorHandler } from './middleware/error'
import { setupOpenAPIDocumentation } from './middleware/openapi'
import { chatRoutes } from './routes/chat'
import { mcpRoutes } from './routes/mcp'
import { modelsRoutes } from './routes/models'
const logger = loggerService.withContext('ApiServer')
const app = express()
// Global middleware
app.use((req, res, next) => {
const start = Date.now()
res.on('finish', () => {
const duration = Date.now() - start
logger.info(`${req.method} ${req.path} - ${res.statusCode} - ${duration}ms`)
})
next()
})
app.use((_req, res, next) => {
res.setHeader('X-Request-ID', uuidv4())
next()
})
app.use(
cors({
origin: '*',
allowedHeaders: ['Content-Type', 'Authorization'],
methods: ['GET', 'POST', 'PUT', 'DELETE', 'OPTIONS']
})
)
/**
* @swagger
* /health:
* get:
* summary: Health check endpoint
* description: Check server status (no authentication required)
* tags: [Health]
* security: []
* responses:
* 200:
* description: Server is healthy
* content:
* application/json:
* schema:
* type: object
* properties:
* status:
* type: string
* example: ok
* timestamp:
* type: string
* format: date-time
* version:
* type: string
* example: 1.0.0
*/
app.get('/health', (_req, res) => {
res.json({
status: 'ok',
timestamp: new Date().toISOString(),
version: process.env.npm_package_version || '1.0.0'
})
})
/**
* @swagger
* /:
* get:
* summary: API information
* description: Get basic API information and available endpoints
* tags: [General]
* security: []
* responses:
* 200:
* description: API information
* content:
* application/json:
* schema:
* type: object
* properties:
* name:
* type: string
* example: Cherry Studio API
* version:
* type: string
* example: 1.0.0
* endpoints:
* type: object
*/
app.get('/', (_req, res) => {
res.json({
name: 'Cherry Studio API',
version: '1.0.0',
endpoints: {
health: 'GET /health',
models: 'GET /v1/models',
chat: 'POST /v1/chat/completions',
mcp: 'GET /v1/mcps'
}
})
})
// API v1 routes with auth
const apiRouter = express.Router()
apiRouter.use(authMiddleware)
apiRouter.use(express.json())
// Mount routes
apiRouter.use('/chat', chatRoutes)
apiRouter.use('/mcps', mcpRoutes)
apiRouter.use('/models', modelsRoutes)
app.use('/v1', apiRouter)
// Setup OpenAPI documentation
setupOpenAPIDocumentation(app)
// Error handling (must be last)
app.use(errorHandler)
export { app }

View File

@ -1,67 +0,0 @@
import { ApiServerConfig } from '@types'
import { v4 as uuidv4 } from 'uuid'
import { loggerService } from '../services/LoggerService'
import { reduxService } from '../services/ReduxService'
const logger = loggerService.withContext('ApiServerConfig')
class ConfigManager {
private _config: ApiServerConfig | null = null
async load(): Promise<ApiServerConfig> {
try {
const settings = await reduxService.select('state.settings')
// Auto-generate API key if not set
if (!settings?.apiServer?.apiKey) {
const generatedKey = `cs-sk-${uuidv4()}`
await reduxService.dispatch({
type: 'settings/setApiServerApiKey',
payload: generatedKey
})
this._config = {
enabled: settings?.apiServer?.enabled ?? false,
port: settings?.apiServer?.port ?? 23333,
host: 'localhost',
apiKey: generatedKey
}
} else {
this._config = {
enabled: settings?.apiServer?.enabled ?? false,
port: settings?.apiServer?.port ?? 23333,
host: 'localhost',
apiKey: settings.apiServer.apiKey
}
}
return this._config
} catch (error: any) {
logger.warn('Failed to load config from Redux, using defaults:', error)
this._config = {
enabled: false,
port: 23333,
host: 'localhost',
apiKey: `cs-sk-${uuidv4()}`
}
return this._config
}
}
async get(): Promise<ApiServerConfig> {
if (!this._config) {
await this.load()
}
if (!this._config) {
throw new Error('Failed to load API server configuration')
}
return this._config
}
async reload(): Promise<ApiServerConfig> {
return await this.load()
}
}
export const config = new ConfigManager()

View File

@ -1,2 +0,0 @@
export { config } from './config'
export { apiServer } from './server'

View File

@ -1,25 +0,0 @@
import { NextFunction, Request, Response } from 'express'
import { config } from '../config'
export const authMiddleware = async (req: Request, res: Response, next: NextFunction) => {
const auth = req.header('Authorization')
if (!auth || !auth.startsWith('Bearer ')) {
return res.status(401).json({ error: 'Unauthorized' })
}
const token = auth.slice(7) // Remove 'Bearer ' prefix
if (!token) {
return res.status(401).json({ error: 'Unauthorized, Bearer token is empty' })
}
const { apiKey } = await config.get()
if (token !== apiKey) {
return res.status(403).json({ error: 'Forbidden' })
}
return next()
}

View File

@ -1,21 +0,0 @@
import { NextFunction, Request, Response } from 'express'
import { loggerService } from '../../services/LoggerService'
const logger = loggerService.withContext('ApiServerErrorHandler')
// eslint-disable-next-line @typescript-eslint/no-unused-vars
export const errorHandler = (err: Error, _req: Request, res: Response, _next: NextFunction) => {
logger.error('API Server Error:', err)
// Don't expose internal errors in production
const isDev = process.env.NODE_ENV === 'development'
res.status(500).json({
error: {
message: isDev ? err.message : 'Internal server error',
type: 'server_error',
...(isDev && { stack: err.stack })
}
})
}

View File

@ -1,206 +0,0 @@
import { Express } from 'express'
import swaggerJSDoc from 'swagger-jsdoc'
import swaggerUi from 'swagger-ui-express'
import { loggerService } from '../../services/LoggerService'
const logger = loggerService.withContext('OpenAPIMiddleware')
const swaggerOptions: swaggerJSDoc.Options = {
definition: {
openapi: '3.0.0',
info: {
title: 'Cherry Studio API',
version: '1.0.0',
description: 'OpenAI-compatible API for Cherry Studio with additional Cherry-specific endpoints',
contact: {
name: 'Cherry Studio',
url: 'https://github.com/CherryHQ/cherry-studio'
}
},
servers: [
{
url: 'http://localhost:23333',
description: 'Local development server'
}
],
components: {
securitySchemes: {
BearerAuth: {
type: 'http',
scheme: 'bearer',
bearerFormat: 'JWT',
description: 'Use the API key from Cherry Studio settings'
}
},
schemas: {
Error: {
type: 'object',
properties: {
error: {
type: 'object',
properties: {
message: { type: 'string' },
type: { type: 'string' },
code: { type: 'string' }
}
}
}
},
ChatMessage: {
type: 'object',
properties: {
role: {
type: 'string',
enum: ['system', 'user', 'assistant', 'tool']
},
content: {
oneOf: [
{ type: 'string' },
{
type: 'array',
items: {
type: 'object',
properties: {
type: { type: 'string' },
text: { type: 'string' },
image_url: {
type: 'object',
properties: {
url: { type: 'string' }
}
}
}
}
}
]
},
name: { type: 'string' },
tool_calls: {
type: 'array',
items: {
type: 'object',
properties: {
id: { type: 'string' },
type: { type: 'string' },
function: {
type: 'object',
properties: {
name: { type: 'string' },
arguments: { type: 'string' }
}
}
}
}
}
}
},
ChatCompletionRequest: {
type: 'object',
required: ['model', 'messages'],
properties: {
model: {
type: 'string',
description: 'The model to use for completion, in format provider:model-id'
},
messages: {
type: 'array',
items: { $ref: '#/components/schemas/ChatMessage' }
},
temperature: {
type: 'number',
minimum: 0,
maximum: 2,
default: 1
},
max_tokens: {
type: 'integer',
minimum: 1
},
stream: {
type: 'boolean',
default: false
},
tools: {
type: 'array',
items: {
type: 'object',
properties: {
type: { type: 'string' },
function: {
type: 'object',
properties: {
name: { type: 'string' },
description: { type: 'string' },
parameters: { type: 'object' }
}
}
}
}
}
}
},
Model: {
type: 'object',
properties: {
id: { type: 'string' },
object: { type: 'string', enum: ['model'] },
created: { type: 'integer' },
owned_by: { type: 'string' }
}
},
MCPServer: {
type: 'object',
properties: {
id: { type: 'string' },
name: { type: 'string' },
command: { type: 'string' },
args: {
type: 'array',
items: { type: 'string' }
},
env: { type: 'object' },
disabled: { type: 'boolean' }
}
}
}
},
security: [
{
BearerAuth: []
}
]
},
apis: ['./src/main/apiServer/routes/*.ts', './src/main/apiServer/app.ts']
}
export function setupOpenAPIDocumentation(app: Express) {
try {
const specs = swaggerJSDoc(swaggerOptions)
// Serve OpenAPI JSON
app.get('/api-docs.json', (_req, res) => {
res.setHeader('Content-Type', 'application/json')
res.send(specs)
})
// Serve Swagger UI
app.use(
'/api-docs',
swaggerUi.serve,
swaggerUi.setup(specs, {
customCss: `
.swagger-ui .topbar { display: none; }
.swagger-ui .info .title { color: #1890ff; }
`,
customSiteTitle: 'Cherry Studio API Documentation'
})
)
logger.info('OpenAPI documentation setup complete')
logger.info('Documentation available at /api-docs')
logger.info('OpenAPI spec available at /api-docs.json')
} catch (error) {
logger.error('Failed to setup OpenAPI documentation:', error as Error)
}
}

View File

@ -1,225 +0,0 @@
import express, { Request, Response } from 'express'
import OpenAI from 'openai'
import { ChatCompletionCreateParams } from 'openai/resources'
import { loggerService } from '../../services/LoggerService'
import { chatCompletionService } from '../services/chat-completion'
import { getProviderByModel, getRealProviderModel } from '../utils'
const logger = loggerService.withContext('ApiServerChatRoutes')
const router = express.Router()
/**
* @swagger
* /v1/chat/completions:
* post:
* summary: Create chat completion
* description: Create a chat completion response, compatible with OpenAI API
* tags: [Chat]
* requestBody:
* required: true
* content:
* application/json:
* schema:
* $ref: '#/components/schemas/ChatCompletionRequest'
* responses:
* 200:
* description: Chat completion response
* content:
* application/json:
* schema:
* type: object
* properties:
* id:
* type: string
* object:
* type: string
* example: chat.completion
* created:
* type: integer
* model:
* type: string
* choices:
* type: array
* items:
* type: object
* properties:
* index:
* type: integer
* message:
* $ref: '#/components/schemas/ChatMessage'
* finish_reason:
* type: string
* usage:
* type: object
* properties:
* prompt_tokens:
* type: integer
* completion_tokens:
* type: integer
* total_tokens:
* type: integer
* text/plain:
* schema:
* type: string
* description: Server-sent events stream (when stream=true)
* 400:
* description: Bad request
* content:
* application/json:
* schema:
* $ref: '#/components/schemas/Error'
* 401:
* description: Unauthorized
* content:
* application/json:
* schema:
* $ref: '#/components/schemas/Error'
* 429:
* description: Rate limit exceeded
* content:
* application/json:
* schema:
* $ref: '#/components/schemas/Error'
* 500:
* description: Internal server error
* content:
* application/json:
* schema:
* $ref: '#/components/schemas/Error'
*/
router.post('/completions', async (req: Request, res: Response) => {
try {
const request: ChatCompletionCreateParams = req.body
if (!request) {
return res.status(400).json({
error: {
message: 'Request body is required',
type: 'invalid_request_error',
code: 'missing_body'
}
})
}
logger.info('Chat completion request:', {
model: request.model,
messageCount: request.messages?.length || 0,
stream: request.stream
})
// Validate request
const validation = chatCompletionService.validateRequest(request)
if (!validation.isValid) {
return res.status(400).json({
error: {
message: validation.errors.join('; '),
type: 'invalid_request_error',
code: 'validation_failed'
}
})
}
// Get provider
const provider = await getProviderByModel(request.model)
if (!provider) {
return res.status(400).json({
error: {
message: `Model "${request.model}" not found`,
type: 'invalid_request_error',
code: 'model_not_found'
}
})
}
// Validate model availability
const modelId = getRealProviderModel(request.model)
const model = provider.models?.find((m) => m.id === modelId)
if (!model) {
return res.status(400).json({
error: {
message: `Model "${modelId}" not available in provider "${provider.id}"`,
type: 'invalid_request_error',
code: 'model_not_available'
}
})
}
// Create OpenAI client
const client = new OpenAI({
baseURL: provider.apiHost,
apiKey: provider.apiKey
})
request.model = modelId
// Handle streaming
if (request.stream) {
const streamResponse = await client.chat.completions.create(request)
res.setHeader('Content-Type', 'text/plain; charset=utf-8')
res.setHeader('Cache-Control', 'no-cache')
res.setHeader('Connection', 'keep-alive')
try {
for await (const chunk of streamResponse as any) {
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
}
res.write('data: [DONE]\n\n')
res.end()
} catch (streamError: any) {
logger.error('Stream error:', streamError)
res.write(
`data: ${JSON.stringify({
error: {
message: 'Stream processing error',
type: 'server_error',
code: 'stream_error'
}
})}\n\n`
)
res.end()
}
return
}
// Handle non-streaming
const response = await client.chat.completions.create(request)
return res.json(response)
} catch (error: any) {
logger.error('Chat completion error:', error)
let statusCode = 500
let errorType = 'server_error'
let errorCode = 'internal_error'
let errorMessage = 'Internal server error'
if (error instanceof Error) {
errorMessage = error.message
if (error.message.includes('API key') || error.message.includes('authentication')) {
statusCode = 401
errorType = 'authentication_error'
errorCode = 'invalid_api_key'
} else if (error.message.includes('rate limit') || error.message.includes('quota')) {
statusCode = 429
errorType = 'rate_limit_error'
errorCode = 'rate_limit_exceeded'
} else if (error.message.includes('timeout') || error.message.includes('connection')) {
statusCode = 502
errorType = 'server_error'
errorCode = 'upstream_error'
}
}
return res.status(statusCode).json({
error: {
message: errorMessage,
type: errorType,
code: errorCode
}
})
}
})
export { router as chatRoutes }

View File

@ -1,153 +0,0 @@
import express, { Request, Response } from 'express'
import { loggerService } from '../../services/LoggerService'
import { mcpApiService } from '../services/mcp'
const logger = loggerService.withContext('ApiServerMCPRoutes')
const router = express.Router()
/**
* @swagger
* /v1/mcps:
* get:
* summary: List MCP servers
* description: Get a list of all configured Model Context Protocol servers
* tags: [MCP]
* responses:
* 200:
* description: List of MCP servers
* content:
* application/json:
* schema:
* type: object
* properties:
* success:
* type: boolean
* data:
* type: array
* items:
* $ref: '#/components/schemas/MCPServer'
* 503:
* description: Service unavailable
* content:
* application/json:
* schema:
* type: object
* properties:
* success:
* type: boolean
* example: false
* error:
* $ref: '#/components/schemas/Error'
*/
router.get('/', async (req: Request, res: Response) => {
try {
logger.info('Get all MCP servers request received')
const servers = await mcpApiService.getAllServers(req)
return res.json({
success: true,
data: servers
})
} catch (error: any) {
logger.error('Error fetching MCP servers:', error)
return res.status(503).json({
success: false,
error: {
message: `Failed to retrieve MCP servers: ${error.message}`,
type: 'service_unavailable',
code: 'servers_unavailable'
}
})
}
})
/**
* @swagger
* /v1/mcps/{server_id}:
* get:
* summary: Get MCP server info
* description: Get detailed information about a specific MCP server
* tags: [MCP]
* parameters:
* - in: path
* name: server_id
* required: true
* schema:
* type: string
* description: MCP server ID
* responses:
* 200:
* description: MCP server information
* content:
* application/json:
* schema:
* type: object
* properties:
* success:
* type: boolean
* data:
* $ref: '#/components/schemas/MCPServer'
* 404:
* description: MCP server not found
* content:
* application/json:
* schema:
* type: object
* properties:
* success:
* type: boolean
* example: false
* error:
* $ref: '#/components/schemas/Error'
*/
router.get('/:server_id', async (req: Request, res: Response) => {
try {
logger.info('Get MCP server info request received')
const server = await mcpApiService.getServerInfo(req.params.server_id)
if (!server) {
logger.warn('MCP server not found')
return res.status(404).json({
success: false,
error: {
message: 'MCP server not found',
type: 'not_found',
code: 'server_not_found'
}
})
}
return res.json({
success: true,
data: server
})
} catch (error: any) {
logger.error('Error fetching MCP server info:', error)
return res.status(503).json({
success: false,
error: {
message: `Failed to retrieve MCP server info: ${error.message}`,
type: 'service_unavailable',
code: 'server_info_unavailable'
}
})
}
})
// Connect to MCP server
router.all('/:server_id/mcp', async (req: Request, res: Response) => {
const server = await mcpApiService.getServerById(req.params.server_id)
if (!server) {
logger.warn('MCP server not found')
return res.status(404).json({
success: false,
error: {
message: 'MCP server not found',
type: 'not_found',
code: 'server_not_found'
}
})
}
return await mcpApiService.handleRequest(req, res, server)
})
export { router as mcpRoutes }

View File

@ -1,66 +0,0 @@
import express, { Request, Response } from 'express'
import { loggerService } from '../../services/LoggerService'
import { chatCompletionService } from '../services/chat-completion'
const logger = loggerService.withContext('ApiServerModelsRoutes')
const router = express.Router()
/**
* @swagger
* /v1/models:
* get:
* summary: List available models
* description: Returns a list of available AI models from all configured providers
* tags: [Models]
* responses:
* 200:
* description: List of available models
* content:
* application/json:
* schema:
* type: object
* properties:
* object:
* type: string
* example: list
* data:
* type: array
* items:
* $ref: '#/components/schemas/Model'
* 503:
* description: Service unavailable
* content:
* application/json:
* schema:
* $ref: '#/components/schemas/Error'
*/
router.get('/', async (_req: Request, res: Response) => {
try {
logger.info('Models list request received')
const models = await chatCompletionService.getModels()
if (models.length === 0) {
logger.warn('No models available from providers')
}
logger.info(`Returning ${models.length} models`)
return res.json({
object: 'list',
data: models
})
} catch (error: any) {
logger.error('Error fetching models:', error)
return res.status(503).json({
error: {
message: 'Failed to retrieve models',
type: 'service_unavailable',
code: 'models_unavailable'
}
})
}
})
export { router as modelsRoutes }

View File

@ -1,65 +0,0 @@
import { createServer } from 'node:http'
import { loggerService } from '../services/LoggerService'
import { app } from './app'
import { config } from './config'
const logger = loggerService.withContext('ApiServer')
export class ApiServer {
private server: ReturnType<typeof createServer> | null = null
async start(): Promise<void> {
if (this.server) {
logger.warn('Server already running')
return
}
// Load config
const { port, host, apiKey } = await config.load()
// Create server with Express app
this.server = createServer(app)
// Start server
return new Promise((resolve, reject) => {
this.server!.listen(port, host, () => {
logger.info(`API Server started at http://${host}:${port}`)
logger.info(`API Key: ${apiKey}`)
resolve()
})
this.server!.on('error', reject)
})
}
async stop(): Promise<void> {
if (!this.server) return
return new Promise((resolve) => {
this.server!.close(() => {
logger.info('API Server stopped')
this.server = null
resolve()
})
})
}
async restart(): Promise<void> {
await this.stop()
await config.reload()
await this.start()
}
isRunning(): boolean {
const hasServer = this.server !== null
const isListening = this.server?.listening || false
const result = hasServer && isListening
logger.debug('isRunning check:', { hasServer, isListening, result })
return result
}
}
export const apiServer = new ApiServer()

View File

@ -1,222 +0,0 @@
import OpenAI from 'openai'
import { ChatCompletionCreateParams } from 'openai/resources'
import { loggerService } from '../../services/LoggerService'
import {
getProviderByModel,
getRealProviderModel,
listAllAvailableModels,
OpenAICompatibleModel,
transformModelToOpenAI,
validateProvider
} from '../utils'
const logger = loggerService.withContext('ChatCompletionService')
export interface ModelData extends OpenAICompatibleModel {
provider_id: string
model_id: string
name: string
}
export interface ValidationResult {
isValid: boolean
errors: string[]
}
export class ChatCompletionService {
async getModels(): Promise<ModelData[]> {
try {
logger.info('Getting available models from providers')
const models = await listAllAvailableModels()
const modelData: ModelData[] = models.map((model) => {
const openAIModel = transformModelToOpenAI(model)
return {
...openAIModel,
provider_id: model.provider,
model_id: model.id,
name: model.name
}
})
logger.info(`Successfully retrieved ${modelData.length} models`)
return modelData
} catch (error: any) {
logger.error('Error getting models:', error)
return []
}
}
validateRequest(request: ChatCompletionCreateParams): ValidationResult {
const errors: string[] = []
// Validate model
if (!request.model) {
errors.push('Model is required')
} else if (typeof request.model !== 'string') {
errors.push('Model must be a string')
} else if (!request.model.includes(':')) {
errors.push('Model must be in format "provider:model_id"')
}
// Validate messages
if (!request.messages) {
errors.push('Messages array is required')
} else if (!Array.isArray(request.messages)) {
errors.push('Messages must be an array')
} else if (request.messages.length === 0) {
errors.push('Messages array cannot be empty')
} else {
// Validate each message
request.messages.forEach((message, index) => {
if (!message.role) {
errors.push(`Message ${index}: role is required`)
}
if (!message.content) {
errors.push(`Message ${index}: content is required`)
}
})
}
// Validate optional parameters
if (request.temperature !== undefined) {
if (typeof request.temperature !== 'number' || request.temperature < 0 || request.temperature > 2) {
errors.push('Temperature must be a number between 0 and 2')
}
}
if (request.max_tokens !== undefined) {
if (typeof request.max_tokens !== 'number' || request.max_tokens < 1) {
errors.push('max_tokens must be a positive number')
}
}
return {
isValid: errors.length === 0,
errors
}
}
async processCompletion(request: ChatCompletionCreateParams): Promise<OpenAI.Chat.Completions.ChatCompletion> {
try {
logger.info('Processing chat completion request:', {
model: request.model,
messageCount: request.messages.length,
stream: request.stream
})
// Validate request
const validation = this.validateRequest(request)
if (!validation.isValid) {
throw new Error(`Request validation failed: ${validation.errors.join(', ')}`)
}
// Get provider for the model
const provider = await getProviderByModel(request.model!)
if (!provider) {
throw new Error(`Provider not found for model: ${request.model}`)
}
// Validate provider
if (!validateProvider(provider)) {
throw new Error(`Provider validation failed for: ${provider.id}`)
}
// Extract model ID from the full model string
const modelId = getRealProviderModel(request.model)
// Create OpenAI client for the provider
const client = new OpenAI({
baseURL: provider.apiHost,
apiKey: provider.apiKey
})
// Prepare request with the actual model ID
const providerRequest = {
...request,
model: modelId,
stream: false
}
logger.debug('Sending request to provider:', {
provider: provider.id,
model: modelId,
apiHost: provider.apiHost
})
const response = (await client.chat.completions.create(providerRequest)) as OpenAI.Chat.Completions.ChatCompletion
logger.info('Successfully processed chat completion')
return response
} catch (error: any) {
logger.error('Error processing chat completion:', error)
throw error
}
}
async *processStreamingCompletion(
request: ChatCompletionCreateParams
): AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk> {
try {
logger.info('Processing streaming chat completion request:', {
model: request.model,
messageCount: request.messages.length
})
// Validate request
const validation = this.validateRequest(request)
if (!validation.isValid) {
throw new Error(`Request validation failed: ${validation.errors.join(', ')}`)
}
// Get provider for the model
const provider = await getProviderByModel(request.model!)
if (!provider) {
throw new Error(`Provider not found for model: ${request.model}`)
}
// Validate provider
if (!validateProvider(provider)) {
throw new Error(`Provider validation failed for: ${provider.id}`)
}
// Extract model ID from the full model string
const modelId = getRealProviderModel(request.model)
// Create OpenAI client for the provider
const client = new OpenAI({
baseURL: provider.apiHost,
apiKey: provider.apiKey
})
// Prepare streaming request
const streamingRequest = {
...request,
model: modelId,
stream: true as const
}
logger.debug('Sending streaming request to provider:', {
provider: provider.id,
model: modelId,
apiHost: provider.apiHost
})
const stream = await client.chat.completions.create(streamingRequest)
for await (const chunk of stream) {
yield chunk
}
logger.info('Successfully completed streaming chat completion')
} catch (error: any) {
logger.error('Error processing streaming chat completion:', error)
throw error
}
}
}
// Export singleton instance
export const chatCompletionService = new ChatCompletionService()

View File

@ -1,251 +0,0 @@
import mcpService from '@main/services/MCPService'
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp'
import {
isJSONRPCRequest,
JSONRPCMessage,
JSONRPCMessageSchema,
MessageExtraInfo
} from '@modelcontextprotocol/sdk/types'
import { MCPServer } from '@types'
import { randomUUID } from 'crypto'
import { EventEmitter } from 'events'
import { Request, Response } from 'express'
import { IncomingMessage, ServerResponse } from 'http'
import { loggerService } from '../../services/LoggerService'
import { reduxService } from '../../services/ReduxService'
import { getMcpServerById } from '../utils/mcp'
const logger = loggerService.withContext('MCPApiService')
const transports: Record<string, StreamableHTTPServerTransport> = {}
interface McpServerDTO {
id: MCPServer['id']
name: MCPServer['name']
type: MCPServer['type']
description: MCPServer['description']
url: string
}
interface McpServersResp {
servers: Record<string, McpServerDTO>
}
/**
* MCPApiService - API layer for MCP server management
*
* This service provides a REST API interface for MCP servers while integrating
* with the existing application architecture:
*
* 1. Uses ReduxService to access the renderer's Redux store directly
* 2. Syncs changes back to the renderer via Redux actions
* 3. Leverages existing MCPService for actual server connections
* 4. Provides session management for API clients
*/
class MCPApiService extends EventEmitter {
private transport: StreamableHTTPServerTransport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID()
})
constructor() {
super()
this.initMcpServer()
logger.silly('MCPApiService initialized')
}
private initMcpServer() {
this.transport.onmessage = this.onMessage
}
/**
* Get servers directly from Redux store
*/
private async getServersFromRedux(): Promise<MCPServer[]> {
try {
logger.silly('Getting servers from Redux store')
// Try to get from cache first (faster)
const cachedServers = reduxService.selectSync<MCPServer[]>('state.mcp.servers')
if (cachedServers && Array.isArray(cachedServers)) {
logger.silly(`Found ${cachedServers.length} servers in Redux cache`)
return cachedServers
}
// If cache is not available, get fresh data
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
return servers || []
} catch (error: any) {
logger.error('Failed to get servers from Redux:', error)
return []
}
}
// get all activated servers
async getAllServers(req: Request): Promise<McpServersResp> {
try {
const servers = await this.getServersFromRedux()
logger.silly(`Returning ${servers.length} servers`)
const resp: McpServersResp = {
servers: {}
}
for (const server of servers) {
if (server.isActive) {
resp.servers[server.id] = {
id: server.id,
name: server.name,
type: 'streamableHttp',
description: server.description,
url: `${req.protocol}://${req.host}/v1/mcps/${server.id}/mcp`
}
}
}
return resp
} catch (error: any) {
logger.error('Failed to get all servers:', error)
throw new Error('Failed to retrieve servers')
}
}
// get server by id
async getServerById(id: string): Promise<MCPServer | null> {
try {
logger.silly(`getServerById called with id: ${id}`)
const servers = await this.getServersFromRedux()
const server = servers.find((s) => s.id === id)
if (!server) {
logger.warn(`Server with id ${id} not found`)
return null
}
logger.silly(`Returning server with id ${id}`)
return server
} catch (error: any) {
logger.error(`Failed to get server with id ${id}:`, error)
throw new Error('Failed to retrieve server')
}
}
async getServerInfo(id: string): Promise<any> {
try {
logger.silly(`getServerInfo called with id: ${id}`)
const server = await this.getServerById(id)
if (!server) {
logger.warn(`Server with id ${id} not found`)
return null
}
logger.silly(`Returning server info for id ${id}`)
const client = await mcpService.initClient(server)
const tools = await client.listTools()
logger.info(`Server with id ${id} info:`, { tools: JSON.stringify(tools) })
// const [version, tools, prompts, resources] = await Promise.all([
// () => {
// try {
// return client.getServerVersion()
// } catch (error) {
// logger.error(`Failed to get server version for id ${id}:`, { error: error })
// return '1.0.0'
// }
// },
// (() => {
// try {
// return client.listTools()
// } catch (error) {
// logger.error(`Failed to list tools for id ${id}:`, { error: error })
// return []
// }
// })(),
// (() => {
// try {
// return client.listPrompts()
// } catch (error) {
// logger.error(`Failed to list prompts for id ${id}:`, { error: error })
// return []
// }
// })(),
// (() => {
// try {
// return client.listResources()
// } catch (error) {
// logger.error(`Failed to list resources for id ${id}:`, { error: error })
// return []
// }
// })()
// ])
return {
id: server.id,
name: server.name,
type: server.type,
description: server.description,
tools
}
} catch (error: any) {
logger.error(`Failed to get server info with id ${id}:`, error)
throw new Error('Failed to retrieve server info')
}
}
async handleRequest(req: Request, res: Response, server: MCPServer) {
const sessionId = req.headers['mcp-session-id'] as string | undefined
logger.silly(`Handling request for server with sessionId ${sessionId}`)
let transport: StreamableHTTPServerTransport
if (sessionId && transports[sessionId]) {
transport = transports[sessionId]
} else {
transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
onsessioninitialized: (sessionId) => {
transports[sessionId] = transport
}
})
transport.onclose = () => {
logger.info(`Transport for sessionId ${sessionId} closed`)
if (transport.sessionId) {
delete transports[transport.sessionId]
}
}
const mcpServer = await getMcpServerById(server.id)
if (mcpServer) {
await mcpServer.connect(transport)
}
}
const jsonpayload = req.body
const messages: JSONRPCMessage[] = []
if (Array.isArray(jsonpayload)) {
for (const payload of jsonpayload) {
const message = JSONRPCMessageSchema.parse(payload)
messages.push(message)
}
} else {
const message = JSONRPCMessageSchema.parse(jsonpayload)
messages.push(message)
}
for (const message of messages) {
if (isJSONRPCRequest(message)) {
if (!message.params) {
message.params = {}
}
if (!message.params._meta) {
message.params._meta = {}
}
message.params._meta.serverId = server.id
}
}
logger.info(`Request body`, { rawBody: req.body, messages: JSON.stringify(messages) })
await transport.handleRequest(req as IncomingMessage, res as ServerResponse, messages)
}
private onMessage(message: JSONRPCMessage, extra?: MessageExtraInfo) {
logger.info(`Received message: ${JSON.stringify(message)}`, extra)
// Handle message here
}
}
export const mcpApiService = new MCPApiService()

View File

@ -1,111 +0,0 @@
import { loggerService } from '@main/services/LoggerService'
import { reduxService } from '@main/services/ReduxService'
import { Model, Provider } from '@types'
const logger = loggerService.withContext('ApiServerUtils')
// OpenAI compatible model format
export interface OpenAICompatibleModel {
id: string
object: 'model'
created: number
owned_by: string
}
export async function getAvailableProviders(): Promise<Provider[]> {
try {
// Wait for store to be ready before accessing providers
const providers = await reduxService.select('state.llm.providers')
if (!providers || !Array.isArray(providers)) {
logger.warn('No providers found in Redux store, returning empty array')
return []
}
return providers.filter((p: Provider) => p.enabled)
} catch (error: any) {
logger.error('Failed to get providers from Redux store:', error)
return []
}
}
export async function listAllAvailableModels(): Promise<Model[]> {
try {
const providers = await getAvailableProviders()
return providers.map((p: Provider) => p.models || []).flat() as Model[]
} catch (error: any) {
logger.error('Failed to list available models:', error)
return []
}
}
export async function getProviderByModel(model: string): Promise<Provider | undefined> {
try {
if (!model || typeof model !== 'string') {
logger.warn(`Invalid model parameter: ${model}`)
return undefined
}
const providers = await getAvailableProviders()
const modelInfo = model.split(':')
if (modelInfo.length < 2) {
logger.warn(`Invalid model format, expected "provider:model": ${model}`)
return undefined
}
const providerId = modelInfo[0]
const provider = providers.find((p: Provider) => p.id === providerId)
if (!provider) {
logger.warn(`Provider not found for model: ${model}`)
return undefined
}
return provider
} catch (error: any) {
logger.error('Failed to get provider by model:', error)
return undefined
}
}
export function getRealProviderModel(modelStr: string): string {
return modelStr.split(':').slice(1).join(':')
}
export function transformModelToOpenAI(model: Model): OpenAICompatibleModel {
return {
id: `${model.provider}:${model.id}`,
object: 'model',
created: Math.floor(Date.now() / 1000),
owned_by: model.owned_by || model.provider
}
}
export function validateProvider(provider: Provider): boolean {
try {
if (!provider) {
return false
}
// Check required fields
if (!provider.id || !provider.type || !provider.apiKey || !provider.apiHost) {
logger.warn('Provider missing required fields:', {
id: !!provider.id,
type: !!provider.type,
apiKey: !!provider.apiKey,
apiHost: !!provider.apiHost
})
return false
}
// Check if provider is enabled
if (!provider.enabled) {
logger.debug(`Provider is disabled: ${provider.id}`)
return false
}
return true
} catch (error: any) {
logger.error('Error validating provider:', error)
return false
}
}

View File

@ -1,76 +0,0 @@
import mcpService from '@main/services/MCPService'
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
import { CallToolRequestSchema, ListToolsRequestSchema, ListToolsResult } from '@modelcontextprotocol/sdk/types.js'
import { MCPServer } from '@types'
import { loggerService } from '../../services/LoggerService'
import { reduxService } from '../../services/ReduxService'
const logger = loggerService.withContext('MCPApiService')
const cachedServers: Record<string, Server> = {}
async function handleListToolsRequest(request: any, extra: any): Promise<ListToolsResult> {
logger.debug('Handling list tools request', { request: request, extra: extra })
const serverId: string = request.params._meta.serverId
const serverConfig = await getMcpServerConfigById(serverId)
if (!serverConfig) {
throw new Error(`Server not found: ${serverId}`)
}
const client = await mcpService.initClient(serverConfig)
return await client.listTools()
}
async function handleCallToolRequest(request: any, extra: any): Promise<any> {
logger.debug('Handling call tool request', { request: request, extra: extra })
const serverId: string = request.params._meta.serverId
const serverConfig = await getMcpServerConfigById(serverId)
if (!serverConfig) {
throw new Error(`Server not found: ${serverId}`)
}
const client = await mcpService.initClient(serverConfig)
return client.callTool(request.params)
}
async function getMcpServerConfigById(id: string): Promise<MCPServer | undefined> {
const servers = await getServersFromRedux()
return servers.find((s) => s.id === id || s.name === id)
}
/**
* Get servers directly from Redux store
*/
async function getServersFromRedux(): Promise<MCPServer[]> {
try {
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
return servers || []
} catch (error: any) {
logger.error('Failed to get servers from Redux:', error)
return []
}
}
export async function getMcpServerById(id: string): Promise<Server> {
const server = cachedServers[id]
if (!server) {
const servers = await getServersFromRedux()
const mcpServer = servers.find((s) => s.id === id || s.name === id)
if (!mcpServer) {
throw new Error(`Server not found: ${id}`)
}
const createMcpServer = (name: string, version: string): Server => {
const server = new Server({ name: name, version }, { capabilities: { tools: {} } })
server.setRequestHandler(ListToolsRequestSchema, handleListToolsRequest)
server.setRequestHandler(CallToolRequestSchema, handleCallToolRequest)
return server
}
const newServer = createMcpServer(mcpServer.name, '0.1.0')
cachedServers[id] = newServer
return newServer
}
logger.silly('getMcpServer ', { server: server })
return server
}

View File

@ -27,7 +27,6 @@ import { registerShortcuts } from './services/ShortcutService'
import { TrayService } from './services/TrayService'
import { windowService } from './services/WindowService'
import process from 'node:process'
import { apiServerService } from './services/ApiServerService'
const logger = loggerService.withContext('MainEntry')
@ -140,17 +139,6 @@ if (!app.requestSingleInstanceLock()) {
//start selection assistant service
initSelectionService()
// Start API server if enabled
try {
const config = await apiServerService.getCurrentConfig()
logger.info('API server config:', config)
if (config.enabled) {
await apiServerService.start()
}
} catch (error: any) {
logger.error('Failed to check/start API server:', error)
}
})
registerProtocolClient(app)
@ -196,7 +184,6 @@ if (!app.requestSingleInstanceLock()) {
// 简单的资源清理,不阻塞退出流程
try {
await mcpService.cleanup()
await apiServerService.stop()
} catch (error) {
logger.warn('Error cleaning up MCP service:', error as Error)
}

View File

@ -13,7 +13,6 @@ import { FileMetadata, Provider, Shortcut, ThemeMode } from '@types'
import { BrowserWindow, dialog, ipcMain, ProxyConfig, session, shell, systemPreferences, webContents } from 'electron'
import { Notification } from 'src/renderer/src/types/notification'
import { apiServerService } from './services/ApiServerService'
import appService from './services/AppService'
import AppUpdater from './services/AppUpdater'
import BackupManager from './services/BackupManager'
@ -696,7 +695,4 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
(_, spanId: string, modelName: string, context: string, msg: any) =>
addStreamMessage(spanId, modelName, context, msg)
)
// API Server
apiServerService.registerIpcHandlers()
}

View File

@ -1,122 +0,0 @@
import fs from 'node:fs'
import path from 'node:path'
import { windowService } from '@main/services/WindowService'
import { getFileExt } from '@main/utils/file'
import { FileMetadata, OcrProvider } from '@types'
import { app } from 'electron'
import pdfjs from 'pdfjs-dist'
import { TypedArray } from 'pdfjs-dist/types/src/display/api'
export default abstract class BaseOcrProvider {
protected provider: OcrProvider
public storageDir = path.join(app.getPath('userData'), 'Data', 'Files')
constructor(provider: OcrProvider) {
if (!provider) {
throw new Error('OCR provider is not set')
}
this.provider = provider
}
abstract parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata; quota?: number }>
/**
*
* Data/Files/{file.id}
* @param file
* @returns null
*/
public async checkIfAlreadyProcessed(file: FileMetadata): Promise<FileMetadata | null> {
try {
// 检查 Data/Files/{file.id} 是否是目录
const preprocessDirPath = path.join(this.storageDir, file.id)
if (fs.existsSync(preprocessDirPath)) {
const stats = await fs.promises.stat(preprocessDirPath)
// 如果是目录,说明已经被预处理过
if (stats.isDirectory()) {
// 查找目录中的处理结果文件
const files = await fs.promises.readdir(preprocessDirPath)
// 查找主要的处理结果文件(.md 或 .txt
const processedFile = files.find((fileName) => fileName.endsWith('.md') || fileName.endsWith('.txt'))
if (processedFile) {
const processedFilePath = path.join(preprocessDirPath, processedFile)
const processedStats = await fs.promises.stat(processedFilePath)
const ext = getFileExt(processedFile)
return {
...file,
name: file.name.replace(file.ext, ext),
path: processedFilePath,
ext: ext,
size: processedStats.size,
created_at: processedStats.birthtime.toISOString()
}
}
}
}
return null
} catch (error) {
// 如果检查过程中出现错误返回null表示未处理
return null
}
}
/**
*
*/
public delay = (ms: number): Promise<void> => {
return new Promise((resolve) => setTimeout(resolve, ms))
}
public async readPdf(
source: string | URL | TypedArray,
passwordCallback?: (fn: (password: string) => void, reason: string) => string
) {
const documentLoadingTask = pdfjs.getDocument(source)
if (passwordCallback) {
documentLoadingTask.onPassword = passwordCallback
}
const document = await documentLoadingTask.promise
return document
}
public async sendOcrProgress(sourceId: string, progress: number): Promise<void> {
const mainWindow = windowService.getMainWindow()
mainWindow?.webContents.send('file-ocr-progress', {
itemId: sourceId,
progress: progress
})
}
/**
*
* @param fileId id
* @param filePaths
* @returns
*/
public moveToAttachmentsDir(fileId: string, filePaths: string[]): string[] {
const attachmentsPath = path.join(this.storageDir, fileId)
if (!fs.existsSync(attachmentsPath)) {
fs.mkdirSync(attachmentsPath, { recursive: true })
}
const movedPaths: string[] = []
for (const filePath of filePaths) {
if (fs.existsSync(filePath)) {
const fileName = path.basename(filePath)
const destPath = path.join(attachmentsPath, fileName)
fs.copyFileSync(filePath, destPath)
fs.unlinkSync(filePath) // 删除原文件,实现"移动"
movedPaths.push(destPath)
}
}
return movedPaths
}
}

View File

@ -1,12 +0,0 @@
import { FileMetadata, OcrProvider } from '@types'
import BaseOcrProvider from './BaseOcrProvider'
export default class DefaultOcrProvider extends BaseOcrProvider {
constructor(provider: OcrProvider) {
super(provider)
}
public parseFile(): Promise<{ processedFile: FileMetadata }> {
throw new Error('Method not implemented.')
}
}

View File

@ -1,130 +0,0 @@
import { loggerService } from '@logger'
import { isMac } from '@main/constant'
import { FileMetadata, OcrProvider } from '@types'
import * as fs from 'fs'
import * as path from 'path'
import { TextItem } from 'pdfjs-dist/types/src/display/api'
import BaseOcrProvider from './BaseOcrProvider'
const logger = loggerService.withContext('MacSysOcrProvider')
export default class MacSysOcrProvider extends BaseOcrProvider {
private readonly MIN_TEXT_LENGTH = 1000
private MacOCR: any
private async initMacOCR() {
if (!isMac) {
throw new Error('MacSysOcrProvider is only available on macOS')
}
if (!this.MacOCR) {
try {
// @ts-ignore This module is optional and only installed/available on macOS. Runtime checks prevent execution on other platforms.
const module = await import('@cherrystudio/mac-system-ocr')
this.MacOCR = module.default
} catch (error) {
logger.error('Failed to load mac-system-ocr:', error as Error)
throw error
}
}
return this.MacOCR
}
private getRecognitionLevel(level?: number) {
return level === 0 ? this.MacOCR.RECOGNITION_LEVEL_FAST : this.MacOCR.RECOGNITION_LEVEL_ACCURATE
}
constructor(provider: OcrProvider) {
super(provider)
}
private async processPages(
results: any,
totalPages: number,
sourceId: string,
writeStream: fs.WriteStream
): Promise<void> {
await this.initMacOCR()
// TODO: 下个版本后面使用批处理以及p-queue来优化
for (let i = 0; i < totalPages; i++) {
// Convert pages to buffers
const pageNum = i + 1
const pageBuffer = await results.getPage(pageNum)
// Process batch
const ocrResult = await this.MacOCR.recognizeFromBuffer(pageBuffer, {
ocrOptions: {
recognitionLevel: this.getRecognitionLevel(this.provider.options?.recognitionLevel),
minConfidence: this.provider.options?.minConfidence || 0.5
}
})
// Write results in order
writeStream.write(ocrResult.text + '\n')
// Update progress
await this.sendOcrProgress(sourceId, (pageNum / totalPages) * 100)
}
}
public async isScanPdf(buffer: Buffer): Promise<boolean> {
const doc = await this.readPdf(new Uint8Array(buffer))
const pageLength = doc.numPages
let counts = 0
const pagesToCheck = Math.min(pageLength, 10)
for (let i = 0; i < pagesToCheck; i++) {
const page = await doc.getPage(i + 1)
const pageData = await page.getTextContent()
const pageText = pageData.items.map((item) => (item as TextItem).str).join('')
counts += pageText.length
if (counts >= this.MIN_TEXT_LENGTH) {
return false
}
}
return true
}
public async parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata }> {
logger.info(`Starting OCR process for file: ${file.name}`)
if (file.ext === '.pdf') {
try {
const { pdf } = await import('@cherrystudio/pdf-to-img-napi')
const pdfBuffer = await fs.promises.readFile(file.path)
const results = await pdf(pdfBuffer, {
scale: 2
})
const totalPages = results.length
const baseDir = path.dirname(file.path)
const baseName = path.basename(file.path, path.extname(file.path))
const txtFileName = `${baseName}.txt`
const txtFilePath = path.join(baseDir, txtFileName)
const writeStream = fs.createWriteStream(txtFilePath)
await this.processPages(results, totalPages, sourceId, writeStream)
await new Promise<void>((resolve, reject) => {
writeStream.end(() => {
logger.info(`OCR process completed successfully for ${file.origin_name}`)
resolve()
})
writeStream.on('error', reject)
})
const movedPaths = this.moveToAttachmentsDir(file.id, [txtFilePath])
return {
processedFile: {
...file,
name: txtFileName,
path: movedPaths[0],
ext: '.txt',
size: fs.statSync(movedPaths[0]).size
}
}
} catch (error) {
logger.error('Error during OCR process:', error as Error)
throw error
}
}
return { processedFile: file }
}
}

View File

@ -1,26 +0,0 @@
import { FileMetadata, OcrProvider as Provider } from '@types'
import BaseOcrProvider from './BaseOcrProvider'
import OcrProviderFactory from './OcrProviderFactory'
export default class OcrProvider {
private sdk: BaseOcrProvider
constructor(provider: Provider) {
this.sdk = OcrProviderFactory.create(provider)
}
public async parseFile(
sourceId: string,
file: FileMetadata
): Promise<{ processedFile: FileMetadata; quota?: number }> {
return this.sdk.parseFile(sourceId, file)
}
/**
*
* @param file
* @returns null
*/
public async checkIfAlreadyProcessed(file: FileMetadata): Promise<FileMetadata | null> {
return this.sdk.checkIfAlreadyProcessed(file)
}
}

View File

@ -1,23 +0,0 @@
import { loggerService } from '@logger'
import { isMac } from '@main/constant'
import { OcrProvider } from '@types'
import BaseOcrProvider from './BaseOcrProvider'
import DefaultOcrProvider from './DefaultOcrProvider'
import MacSysOcrProvider from './MacSysOcrProvider'
const logger = loggerService.withContext('OcrProviderFactory')
export default class OcrProviderFactory {
static create(provider: OcrProvider): BaseOcrProvider {
switch (provider.id) {
case 'system':
if (!isMac) {
logger.warn('System OCR provider is only available on macOS')
}
return new MacSysOcrProvider(provider)
default:
return new DefaultOcrProvider(provider)
}
}
}

View File

@ -1,17 +1,18 @@
import fs from 'node:fs'
import path from 'node:path'
import { loggerService } from '@logger'
import { windowService } from '@main/services/WindowService'
import { getFileExt } from '@main/utils/file'
import { getFileExt, getTempDir } from '@main/utils/file'
import { FileMetadata, PreprocessProvider } from '@types'
import { app } from 'electron'
import pdfjs from 'pdfjs-dist'
import { TypedArray } from 'pdfjs-dist/types/src/display/api'
import { PDFDocument } from 'pdf-lib'
const logger = loggerService.withContext('BasePreprocessProvider')
export default abstract class BasePreprocessProvider {
protected provider: PreprocessProvider
protected userId?: string
public storageDir = path.join(app.getPath('userData'), 'Data', 'Files')
public storageDir = path.join(getTempDir(), 'preprocess')
constructor(provider: PreprocessProvider, userId?: string) {
if (!provider) {
@ -19,7 +20,19 @@ export default abstract class BasePreprocessProvider {
}
this.provider = provider
this.userId = userId
this.ensureDirectories()
}
private ensureDirectories() {
try {
if (!fs.existsSync(this.storageDir)) {
fs.mkdirSync(this.storageDir, { recursive: true })
}
} catch (error) {
logger.error('Failed to create directories:', error as Error)
}
}
abstract parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata; quota?: number }>
abstract checkQuota(): Promise<number>
@ -77,17 +90,11 @@ export default abstract class BasePreprocessProvider {
return new Promise((resolve) => setTimeout(resolve, ms))
}
public async readPdf(
source: string | URL | TypedArray,
passwordCallback?: (fn: (password: string) => void, reason: string) => string
) {
const documentLoadingTask = pdfjs.getDocument(source)
if (passwordCallback) {
documentLoadingTask.onPassword = passwordCallback
public async readPdf(buffer: Buffer) {
const pdfDoc = await PDFDocument.load(buffer)
return {
numPages: pdfDoc.getPageCount()
}
const document = await documentLoadingTask.promise
return document
}
public async sendPreprocessProgress(sourceId: string, progress: number): Promise<void> {

View File

@ -39,7 +39,7 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider {
private async validateFile(filePath: string): Promise<void> {
const pdfBuffer = await fs.promises.readFile(filePath)
const doc = await this.readPdf(new Uint8Array(pdfBuffer))
const doc = await this.readPdf(pdfBuffer)
// 文件页数小于1000页
if (doc.numPages >= 1000) {

View File

@ -115,7 +115,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
private async validateFile(filePath: string): Promise<void> {
const pdfBuffer = await fs.promises.readFile(filePath)
const doc = await this.readPdf(new Uint8Array(pdfBuffer))
const doc = await this.readPdf(pdfBuffer)
// 文件页数小于600页
if (doc.numPages >= 600) {
@ -178,7 +178,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
try {
// 下载ZIP文件
const response = await axios.get(zipUrl, { responseType: 'arraybuffer' })
fs.writeFileSync(zipPath, response.data)
fs.writeFileSync(zipPath, Buffer.from(response.data))
logger.info(`Downloaded ZIP file: ${zipPath}`)
// 确保提取目录存在

View File

@ -1,108 +0,0 @@
import { IpcChannel } from '@shared/IpcChannel'
import { ApiServerConfig } from '@types'
import { ipcMain } from 'electron'
import { apiServer } from '../apiServer'
import { config } from '../apiServer/config'
import { loggerService } from './LoggerService'
const logger = loggerService.withContext('ApiServerService')
export class ApiServerService {
constructor() {
// Use the new clean implementation
}
async start(): Promise<void> {
try {
await apiServer.start()
logger.info('API Server started successfully')
} catch (error: any) {
logger.error('Failed to start API Server:', error)
throw error
}
}
async stop(): Promise<void> {
try {
await apiServer.stop()
logger.info('API Server stopped successfully')
} catch (error: any) {
logger.error('Failed to stop API Server:', error)
throw error
}
}
async restart(): Promise<void> {
try {
await apiServer.restart()
logger.info('API Server restarted successfully')
} catch (error: any) {
logger.error('Failed to restart API Server:', error)
throw error
}
}
isRunning(): boolean {
return apiServer.isRunning()
}
async getCurrentConfig(): Promise<ApiServerConfig> {
return await config.get()
}
registerIpcHandlers(): void {
// API Server
ipcMain.handle(IpcChannel.ApiServer_Start, async () => {
try {
await this.start()
return { success: true }
} catch (error: any) {
return { success: false, error: error instanceof Error ? error.message : 'Unknown error' }
}
})
ipcMain.handle(IpcChannel.ApiServer_Stop, async () => {
try {
await this.stop()
return { success: true }
} catch (error: any) {
return { success: false, error: error instanceof Error ? error.message : 'Unknown error' }
}
})
ipcMain.handle(IpcChannel.ApiServer_Restart, async () => {
try {
await this.restart()
return { success: true }
} catch (error: any) {
return { success: false, error: error instanceof Error ? error.message : 'Unknown error' }
}
})
ipcMain.handle(IpcChannel.ApiServer_GetStatus, async () => {
try {
const config = await this.getCurrentConfig()
return {
running: this.isRunning(),
config
}
} catch (error: any) {
return {
running: this.isRunning(),
config: null
}
}
})
ipcMain.handle(IpcChannel.ApiServer_GetConfig, async () => {
try {
return await this.getCurrentConfig()
} catch (error: any) {
return null
}
})
}
}
// Export singleton instance
export const apiServerService = new ApiServerService()

View File

@ -16,7 +16,7 @@ import { writeFileSync } from 'fs'
import { readFile } from 'fs/promises'
import officeParser from 'officeparser'
import * as path from 'path'
import pdfjs from 'pdfjs-dist'
import { PDFDocument } from 'pdf-lib'
import { chdir } from 'process'
import { v4 as uuidv4 } from 'uuid'
import WordExtractor from 'word-extractor'
@ -367,10 +367,8 @@ class FileStorage {
const filePath = path.join(this.storageDir, id)
const buffer = await fs.promises.readFile(filePath)
const doc = await pdfjs.getDocument({ data: buffer }).promise
const pages = doc.numPages
await doc.destroy()
return pages
const pdfDoc = await PDFDocument.load(buffer)
return pdfDoc.getPageCount()
}
public binaryImage = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: Buffer; mime: string }> => {

View File

@ -25,7 +25,6 @@ import { loggerService } from '@logger'
import Embeddings from '@main/knowledge/embeddings/Embeddings'
import { addFileLoader } from '@main/knowledge/loader'
import { NoteLoader } from '@main/knowledge/loader/noteLoader'
import OcrProvider from '@main/knowledge/ocr/OcrProvider'
import PreprocessProvider from '@main/knowledge/preprocess/PreprocessProvider'
import Reranker from '@main/knowledge/reranker/Reranker'
import { windowService } from '@main/services/WindowService'
@ -687,14 +686,9 @@ class KnowledgeService {
userId: string
): Promise<FileMetadata> => {
let fileToProcess: FileMetadata = file
if (base.preprocessOrOcrProvider && file.ext.toLowerCase() === '.pdf') {
if (base.preprocessProvider && file.ext.toLowerCase() === '.pdf') {
try {
let provider: PreprocessProvider | OcrProvider
if (base.preprocessOrOcrProvider.type === 'preprocess') {
provider = new PreprocessProvider(base.preprocessOrOcrProvider.provider, userId)
} else {
provider = new OcrProvider(base.preprocessOrOcrProvider.provider)
}
const provider = new PreprocessProvider(base.preprocessProvider.provider, userId)
// Check if file has already been preprocessed
const alreadyProcessed = await provider.checkIfAlreadyProcessed(file)
if (alreadyProcessed) {
@ -728,8 +722,8 @@ class KnowledgeService {
userId: string
): Promise<number> => {
try {
if (base.preprocessOrOcrProvider && base.preprocessOrOcrProvider.type === 'preprocess') {
const provider = new PreprocessProvider(base.preprocessOrOcrProvider.provider, userId)
if (base.preprocessProvider && base.preprocessProvider.type === 'preprocess') {
const provider = new PreprocessProvider(base.preprocessProvider.provider, userId)
return await provider.checkQuota()
}
throw new Error('No preprocess provider configured')

View File

@ -66,7 +66,7 @@ export class ProxyManager {
try {
if (config?.mode === this.config?.mode && config?.proxyRules === this.config?.proxyRules) {
logger.info('proxy config is the same, skip configure')
logger.debug('proxy config is the same, skip configure')
return
}

View File

@ -319,6 +319,13 @@ export class WindowService {
private setupWindowLifecycleEvents(mainWindow: BrowserWindow) {
mainWindow.on('close', (event) => {
// save data before when close window
try {
mainWindow.webContents.send(IpcChannel.App_SaveData)
} catch (error) {
logger.error('Failed to save data:', error as Error)
}
// 如果已经触发退出,直接退出
if (app.isQuitting) {
return app.quit()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 182 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

View File

@ -35,6 +35,7 @@ interface DraggableVirtualListProps<T> {
ref?: React.Ref<HTMLDivElement>
className?: string
style?: React.CSSProperties
scrollerStyle?: React.CSSProperties
itemStyle?: React.CSSProperties
itemContainerStyle?: React.CSSProperties
droppableProps?: Partial<DroppableProps>
@ -43,6 +44,7 @@ interface DraggableVirtualListProps<T> {
onDragEnd?: OnDragEndResponder
list: T[]
itemKey?: (index: number) => Key
estimateSize?: (index: number) => number
overscan?: number
header?: React.ReactNode
children: (item: T, index: number) => React.ReactNode
@ -59,6 +61,7 @@ function DraggableVirtualList<T>({
ref,
className,
style,
scrollerStyle,
itemStyle,
itemContainerStyle,
droppableProps,
@ -67,6 +70,7 @@ function DraggableVirtualList<T>({
onDragEnd,
list,
itemKey,
estimateSize: _estimateSize,
overscan = 5,
header,
children
@ -88,12 +92,15 @@ function DraggableVirtualList<T>({
count: list?.length ?? 0,
getScrollElement: useCallback(() => parentRef.current, []),
getItemKey: itemKey,
estimateSize: useCallback(() => 50, []),
estimateSize: useCallback((index) => _estimateSize?.(index) ?? 50, [_estimateSize]),
overscan
})
return (
<div ref={ref} className={`${className} draggable-virtual-list`} style={{ height: '100%', ...style }}>
<div
ref={ref}
className={`${className} draggable-virtual-list`}
style={{ height: '100%', display: 'flex', flexDirection: 'column', ...style }}>
<DragDropContext onDragStart={onDragStart} onDragEnd={_onDragEnd}>
{header}
<Droppable
@ -128,6 +135,7 @@ function DraggableVirtualList<T>({
{...provided.droppableProps}
className="virtual-scroller"
style={{
...scrollerStyle,
height: '100%',
width: '100%',
overflowY: 'auto',

View File

@ -1,20 +1,41 @@
import { SVGProps } from 'react'
export function SvgSpinners180Ring(props: SVGProps<SVGSVGElement>) {
// 避免与全局样式冲突
const animationClassName = 'svg-spinner-anim-180-ring'
return (
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" {...props}>
{/* Icon from SVG Spinners by Utkarsh Verma - https://github.com/n3r4zzurr0/svg-spinners/blob/main/LICENSE */}
<path
fill="currentColor"
d="M12,4a8,8,0,0,1,7.89,6.7A1.53,1.53,0,0,0,21.38,12h0a1.5,1.5,0,0,0,1.48-1.75,11,11,0,0,0-21.72,0A1.5,1.5,0,0,0,2.62,12h0a1.53,1.53,0,0,0,1.49-1.3A8,8,0,0,1,12,4Z">
<animateTransform
attributeName="transform"
dur="0.75s"
repeatCount="indefinite"
type="rotate"
values="0 12 12;360 12 12"></animateTransform>
</path>
</svg>
<>
{/* CSS transform 硬件加速 */}
<style>
{`
@keyframes svg-spinner-rotate-180-ring {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
.${animationClassName} {
transform-origin: center;
animation: svg-spinner-rotate-180-ring 0.75s linear infinite;
}
`}
</style>
<svg
xmlns="http://www.w3.org/2000/svg"
width="1em"
height="1em"
viewBox="0 0 24 24"
{...props}
className={`${animationClassName} ${props.className || ''}`.trim()}>
{/* Icon from SVG Spinners by Utkarsh Verma - https://github.com/n3r4zzurr0/svg-spinners/blob/main/LICENSE */}
<path
fill="currentColor"
d="M12,4a8,8,0,0,1,7.89,6.7A1.53,1.53,0,0,0,21.38,12h0a1.5,1.5,0,0,0,1.48-1.75,11,11,0,0,0-21.72,0A1.5,1.5,0,0,0,2.62,12h0a1.53,1.53,0,0,0,1.49-1.3A8,8,0,0,1,12,4Z"></path>
</svg>
</>
)
}
export default SvgSpinners180Ring

View File

@ -143,7 +143,7 @@ const MinappPopupContainer: React.FC = () => {
const { pinned, updatePinnedMinapps } = useMinapps()
const { t } = useTranslation()
const backgroundColor = useNavBackgroundColor()
const { isLeftNavbar, isTopNavbar } = useNavbarPosition()
const { isTopNavbar } = useNavbarPosition()
const dispatch = useAppDispatch()
/** control the drawer open or close */
@ -165,6 +165,8 @@ const MinappPopupContainer: React.FC = () => {
/** whether the minapps open link external is enabled */
const { minappsOpenLinkExternal } = useSettings()
const { isLeftNavbar } = useNavbarPosition()
const isInDevelopment = process.env.NODE_ENV === 'development'
useBridge()
@ -403,7 +405,7 @@ const MinappPopupContainer: React.FC = () => {
</Tooltip>
)}
<Spacer />
<ButtonsGroup className={isWin || isLinux ? 'windows' : ''} isTopNavBar={isTopNavbar}>
<ButtonsGroup className={isWin || isLinux ? 'windows' : ''}>
<Tooltip title={t('minapp.popup.goBack')} mouseEnterDelay={0.8} placement="bottom">
<TitleButton onClick={() => handleGoBack(appInfo.id)}>
<ArrowLeftOutlined />
@ -505,7 +507,6 @@ const MinappPopupContainer: React.FC = () => {
closeIcon={null}
style={{
marginLeft: isLeftNavbar ? 'var(--sidebar-width)' : 0,
marginTop: isTopNavbar ? 'var(--navbar-height)' : 0,
backgroundColor: window.root.style.background
}}>
{/* 在所有小程序中显示GoogleLoginTip */}
@ -540,7 +541,7 @@ const TitleContainer = styled.div`
padding-left: ${isMac ? '20px' : '10px'};
}
[navbar-position='top'] & {
padding-left: ${isMac ? '20px' : '10px'};
padding-left: ${isMac ? '80px' : '10px'};
border-bottom: 0.5px solid var(--color-border);
}
`
@ -562,14 +563,14 @@ const TitleTextTooltip = styled.span`
}
`
const ButtonsGroup = styled.div<{ isTopNavBar: boolean }>`
const ButtonsGroup = styled.div`
display: flex;
flex-direction: row;
align-items: center;
gap: 5px;
-webkit-app-region: no-drag;
&.windows {
margin-right: ${(props) => (props.isTopNavBar ? 0 : isWin ? '130px' : isLinux ? '100px' : 0)};
margin-right: ${isWin ? '130px' : isLinux ? '100px' : 0};
background-color: var(--color-background-mute);
border-radius: 50px;
padding: 0 3px;

View File

@ -23,7 +23,7 @@ const WebviewContainer = memo(
}) => {
const webviewRef = useRef<WebviewTag | null>(null)
const { enableSpellCheck } = useSettings()
const { isLeftNavbar, isTopNavbar } = useNavbarPosition()
const { isLeftNavbar } = useNavbarPosition()
const setRef = (appid: string) => {
onSetRefCallback(appid, null)
@ -74,7 +74,7 @@ const WebviewContainer = memo(
const WebviewStyle: React.CSSProperties = {
width: isLeftNavbar ? 'calc(100vw - var(--sidebar-width))' : '100vw',
height: isTopNavbar ? 'calc(100vh - var(--navbar-height) - var(--navbar-height))' : '100vh',
height: 'calc(100vh - var(--navbar-height))',
backgroundColor: 'var(--color-background)',
display: 'inline-flex'
}

View File

@ -3,14 +3,14 @@ import CustomTag from '@renderer/components/CustomTag'
import ExpandableText from '@renderer/components/ExpandableText'
import ModelIdWithTags from '@renderer/components/ModelIdWithTags'
import NewApiBatchAddModelPopup from '@renderer/components/ModelList/NewApiBatchAddModelPopup'
import { DynamicVirtualList } from '@renderer/components/VirtualList'
import { getModelLogo } from '@renderer/config/models'
import FileItem from '@renderer/pages/files/FileItem'
import { Model, Provider } from '@renderer/types'
import { defaultRangeExtractor, useVirtualizer } from '@tanstack/react-virtual'
import { Button, Flex, Tooltip } from 'antd'
import { Avatar } from 'antd'
import { ChevronRight } from 'lucide-react'
import React, { memo, useCallback, useMemo, useRef, useState } from 'react'
import React, { memo, useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@ -39,8 +39,6 @@ interface ManageModelsListProps {
const ManageModelsList: React.FC<ManageModelsListProps> = ({ modelGroups, provider, onAddModel, onRemoveModel }) => {
const { t } = useTranslation()
const scrollerRef = useRef<HTMLDivElement>(null)
const activeStickyIndexRef = useRef(0)
const [collapsedGroups, setCollapsedGroups] = useState(new Set<string>())
const handleGroupToggle = useCallback((groupName: string) => {
@ -74,33 +72,6 @@ const ManageModelsList: React.FC<ManageModelsListProps> = ({ modelGroups, provid
return rows
}, [modelGroups, collapsedGroups])
// 找到所有组 header 的索引
const stickyIndexes = useMemo(() => {
return flatRows.map((row, index) => (row.type === 'group' ? index : -1)).filter((index) => index !== -1)
}, [flatRows])
const isSticky = useCallback((index: number) => stickyIndexes.includes(index), [stickyIndexes])
const isActiveSticky = useCallback((index: number) => activeStickyIndexRef.current === index, [])
// 自定义 range extractor 用于 sticky header
const rangeExtractor = useCallback(
(range: any) => {
activeStickyIndexRef.current = [...stickyIndexes].reverse().find((index) => range.startIndex >= index) ?? 0
const next = new Set([activeStickyIndexRef.current, ...defaultRangeExtractor(range)])
return [...next].sort((a, b) => a - b)
},
[stickyIndexes]
)
const virtualizer = useVirtualizer({
count: flatRows.length,
getScrollElement: () => scrollerRef.current,
estimateSize: () => 42,
rangeExtractor,
overscan: 5
})
const renderGroupTools = useCallback(
(models: Model[]) => {
const isAllInProvider = models.every((model) => isModelInProvider(provider, model.id))
@ -153,79 +124,47 @@ const ManageModelsList: React.FC<ManageModelsListProps> = ({ modelGroups, provid
[provider, onRemoveModel, onAddModel, t]
)
const virtualItems = virtualizer.getVirtualItems()
return (
<ListContainer ref={scrollerRef}>
<div
style={{
height: `${virtualizer.getTotalSize()}px`,
width: '100%',
position: 'relative'
}}>
{virtualItems.map((virtualItem) => {
const row = flatRows[virtualItem.index]
const isRowSticky = isSticky(virtualItem.index)
const isRowActiveSticky = isActiveSticky(virtualItem.index)
const isCollapsed = row.type === 'group' && collapsedGroups.has(row.groupName)
if (!row) return null
<DynamicVirtualList
list={flatRows}
estimateSize={useCallback(() => 60, [])}
isSticky={useCallback((index: number) => flatRows[index].type === 'group', [flatRows])}
overscan={5}
scrollerStyle={{
paddingRight: '10px'
}}
itemContainerStyle={{
paddingBottom: '8px'
}}>
{(row) => {
if (row.type === 'group') {
const isCollapsed = collapsedGroups.has(row.groupName)
return (
<div
key={virtualItem.index}
data-index={virtualItem.index}
ref={virtualizer.measureElement}
style={{
...(isRowSticky
? {
background: 'var(--color-background)',
zIndex: 1
}
: {}),
...(isRowActiveSticky
? {
position: 'sticky'
}
: {
position: 'absolute',
transform: `translateY(${virtualItem.start}px)`
}),
top: 0,
left: 0,
width: '100%'
}}>
{row.type === 'group' ? (
<GroupHeader onClick={() => handleGroupToggle(row.groupName)}>
<Flex align="center" gap={10} style={{ flex: 1 }}>
<ChevronRight
size={16}
color="var(--color-text-3)"
strokeWidth={1.5}
style={{ transform: isCollapsed ? 'rotate(0deg)' : 'rotate(90deg)' }}
/>
<span style={{ fontWeight: 'bold', fontSize: '14px' }}>{row.groupName}</span>
<CustomTag color="#02B96B" size={10}>
{row.models.length}
</CustomTag>
</Flex>
{renderGroupTools(row.models)}
</GroupHeader>
) : (
<div style={{ padding: '4px 0' }}>
<ModelListItem
model={row.model}
provider={provider}
onAddModel={onAddModel}
onRemoveModel={onRemoveModel}
/>
</div>
)}
</div>
<GroupHeader
style={{ background: 'var(--color-background)' }}
onClick={() => handleGroupToggle(row.groupName)}>
<Flex align="center" gap={10} style={{ flex: 1 }}>
<ChevronRight
size={16}
color="var(--color-text-3)"
strokeWidth={1.5}
style={{ transform: isCollapsed ? 'rotate(0deg)' : 'rotate(90deg)' }}
/>
<span style={{ fontWeight: 'bold', fontSize: '14px' }}>{row.groupName}</span>
<CustomTag color="#02B96B" size={10}>
{row.models.length}
</CustomTag>
</Flex>
{renderGroupTools(row.models)}
</GroupHeader>
)
})}
</div>
</ListContainer>
}
return (
<ModelListItem model={row.model} provider={provider} onAddModel={onAddModel} onRemoveModel={onRemoveModel} />
)
}}
</DynamicVirtualList>
)
}
@ -262,18 +201,12 @@ const ModelListItem: React.FC<ModelListItemProps> = memo(({ model, provider, onA
)
})
const ListContainer = styled.div`
height: calc(100vh - 300px);
overflow: auto;
padding-right: 10px;
`
const GroupHeader = styled.div`
display: flex;
align-items: center;
justify-content: space-between;
padding: 0 8px;
min-height: 48px;
min-height: 50px;
color: var(--color-text);
cursor: pointer;
`

View File

@ -16,8 +16,8 @@ import { useAppDispatch } from '@renderer/store'
import { setModel } from '@renderer/store/assistants'
import { Model } from '@renderer/types'
import { filterModelsByKeywords } from '@renderer/utils'
import { Button, Flex, Spin, Tooltip } from 'antd'
import { groupBy, sortBy, toPairs } from 'lodash'
import { Button, Empty, Flex, Spin, Tooltip } from 'antd'
import { groupBy, isEmpty, sortBy, toPairs } from 'lodash'
import { ListCheck, Plus } from 'lucide-react'
import React, { memo, startTransition, useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
@ -134,6 +134,8 @@ const ModelList: React.FC<ModelListProps> = ({ providerId }) => {
[provider, onUpdateModel]
)
const isLoading = useMemo(() => displayedModelGroups === null, [displayedModelGroups])
return (
<>
<SettingSubtitle style={{ marginBottom: 5 }}>
@ -158,54 +160,60 @@ const ModelList: React.FC<ModelListProps> = ({ providerId }) => {
</HStack>
</HStack>
</SettingSubtitle>
{displayedModelGroups === null ? (
<Flex align="center" justify="center" style={{ minHeight: '8rem' }}>
<Spin indicator={<SvgSpinners180Ring color="var(--color-text-2)" />} />
<Spin spinning={isLoading} indicator={<SvgSpinners180Ring color="var(--color-text-2)" />}>
{displayedModelGroups && !isEmpty(displayedModelGroups) ? (
<Flex gap={12} vertical>
{Object.keys(displayedModelGroups).map((group, i) => (
<ModelListGroup
key={group}
groupName={group}
models={displayedModelGroups[group]}
modelStatuses={modelStatuses}
defaultOpen={i <= 5}
disabled={isHealthChecking}
onEditModel={onEditModel}
onRemoveModel={removeModel}
onRemoveGroup={() => displayedModelGroups[group].forEach((model) => removeModel(model))}
/>
))}
</Flex>
) : (
<Empty
image={Empty.PRESENTED_IMAGE_SIMPLE}
description={t('settings.models.empty')}
style={{ visibility: isLoading ? 'hidden' : 'visible' }}
/>
)}
</Spin>
<Flex justify="space-between" align="center">
{docsWebsite || modelsWebsite ? (
<SettingHelpTextRow>
<SettingHelpText>{t('settings.provider.docs_check')} </SettingHelpText>
{docsWebsite && (
<SettingHelpLink target="_blank" href={docsWebsite}>
{getProviderLabel(provider.id) + ' '}
{t('common.docs')}
</SettingHelpLink>
)}
{docsWebsite && modelsWebsite && <SettingHelpText>{t('common.and')}</SettingHelpText>}
{modelsWebsite && (
<SettingHelpLink target="_blank" href={modelsWebsite}>
{t('common.models')}
</SettingHelpLink>
)}
<SettingHelpText>{t('settings.provider.docs_more_details')}</SettingHelpText>
</SettingHelpTextRow>
) : (
<div style={{ height: 5 }} />
)}
<Flex gap={10} style={{ marginTop: 12 }}>
<Button type="primary" onClick={onManageModel} icon={<ListCheck size={16} />} disabled={isHealthChecking}>
{t('button.manage')}
</Button>
<Button type="default" onClick={onAddModel} icon={<Plus size={16} />} disabled={isHealthChecking}>
{t('button.add')}
</Button>
</Flex>
) : (
<Flex gap={12} vertical>
{Object.keys(displayedModelGroups).map((group, i) => (
<ModelListGroup
key={group}
groupName={group}
models={displayedModelGroups[group]}
modelStatuses={modelStatuses}
defaultOpen={i <= 5}
disabled={isHealthChecking}
onEditModel={onEditModel}
onRemoveModel={removeModel}
onRemoveGroup={() => displayedModelGroups[group].forEach((model) => removeModel(model))}
/>
))}
{docsWebsite || modelsWebsite ? (
<SettingHelpTextRow>
<SettingHelpText>{t('settings.provider.docs_check')} </SettingHelpText>
{docsWebsite && (
<SettingHelpLink target="_blank" href={docsWebsite}>
{getProviderLabel(provider.id) + ' '}
{t('common.docs')}
</SettingHelpLink>
)}
{docsWebsite && modelsWebsite && <SettingHelpText>{t('common.and')}</SettingHelpText>}
{modelsWebsite && (
<SettingHelpLink target="_blank" href={modelsWebsite}>
{t('common.models')}
</SettingHelpLink>
)}
<SettingHelpText>{t('settings.provider.docs_more_details')}</SettingHelpText>
</SettingHelpTextRow>
) : (
<div style={{ height: 5 }} />
)}
</Flex>
)}
<Flex gap={10} style={{ marginTop: 10 }}>
<Button type="primary" onClick={onManageModel} icon={<ListCheck size={16} />} disabled={isHealthChecking}>
{t('button.manage')}
</Button>
<Button type="default" onClick={onAddModel} icon={<Plus size={16} />} disabled={isHealthChecking}>
{t('button.add')}
</Button>
</Flex>
</>
)

View File

@ -1,10 +1,10 @@
import { MinusOutlined } from '@ant-design/icons'
import CustomCollapse from '@renderer/components/CustomCollapse'
import { DynamicVirtualList, type DynamicVirtualListRef } from '@renderer/components/VirtualList'
import { Model } from '@renderer/types'
import { ModelWithStatus } from '@renderer/types/healthCheck'
import { useVirtualizer } from '@tanstack/react-virtual'
import { Button, Flex, Tooltip } from 'antd'
import React, { memo, useEffect, useRef, useState } from 'react'
import React, { memo, useCallback, useRef } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@ -32,29 +32,15 @@ const ModelListGroup: React.FC<ModelListGroupProps> = ({
onRemoveGroup
}) => {
const { t } = useTranslation()
const scrollerRef = useRef<HTMLDivElement>(null)
const [isExpanded, setIsExpanded] = useState(defaultOpen)
const listRef = useRef<DynamicVirtualListRef>(null)
const virtualizer = useVirtualizer({
count: models.length,
getScrollElement: () => scrollerRef.current,
estimateSize: () => 52,
overscan: 5
})
const virtualItems = virtualizer.getVirtualItems()
// 监听折叠面板状态变化,确保虚拟列表在展开时正确渲染
useEffect(() => {
if (isExpanded && scrollerRef.current) {
requestAnimationFrame(() => virtualizer.measure())
}
}, [isExpanded, virtualizer])
const handleCollapseChange = (activeKeys: string[] | string) => {
const handleCollapseChange = useCallback((activeKeys: string[] | string) => {
const isNowExpanded = Array.isArray(activeKeys) ? activeKeys.length > 0 : !!activeKeys
setIsExpanded(isNowExpanded)
}
if (isNowExpanded) {
// 延迟到 DOM 可见后测量
requestAnimationFrame(() => listRef.current?.measure())
}
}, [])
return (
<CustomCollapseWrapper>
@ -80,45 +66,28 @@ const ModelListGroup: React.FC<ModelListGroupProps> = ({
/>
</Tooltip>
}>
<ScrollContainer ref={scrollerRef}>
<div
style={{
height: `${virtualizer.getTotalSize()}px`,
width: '100%',
position: 'relative'
}}>
<div
style={{
position: 'absolute',
top: 0,
left: 0,
width: '100%',
transform: `translateY(${virtualItems[0]?.start ?? 0}px)`
}}>
{virtualItems.map((virtualItem) => {
const model = models[virtualItem.index]
return (
<div
key={virtualItem.key}
data-index={virtualItem.index}
ref={virtualizer.measureElement}
style={{
/* 在这里调整 item 间距 */
padding: '4px 0'
}}>
<ModelListItem
model={model}
modelStatus={modelStatuses.find((status) => status.model.id === model.id)}
onEdit={onEditModel}
onRemove={onRemoveModel}
disabled={disabled}
/>
</div>
)
})}
</div>
</div>
</ScrollContainer>
<DynamicVirtualList
ref={listRef}
list={models}
estimateSize={useCallback(() => 52, [])} // 44px item + 8px padding
overscan={5}
scrollerStyle={{
maxHeight: '390px',
padding: '4px 16px'
}}
itemContainerStyle={{
padding: '4px 0'
}}>
{(model) => (
<ModelListItem
model={model}
modelStatus={modelStatuses.find((status) => status.model.id === model.id)}
onEdit={onEditModel}
onRemove={onRemoveModel}
disabled={disabled}
/>
)}
</DynamicVirtualList>
</CustomCollapse>
</CustomCollapseWrapper>
)
@ -141,10 +110,4 @@ const CustomCollapseWrapper = styled.div`
}
`
const ScrollContainer = styled.div`
overflow-y: auto;
max-height: 390px;
padding: 4px 16px;
`
export default memo(ModelListGroup)

View File

@ -0,0 +1,372 @@
import { act, render, screen } from '@testing-library/react'
import React, { useRef } from 'react'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { DynamicVirtualList, type DynamicVirtualListRef } from '..'
// Mock management
const mocks = vi.hoisted(() => ({
virtualizer: {
getVirtualItems: vi.fn(() => [
{ index: 0, key: 'item-0', start: 0, size: 50 },
{ index: 1, key: 'item-1', start: 50, size: 50 },
{ index: 2, key: 'item-2', start: 100, size: 50 }
]),
getTotalSize: vi.fn(() => 150),
getVirtualIndexes: vi.fn(() => [0, 1, 2]),
measure: vi.fn(),
scrollToOffset: vi.fn(),
scrollToIndex: vi.fn(),
resizeItem: vi.fn(),
measureElement: vi.fn(),
scrollElement: null as HTMLDivElement | null
},
useVirtualizer: vi.fn()
}))
// Set up the mock to return our mock virtualizer
mocks.useVirtualizer.mockImplementation(() => mocks.virtualizer)
vi.mock('@tanstack/react-virtual', () => ({
useVirtualizer: mocks.useVirtualizer,
defaultRangeExtractor: vi.fn((range) =>
Array.from({ length: range.endIndex - range.startIndex + 1 }, (_, i) => range.startIndex + i)
)
}))
// Test data factory
interface TestItem {
id: string
content: string
}
function createTestItems(count = 5): TestItem[] {
return Array.from({ length: count }, (_, i) => ({
id: `${i + 1}`,
content: `Item ${i + 1}`
}))
}
describe('DynamicVirtualList', () => {
const defaultItems = createTestItems()
const defaultProps = {
list: defaultItems,
estimateSize: () => 50,
children: (item: TestItem, index: number) => <div data-testid={`item-${index}`}>{item.content}</div>
}
// Test component for ref testing
const TestComponentWithRef: React.FC<{
onRefReady?: (ref: DynamicVirtualListRef | null) => void
listProps?: any
}> = ({ onRefReady, listProps = {} }) => {
const ref = useRef<DynamicVirtualListRef>(null)
React.useEffect(() => {
onRefReady?.(ref.current)
}, [onRefReady])
return <DynamicVirtualList ref={ref} {...defaultProps} {...listProps} />
}
beforeEach(() => {
vi.clearAllMocks()
})
afterEach(() => {
vi.clearAllMocks()
})
describe('basic rendering', () => {
it('snapshot test', () => {
const { container } = render(<DynamicVirtualList {...defaultProps} />)
expect(container).toMatchSnapshot()
})
it('should apply custom scroller styles', () => {
const customStyle = { backgroundColor: 'red', height: '400px' }
render(<DynamicVirtualList {...defaultProps} scrollerStyle={customStyle} />)
const scrollContainer = document.querySelector('.dynamic-virtual-list')
expect(scrollContainer).toBeInTheDocument()
expect(scrollContainer).toHaveStyle('background-color: rgb(255, 0, 0)')
expect(scrollContainer).toHaveStyle('height: 400px')
})
it('should apply custom item container styles', () => {
const itemStyle = { padding: '10px', margin: '5px' }
render(<DynamicVirtualList {...defaultProps} itemContainerStyle={itemStyle} />)
const items = document.querySelectorAll('[data-index]')
expect(items.length).toBeGreaterThan(0)
// Check first item styles
const firstItem = items[0] as HTMLElement
expect(firstItem).toHaveStyle('padding: 10px')
expect(firstItem).toHaveStyle('margin: 5px')
})
})
describe('props integration', () => {
it('should render correctly with different item counts', () => {
const { rerender } = render(<DynamicVirtualList {...defaultProps} list={createTestItems(3)} />)
// Should render without errors
expect(screen.getByTestId('item-0')).toBeInTheDocument()
// Should handle dynamic item count changes
rerender(<DynamicVirtualList {...defaultProps} list={createTestItems(10)} />)
expect(document.querySelector('.dynamic-virtual-list')).toBeInTheDocument()
})
it('should work with custom estimateSize function', () => {
const customEstimateSize = vi.fn(() => 80)
// Should render without errors when using custom estimateSize
expect(() => {
render(<DynamicVirtualList {...defaultProps} estimateSize={customEstimateSize} />)
}).not.toThrow()
expect(screen.getByTestId('item-0')).toBeInTheDocument()
})
})
describe('sticky feature', () => {
it('should apply sticky positioning to specified items', () => {
const isSticky = vi.fn((index: number) => index === 0) // First item is sticky
render(<DynamicVirtualList {...defaultProps} isSticky={isSticky} />)
// Should call isSticky function during rendering
expect(isSticky).toHaveBeenCalled()
// Should apply sticky styles to sticky items
const stickyItem = document.querySelector('[data-index="0"]') as HTMLElement
expect(stickyItem).toBeInTheDocument()
expect(stickyItem).toHaveStyle('position: sticky')
expect(stickyItem).toHaveStyle('z-index: 1')
})
it('should apply absolute positioning to non-sticky items', () => {
const isSticky = vi.fn((index: number) => index === 0)
render(<DynamicVirtualList {...defaultProps} isSticky={isSticky} />)
// Non-sticky items should have absolute positioning
const regularItem = document.querySelector('[data-index="1"]') as HTMLElement
expect(regularItem).toBeInTheDocument()
expect(regularItem).toHaveStyle('position: absolute')
})
it('should apply absolute positioning to all items when no sticky function provided', () => {
render(<DynamicVirtualList {...defaultProps} />)
// All items should have absolute positioning
const items = document.querySelectorAll('[data-index]')
items.forEach((item) => {
const htmlItem = item as HTMLElement
expect(htmlItem).toHaveStyle('position: absolute')
})
})
})
describe('custom range extractor', () => {
it('should work with custom rangeExtractor', () => {
const customRangeExtractor = vi.fn(() => [0, 1, 2])
// Should render without errors when using custom rangeExtractor
expect(() => {
render(<DynamicVirtualList {...defaultProps} rangeExtractor={customRangeExtractor} />)
}).not.toThrow()
expect(screen.getByTestId('item-0')).toBeInTheDocument()
})
it('should handle both rangeExtractor and sticky props gracefully', () => {
const customRangeExtractor = vi.fn(() => [0, 1, 2])
const isSticky = vi.fn((index: number) => index === 0)
// Should render without conflicts when both props are provided
expect(() => {
render(<DynamicVirtualList {...defaultProps} rangeExtractor={customRangeExtractor} isSticky={isSticky} />)
}).not.toThrow()
expect(screen.getByTestId('item-0')).toBeInTheDocument()
})
})
describe('ref api', () => {
let refInstance: DynamicVirtualListRef | null = null
beforeEach(async () => {
render(
<TestComponentWithRef
onRefReady={(ref) => {
refInstance = ref
}}
/>
)
// Wait for ref to be ready
await new Promise((resolve) => setTimeout(resolve, 0))
})
it('should expose all required ref methods', () => {
expect(refInstance).toBeTruthy()
expect(refInstance).not.toBeNull()
// Type assertion to help TypeScript understand the type
const ref = refInstance as unknown as DynamicVirtualListRef
expect(typeof ref.measure).toBe('function')
expect(typeof ref.scrollElement).toBe('function')
expect(typeof ref.scrollToOffset).toBe('function')
expect(typeof ref.scrollToIndex).toBe('function')
expect(typeof ref.resizeItem).toBe('function')
expect(typeof ref.getTotalSize).toBe('function')
expect(typeof ref.getVirtualItems).toBe('function')
expect(typeof ref.getVirtualIndexes).toBe('function')
})
it('should allow calling all ref methods without throwing', () => {
const ref = refInstance as unknown as DynamicVirtualListRef
// Test that all methods can be called without errors
expect(() => ref.measure()).not.toThrow()
expect(() => ref.scrollToOffset(100, { align: 'start' })).not.toThrow()
expect(() => ref.scrollToIndex(2, { align: 'center' })).not.toThrow()
expect(() => ref.resizeItem(1, 80)).not.toThrow()
// Test that data methods return expected types
expect(typeof ref.getTotalSize()).toBe('number')
expect(Array.isArray(ref.getVirtualItems())).toBe(true)
expect(Array.isArray(ref.getVirtualIndexes())).toBe(true)
})
})
describe('orientation support', () => {
beforeEach(() => {
// Reset mocks for orientation tests
mocks.virtualizer.getVirtualItems.mockReturnValue([
{ index: 0, key: 'item-0', start: 0, size: 100 },
{ index: 1, key: 'item-1', start: 100, size: 100 }
])
mocks.virtualizer.getTotalSize.mockReturnValue(200)
})
it('should apply horizontal layout styles correctly', () => {
render(<DynamicVirtualList {...defaultProps} horizontal={true} />)
// Verify container styles for horizontal layout
const container = document.querySelector('div[style*="position: relative"]') as HTMLElement
expect(container).toHaveStyle('width: 200px') // totalSize
expect(container).toHaveStyle('height: 100%')
// Verify item transform for horizontal layout
const items = document.querySelectorAll('[data-index]')
const firstItem = items[0] as HTMLElement
expect(firstItem.style.transform).toContain('translateX(0px)')
expect(firstItem).toHaveStyle('height: 100%')
})
it('should apply vertical layout styles correctly', () => {
// Reset to default vertical mock values
mocks.virtualizer.getTotalSize.mockReturnValue(150)
render(<DynamicVirtualList {...defaultProps} horizontal={false} />)
// Verify container styles for vertical layout
const container = document.querySelector('div[style*="position: relative"]') as HTMLElement
expect(container).toHaveStyle('width: 100%')
expect(container).toHaveStyle('height: 150px') // totalSize from mock
// Verify item transform for vertical layout
const items = document.querySelectorAll('[data-index]')
const firstItem = items[0] as HTMLElement
expect(firstItem.style.transform).toContain('translateY(0px)')
expect(firstItem).toHaveStyle('width: 100%')
})
})
describe('edge cases', () => {
it('should handle edge cases gracefully', () => {
// Empty items list
mocks.virtualizer.getVirtualItems.mockReturnValueOnce([])
expect(() => {
render(<DynamicVirtualList {...defaultProps} list={[]} />)
}).not.toThrow()
// Null ref
expect(() => {
render(<DynamicVirtualList {...defaultProps} ref={null} />)
}).not.toThrow()
// Zero estimate size
expect(() => {
render(<DynamicVirtualList {...defaultProps} estimateSize={() => 0} />)
}).not.toThrow()
// Items without expected properties
const itemsWithoutContent = [{ id: '1' }, { id: '2' }] as any[]
expect(() => {
render(
<DynamicVirtualList
{...defaultProps}
list={itemsWithoutContent}
children={(_item, index) => <div data-testid={`item-${index}`}>No content</div>}
/>
)
}).not.toThrow()
})
})
describe('auto hide scrollbar', () => {
it('should always show scrollbar when autoHideScrollbar is false', () => {
render(<DynamicVirtualList {...defaultProps} autoHideScrollbar={false} />)
const scrollContainer = document.querySelector('.dynamic-virtual-list') as HTMLElement
expect(scrollContainer).toBeInTheDocument()
// When autoHideScrollbar is false, scrollbar should always be visible
expect(scrollContainer).not.toHaveAttribute('aria-hidden', 'true')
})
it('should hide scrollbar initially and show during scrolling when autoHideScrollbar is true', async () => {
vi.useFakeTimers()
render(<DynamicVirtualList {...defaultProps} autoHideScrollbar={true} />)
const scrollContainer = document.querySelector('.dynamic-virtual-list') as HTMLElement
expect(scrollContainer).toBeInTheDocument()
// Initially hidden
expect(scrollContainer).toHaveAttribute('aria-hidden', 'true')
// We can't easily simulate real scroll events in JSDOM, so we'll test the internal logic directly
// by calling the onChange handler which should update the state
const onChangeCallback = mocks.useVirtualizer.mock.calls[0][0].onChange
// Simulate scroll start
act(() => {
onChangeCallback({ isScrolling: true }, true)
})
// After scrolling starts, scrollbar should be visible
expect(scrollContainer).toHaveAttribute('aria-hidden', 'false')
// Simulate scroll end
act(() => {
onChangeCallback({ isScrolling: false }, true)
})
// Advance timers to trigger the hide timeout
act(() => {
vi.advanceTimersByTime(10000)
})
// After timeout, scrollbar should be hidden again
expect(scrollContainer).toHaveAttribute('aria-hidden', 'true')
vi.useRealTimers()
})
})
})

View File

@ -0,0 +1,58 @@
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
exports[`DynamicVirtualList > basic rendering > snapshot test 1`] = `
.c0::-webkit-scrollbar-thumb {
transition: background 0.3s ease-in-out;
will-change: background;
background: var(--color-scrollbar-thumb);
}
.c0::-webkit-scrollbar-thumb:hover {
background: var(--color-scrollbar-thumb-hover);
}
<div>
<div
aria-hidden="false"
aria-label="Dynamic Virtual List"
class="c0 dynamic-virtual-list"
role="region"
style="overflow: auto; height: 100%;"
>
<div
style="position: relative; width: 100%; height: 150px;"
>
<div
data-index="0"
style="position: absolute; top: 0px; left: 0px; transform: translateY(0px); width: 100%;"
>
<div
data-testid="item-0"
>
Item 1
</div>
</div>
<div
data-index="1"
style="position: absolute; top: 0px; left: 0px; transform: translateY(50px); width: 100%;"
>
<div
data-testid="item-1"
>
Item 2
</div>
</div>
<div
data-index="2"
style="position: absolute; top: 0px; left: 0px; transform: translateY(100px); width: 100%;"
>
<div
data-testid="item-2"
>
Item 3
</div>
</div>
</div>
</div>
</div>
`;

View File

@ -0,0 +1,257 @@
import type { Range, ScrollToOptions, VirtualItem, VirtualizerOptions } from '@tanstack/react-virtual'
import { defaultRangeExtractor, useVirtualizer } from '@tanstack/react-virtual'
import React, { memo, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react'
import styled from 'styled-components'
const SCROLLBAR_AUTO_HIDE_DELAY = 2000
type InheritedVirtualizerOptions = Partial<
Omit<
VirtualizerOptions<HTMLDivElement, Element>,
| 'count' // determined by items.length
| 'getScrollElement' // determined by internal scrollerRef
| 'estimateSize' // promoted to a required prop
| 'rangeExtractor' // isSticky provides a simpler abstraction
>
>
export interface DynamicVirtualListRef {
/** Resets any prev item measurements. */
measure: () => void
/** Returns the scroll element for the virtualizer. */
scrollElement: () => HTMLDivElement | null
/** Scrolls the virtualizer to the pixel offset provided. */
scrollToOffset: (offset: number, options?: ScrollToOptions) => void
/** Scrolls the virtualizer to the items of the index provided. */
scrollToIndex: (index: number, options?: ScrollToOptions) => void
/** Resizes an item. */
resizeItem: (index: number, size: number) => void
/** Returns the total size in pixels for the virtualized items. */
getTotalSize: () => number
/** Returns the virtual items for the current state of the virtualizer. */
getVirtualItems: () => VirtualItem[]
/** Returns the virtual row indexes for the current state of the virtualizer. */
getVirtualIndexes: () => number[]
}
export interface DynamicVirtualListProps<T> extends InheritedVirtualizerOptions {
ref?: React.Ref<DynamicVirtualListRef>
/**
* List data
*/
list: T[]
/**
* List item renderer function
*/
children: (item: T, index: number) => React.ReactNode
/**
* List size (height or width, default is 100%)
*/
size?: string | number
/**
* List item size estimator function (initial estimation)
*/
estimateSize: (index: number) => number
/**
* Sticky item predicate, cannot be used with rangeExtractor
*/
isSticky?: (index: number) => boolean
/**
* Range extractor function, cannot be used with isSticky
*/
rangeExtractor?: (range: Range) => number[]
/**
* List item container style
*/
itemContainerStyle?: React.CSSProperties
/**
* Scroll container style
*/
scrollerStyle?: React.CSSProperties
/**
* Hide the scrollbar automatically when scrolling is stopped
*/
autoHideScrollbar?: boolean
}
function DynamicVirtualList<T>(props: DynamicVirtualListProps<T>) {
const {
ref,
list,
children,
size,
estimateSize,
isSticky,
rangeExtractor: customRangeExtractor,
itemContainerStyle,
scrollerStyle,
autoHideScrollbar = false,
...restOptions
} = props
const [showScrollbar, setShowScrollbar] = useState(!autoHideScrollbar)
const timeoutRef = useRef<NodeJS.Timeout | null>(null)
const internalScrollerRef = useRef<HTMLDivElement>(null)
const scrollerRef = internalScrollerRef
const activeStickyIndexRef = useRef(0)
const stickyIndexes = useMemo(() => {
if (!isSticky) return []
return list.map((_, index) => (isSticky(index) ? index : -1)).filter((index) => index !== -1)
}, [list, isSticky])
const internalStickyRangeExtractor = useCallback(
(range: Range) => {
// The active sticky index is the last one that is before or at the start of the visible range
const newActiveStickyIndex =
[...stickyIndexes].reverse().find((index) => range.startIndex >= index) ?? stickyIndexes[0] ?? 0
if (newActiveStickyIndex !== activeStickyIndexRef.current) {
activeStickyIndexRef.current = newActiveStickyIndex
}
// Merge the active sticky index and the default range extractor
const next = new Set([activeStickyIndexRef.current, ...defaultRangeExtractor(range)])
// Sort the set to maintain proper order
return [...next].sort((a, b) => a - b)
},
[stickyIndexes]
)
const rangeExtractor = customRangeExtractor ?? (isSticky ? internalStickyRangeExtractor : undefined)
const handleScrollbarHide = useCallback(
(isScrolling: boolean) => {
if (!autoHideScrollbar) return
if (timeoutRef.current) clearTimeout(timeoutRef.current)
if (isScrolling) {
setShowScrollbar(true)
} else {
timeoutRef.current = setTimeout(() => {
setShowScrollbar(false)
}, SCROLLBAR_AUTO_HIDE_DELAY)
}
},
[autoHideScrollbar]
)
const virtualizer = useVirtualizer({
...restOptions,
count: list.length,
getScrollElement: () => scrollerRef.current,
estimateSize,
rangeExtractor,
onChange: (instance, sync) => {
restOptions.onChange?.(instance, sync)
handleScrollbarHide(instance.isScrolling)
}
})
useEffect(() => {
return () => {
if (timeoutRef.current) {
clearTimeout(timeoutRef.current)
}
}
}, [autoHideScrollbar])
useImperativeHandle(
ref,
() => ({
measure: () => virtualizer.measure(),
scrollElement: () => virtualizer.scrollElement,
scrollToOffset: (offset, options) => virtualizer.scrollToOffset(offset, options),
scrollToIndex: (index, options) => virtualizer.scrollToIndex(index, options),
resizeItem: (index, size) => virtualizer.resizeItem(index, size),
getTotalSize: () => virtualizer.getTotalSize(),
getVirtualItems: () => virtualizer.getVirtualItems(),
getVirtualIndexes: () => virtualizer.getVirtualIndexes()
}),
[virtualizer]
)
const virtualItems = virtualizer.getVirtualItems()
const totalSize = virtualizer.getTotalSize()
const { horizontal } = restOptions
return (
<ScrollContainer
ref={scrollerRef}
className="dynamic-virtual-list"
role="region"
aria-label="Dynamic Virtual List"
aria-hidden={!showScrollbar}
$autoHide={autoHideScrollbar}
$show={showScrollbar}
style={{
overflow: 'auto',
...(horizontal ? { width: size ?? '100%' } : { height: size ?? '100%' }),
...scrollerStyle
}}>
<div
style={{
position: 'relative',
width: horizontal ? `${totalSize}px` : '100%',
height: !horizontal ? `${totalSize}px` : '100%'
}}>
{virtualItems.map((virtualItem) => {
const isItemSticky = stickyIndexes.includes(virtualItem.index)
const isItemActiveSticky = isItemSticky && activeStickyIndexRef.current === virtualItem.index
const style: React.CSSProperties = {
...itemContainerStyle,
position: isItemActiveSticky ? 'sticky' : 'absolute',
top: 0,
left: 0,
zIndex: isItemSticky ? 1 : undefined,
...(horizontal
? {
transform: isItemActiveSticky ? undefined : `translateX(${virtualItem.start}px)`,
height: '100%'
}
: {
transform: isItemActiveSticky ? undefined : `translateY(${virtualItem.start}px)`,
width: '100%'
})
}
return (
<div key={virtualItem.key} data-index={virtualItem.index} ref={virtualizer.measureElement} style={style}>
{children(list[virtualItem.index], virtualItem.index)}
</div>
)
})}
</div>
</ScrollContainer>
)
}
const ScrollContainer = styled.div<{ $autoHide: boolean; $show: boolean }>`
&::-webkit-scrollbar-thumb {
transition: background 0.3s ease-in-out;
will-change: background;
background: ${(props) => (props.$autoHide && !props.$show ? 'transparent' : 'var(--color-scrollbar-thumb)')};
&:hover {
background: var(--color-scrollbar-thumb-hover);
}
}
`
const MemoizedDynamicVirtualList = memo(DynamicVirtualList) as <T>(
props: DynamicVirtualListProps<T>
) => React.ReactElement
export default MemoizedDynamicVirtualList

View File

@ -0,0 +1 @@
export { default as DynamicVirtualList, type DynamicVirtualListProps, type DynamicVirtualListRef } from './dynamic'

View File

@ -4,7 +4,7 @@ exports[`DraggableVirtualList > snapshot > should match snapshot with custom sty
<div>
<div
class="custom-class draggable-virtual-list"
style="height: 100%; border: 1px solid red;"
style="height: 100%; display: flex; flex-direction: column; border: 1px solid red;"
>
<div
data-testid="drag-drop-context"

View File

@ -2663,6 +2663,17 @@ export function isQwenReasoningModel(model?: Model): boolean {
return false
}
const baseName = getLowerBaseModelName(model.id, '/')
if (baseName.startsWith('qwen3')) {
if (baseName.includes('thinking')) {
return true
} else if (baseName.includes('instruct')) {
return false
}
return true
}
if (isSupportedThinkingTokenQwenModel(model)) {
return true
}
@ -2681,27 +2692,34 @@ export function isSupportedThinkingTokenQwenModel(model?: Model): boolean {
const baseName = getLowerBaseModelName(model.id, '/')
if (baseName.includes('coder') || baseName.includes('qwen3-235b-a22b-instruct')) {
if (baseName.includes('coder')) {
return false
}
return (
baseName.startsWith('qwen3') ||
[
'qwen-plus',
'qwen-plus-latest',
'qwen-plus-0428',
'qwen-plus-2025-04-28',
'qwen-plus-0714',
'qwen-plus-2025-07-14',
'qwen-turbo',
'qwen-turbo-latest',
'qwen-turbo-0428',
'qwen-turbo-2025-04-28',
'qwen-turbo-0715',
'qwen-turbo-2025-07-15'
].includes(baseName)
)
if (baseName.startsWith('qwen3')) {
if (baseName.includes('instruct')) {
return false
}
if (baseName.includes('thinking')) {
return true
}
return true
}
return [
'qwen-plus',
'qwen-plus-latest',
'qwen-plus-0428',
'qwen-plus-2025-04-28',
'qwen-plus-0714',
'qwen-plus-2025-07-14',
'qwen-turbo',
'qwen-turbo-latest',
'qwen-turbo-0428',
'qwen-turbo-2025-04-28',
'qwen-turbo-0715',
'qwen-turbo-2025-07-15'
].includes(baseName)
}
export function isQwen3235BA22BThinkingModel(model?: Model): boolean {
@ -3070,11 +3088,13 @@ export const THINKING_TOKEN_MAP: Record<string, { min: number; max: number }> =
'gemini-.*-pro.*$': { min: 128, max: 32768 },
// Qwen models
'qwen3-235b-a22b-thinking(?:-[\\w-]+)$': { min: 0, max: 81_920 },
'qwen3-235b-a22b-thinking-2507$': { min: 0, max: 81_920 },
'qwen3-30b-a3b-thinking-2507$': { min: 0, max: 81_920 },
'qwen-plus-2025-07-28$': { min: 0, max: 81_920 },
'qwen3-1\\.7b$': { min: 0, max: 30_720 },
'qwen3-0\\.6b$': { min: 0, max: 30_720 },
'qwen-plus-.*$': { min: 0, max: 38912 },
'qwen-turbo-.*$': { min: 0, max: 38912 },
'qwen3-0\\.6b$': { min: 0, max: 30720 },
'qwen3-1\\.7b$': { min: 0, max: 30720 },
'qwen3-.*$': { min: 1024, max: 38912 },
// Claude models

View File

@ -5,7 +5,7 @@ import Ai302ProviderLogo from '@renderer/assets/images/providers/302ai.webp'
import AiHubMixProviderLogo from '@renderer/assets/images/providers/aihubmix.webp'
import AlayaNewProviderLogo from '@renderer/assets/images/providers/alayanew.webp'
import AnthropicProviderLogo from '@renderer/assets/images/providers/anthropic.png'
import AwsProviderLogo from '@renderer/assets/images/providers/aws-bedrock.png'
import AwsProviderLogo from '@renderer/assets/images/providers/aws-bedrock.webp'
import BaichuanProviderLogo from '@renderer/assets/images/providers/baichuan.png'
import BaiduCloudProviderLogo from '@renderer/assets/images/providers/baidu-cloud.svg'
import BailianProviderLogo from '@renderer/assets/images/providers/bailian.png'
@ -320,7 +320,7 @@ export const PROVIDER_CONFIG = {
websites: {
official: 'https://open.bigmodel.cn/',
apiKey: 'https://open.bigmodel.cn/usercenter/apikeys',
docs: 'https://open.bigmodel.cn/dev/howuse/introduction',
docs: 'https://docs.bigmodel.cn/',
models: 'https://open.bigmodel.cn/modelcenter/square'
}
},

View File

@ -6,7 +6,9 @@ import { Dexie, type EntityTable } from 'dexie'
import { upgradeToV5, upgradeToV7, upgradeToV8 } from './upgrades'
// Database declaration (move this to its own module also)
export const db = new Dexie('CherryStudio') as Dexie & {
export const db = new Dexie('CherryStudio', {
chromeTransactionDurability: 'strict'
}) as Dexie & {
files: EntityTable<FileMetadata, 'id'>
topics: EntityTable<{ id: string; messages: NewMessage[] }, 'id'> // Correct type for topics
settings: EntityTable<{ id: string; value: any }, 'id'>

View File

@ -8,17 +8,19 @@ import KnowledgeQueue from '@renderer/queue/KnowledgeQueue'
import MemoryService from '@renderer/services/MemoryService'
import { useAppDispatch } from '@renderer/store'
import { useAppSelector } from '@renderer/store'
import { handleSaveData } from '@renderer/store'
import { selectMemoryConfig } from '@renderer/store/memory'
import { setAvatar, setFilesPath, setResourcesPath, setUpdateState } from '@renderer/store/runtime'
import { delay, runAsyncFunction } from '@renderer/utils'
import { defaultLanguage } from '@shared/config/constant'
import { IpcChannel } from '@shared/IpcChannel'
import { useLiveQuery } from 'dexie-react-hooks'
import { useEffect } from 'react'
import { useDefaultModel } from './useAssistant'
import useFullScreenNotice from './useFullScreenNotice'
import { useRuntime } from './useRuntime'
import { useNavbarPosition, useSettings } from './useSettings'
import { useSettings } from './useSettings'
import useUpdateHandler from './useUpdateHandler'
const logger = loggerService.withContext('useAppInit')
@ -31,7 +33,6 @@ export function useAppInit() {
const avatar = useLiveQuery(() => db.settings.get('image://avatar'))
const { theme } = useTheme()
const memoryConfig = useAppSelector(selectMemoryConfig)
const { isTopNavbar } = useNavbarPosition()
useEffect(() => {
document.getElementById('spinner')?.remove()
@ -50,6 +51,12 @@ export function useAppInit() {
})
}, [])
useEffect(() => {
window.electron.ipcRenderer.on(IpcChannel.App_SaveData, async () => {
await handleSaveData()
})
}, [])
useUpdateHandler()
useFullScreenNotice()
@ -86,17 +93,13 @@ export function useAppInit() {
const transparentWindow = windowStyle === 'transparent' && isMac && !minappShow
if (minappShow) {
if (isTopNavbar) {
window.root.style.background = 'var(--navbar-background)'
} else {
window.root.style.background =
windowStyle === 'transparent' && isMac ? 'var(--color-background)' : 'var(--navbar-background)'
}
window.root.style.background =
windowStyle === 'transparent' && isMac ? 'var(--color-background)' : 'var(--navbar-background)'
return
}
window.root.style.background = transparentWindow ? 'var(--navbar-background-mac)' : 'var(--navbar-background)'
}, [windowStyle, minappShow, theme, isTopNavbar])
}, [windowStyle, minappShow, theme])
useEffect(() => {
if (isLocalAi) {

View File

@ -1,6 +1,4 @@
import { isMac } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
import { useOcrProviders } from '@renderer/hooks/useOcr'
import { usePreprocessProviders } from '@renderer/hooks/usePreprocess'
import { useProviders } from '@renderer/hooks/useProvider'
import { getModelUniqId } from '@renderer/services/ModelService'
@ -42,11 +40,10 @@ export const useKnowledgeBaseForm = (base?: KnowledgeBase) => {
const [newBase, setNewBase] = useState<KnowledgeBase>(base || createInitialKnowledgeBase())
const { providers } = useProviders()
const { preprocessProviders } = usePreprocessProviders()
const { ocrProviders } = useOcrProviders()
const selectedDocPreprocessProvider = useMemo(
() => newBase.preprocessOrOcrProvider?.provider,
[newBase.preprocessOrOcrProvider]
() => newBase.preprocessProvider?.provider,
[newBase.preprocessProvider]
)
const docPreprocessSelectOptions = useMemo(() => {
@ -57,14 +54,8 @@ export const useKnowledgeBaseForm = (base?: KnowledgeBase) => {
.filter((p) => p.apiKey !== '' || p.id === 'mineru')
.map((p) => ({ value: p.id, label: p.name }))
}
const ocrOptions = {
label: t('settings.tool.ocr.provider'),
title: t('settings.tool.ocr.provider'),
options: ocrProviders.filter((p) => p.apiKey !== '').map((p) => ({ value: p.id, label: p.name }))
}
return isMac ? [preprocessOptions, ocrOptions] : [preprocessOptions]
}, [ocrProviders, preprocessProviders, t])
return [preprocessOptions]
}, [preprocessProviders, t])
const handleEmbeddingModelChange = useCallback(
(value: string) => {
@ -92,21 +83,20 @@ export const useKnowledgeBaseForm = (base?: KnowledgeBase) => {
const handleDocPreprocessChange = useCallback(
(value: string) => {
const type = preprocessProviders.find((p) => p.id === value) ? 'preprocess' : 'ocr'
const provider = (type === 'preprocess' ? preprocessProviders : ocrProviders).find((p) => p.id === value)
const provider = preprocessProviders.find((p) => p.id === value)
if (!provider) {
setNewBase((prev) => ({ ...prev, preprocessOrOcrProvider: undefined }))
setNewBase((prev) => ({ ...prev, preprocessProvider: undefined }))
return
}
setNewBase((prev) => ({
...prev,
preprocessOrOcrProvider: {
type,
preprocessProvider: {
type: 'preprocess',
provider
}
}))
},
[preprocessProviders, ocrProviders]
[preprocessProviders]
)
const handleChunkSizeChange = useCallback(
@ -152,7 +142,6 @@ export const useKnowledgeBaseForm = (base?: KnowledgeBase) => {
const providerData = {
providers,
preprocessProviders,
ocrProviders,
selectedDocPreprocessProvider,
docPreprocessSelectOptions
}

View File

@ -1,45 +0,0 @@
import { RootState } from '@renderer/store'
import {
setDefaultOcrProvider as _setDefaultOcrProvider,
updateOcrProvider as _updateOcrProvider,
updateOcrProviders as _updateOcrProviders
} from '@renderer/store/ocr'
import { OcrProvider } from '@renderer/types'
import { useDispatch, useSelector } from 'react-redux'
export const useOcrProvider = (id: string) => {
const dispatch = useDispatch()
const ocrProviders = useSelector((state: RootState) => state.ocr.providers)
const provider = ocrProviders.find((provider) => provider.id === id)
if (!provider) {
throw new Error(`OCR provider with id ${id} not found`)
}
const updateOcrProvider = (ocrProvider: OcrProvider) => {
dispatch(_updateOcrProvider(ocrProvider))
}
return { provider, updateOcrProvider }
}
export const useOcrProviders = () => {
const dispatch = useDispatch()
const ocrProviders = useSelector((state: RootState) => state.ocr.providers)
return {
ocrProviders: ocrProviders,
updateOcrProviders: (ocrProviders: OcrProvider[]) => dispatch(_updateOcrProviders(ocrProviders))
}
}
export const useDefaultOcrProvider = () => {
const defaultProviderId = useSelector((state: RootState) => state.ocr.defaultProvider)
const { ocrProviders } = useOcrProviders()
const dispatch = useDispatch()
const provider = defaultProviderId ? ocrProviders.find((provider) => provider.id === defaultProviderId) : undefined
const setDefaultOcrProvider = (ocrProvider: OcrProvider) => {
dispatch(_setDefaultOcrProvider(ocrProvider.id))
}
const updateDefaultOcrProvider = (ocrProvider: OcrProvider) => {
dispatch(_updateOcrProvider(ocrProvider))
}
return { provider, setDefaultOcrProvider, updateDefaultOcrProvider }
}

View File

@ -0,0 +1,34 @@
import { loggerService } from '@logger'
import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt'
import { useEffect, useState } from 'react'
const logger = loggerService.withContext('usePromptProcessor')
interface PromptProcessor {
prompt: string
modelName?: string
}
export function usePromptProcessor({ prompt, modelName }: PromptProcessor): string {
const [processedPrompt, setProcessedPrompt] = useState(prompt)
useEffect(() => {
const processPrompt = async () => {
try {
if (containsSupportedVariables(prompt)) {
const result = await replacePromptVariables(prompt, modelName)
setProcessedPrompt(result)
} else {
setProcessedPrompt(prompt)
}
} catch (error) {
logger.error('Failed to process prompt variables, falling back:', error as Error)
setProcessedPrompt(prompt)
}
}
processPrompt()
}, [prompt, modelName])
return processedPrompt
}

View File

@ -703,6 +703,7 @@
"no_results": "No results",
"open": "Open",
"paste": "Paste",
"preview": "Preview",
"prompt": "Prompt",
"provider": "Provider",
"reasoning_content": "Deep reasoning",
@ -711,6 +712,7 @@
"rename": "Rename",
"reset": "Reset",
"save": "Save",
"saved": "Saved",
"search": "Search",
"select": "Select",
"selectedItems": "Selected {{count}} items",
@ -886,7 +888,7 @@
"error": {
"failed_to_create": "Knowledge base creation failed",
"failed_to_edit": "Knowledge base editing failed",
"model_invalid": "No model selected or deleted"
"model_invalid": "No model selected"
},
"file_hint": "Support {{file_types}}",
"index_all": "Index All",
@ -927,7 +929,7 @@
"search_placeholder": "Enter text to search",
"settings": {
"preprocessing": "Preprocessing",
"preprocessing_tooltip": "Preprocess uploaded files with OCR",
"preprocessing_tooltip": "Preprocess uploaded files",
"title": "Knowledge Base Settings"
},
"sitemap_added": "Added successfully",
@ -3307,26 +3309,11 @@
},
"title": "Settings",
"tool": {
"ocr": {
"mac_system_ocr_options": {
"min_confidence": "Minimum Confidence",
"mode": {
"accurate": "Accurate",
"fast": "Fast",
"title": "Recognition Mode"
}
},
"provider": "OCR Provider",
"provider_placeholder": "Choose an OCR provider",
"title": "OCR Settings"
},
"preprocess": {
"provider": "Pre Process Provider",
"provider_placeholder": "Choose a Pre Process provider",
"title": "Pre Process"
},
"preprocessOrOcr": {
"tooltip": "In Settings -> Tools, set a document preprocessing service provider or OCR. Document preprocessing can effectively improve the retrieval performance of complex format documents and scanned documents. OCR can only recognize text within images in documents or scanned PDF text."
"title": "Pre Process",
"tooltip": "In Settings -> Tools, set a document preprocessing service provider. Document preprocessing can effectively improve the retrieval performance of complex format documents and scanned documents."
},
"title": "Tools Settings",
"websearch": {

View File

@ -703,6 +703,7 @@
"no_results": "検索結果なし",
"open": "開く",
"paste": "貼り付け",
"preview": "プレビュー",
"prompt": "プロンプト",
"provider": "プロバイダー",
"reasoning_content": "深く考察済み",
@ -711,6 +712,7 @@
"rename": "名前を変更",
"reset": "リセット",
"save": "保存",
"saved": "保存されました",
"search": "検索",
"select": "選択",
"selectedItems": "{{count}}件の項目を選択しました",
@ -886,7 +888,7 @@
"error": {
"failed_to_create": "ナレッジベースの作成に失敗しました",
"failed_to_edit": "ナレッジベースの編集に失敗しました",
"model_invalid": "モデルが選択されていないか、削除されています"
"model_invalid": "モデルが選択されていません"
},
"file_hint": "{{file_types}} 形式をサポート",
"index_all": "すべてをインデックス",
@ -927,7 +929,7 @@
"search_placeholder": "検索するテキストを入力",
"settings": {
"preprocessing": "預処理",
"preprocessing_tooltip": "アップロードされたファイルのOCR預処理",
"preprocessing_tooltip": "アップロードされたファイルの預処理",
"title": "ナレッジベース設定"
},
"sitemap_added": "追加成功",
@ -3307,26 +3309,11 @@
},
"title": "設定",
"tool": {
"ocr": {
"mac_system_ocr_options": {
"min_confidence": "最小信頼度",
"mode": {
"accurate": "正確",
"fast": "速い",
"title": "認識モード"
}
},
"provider": "OCRプロバイダー",
"provider_placeholder": "OCRプロバイダーを選択",
"title": "OCRオーシーアール"
},
"preprocess": {
"provider": "プレプロセスプロバイダー",
"provider_placeholder": "前処理プロバイダーを選択してください",
"title": "前処理"
},
"preprocessOrOcr": {
"tooltip": "設定 → ツールで、ドキュメント前処理サービスプロバイダーまたはOCRを設定します。ドキュメント前処理は、複雑な形式のドキュメントやスキャンされたドキュメントの検索性能を効果的に向上させます。OCRは、ドキュメント内の画像内のテキストまたはスキャンされたPDFテキストのみを認識できます。"
"title": "前処理",
"tooltip": "設定 → ツールで、ドキュメント前処理サービスプロバイダーを設定します。ドキュメント前処理は、複雑な形式のドキュメントやスキャンされたドキュメントの検索性能を効果的に向上させます。"
},
"title": "ツール設定",
"websearch": {

View File

@ -703,6 +703,7 @@
"no_results": "Результатов не найдено",
"open": "Открыть",
"paste": "Вставить",
"preview": "Предварительный просмотр",
"prompt": "Промпт",
"provider": "Провайдер",
"reasoning_content": "Глубокий анализ",
@ -711,6 +712,7 @@
"rename": "Переименовать",
"reset": "Сбросить",
"save": "Сохранить",
"saved": "Сохранено",
"search": "Поиск",
"select": "Выбрать",
"selectedItems": "Выбрано {{count}} элементов",
@ -886,7 +888,7 @@
"error": {
"failed_to_create": "Создание базы знаний завершено с ошибками",
"failed_to_edit": "Редактирование базы знаний завершено с ошибками",
"model_invalid": "Модель не выбрана или удалена"
"model_invalid": "Модель не выбрана"
},
"file_hint": "Поддерживаются {{file_types}}",
"index_all": "Индексировать все",
@ -927,7 +929,7 @@
"search_placeholder": "Введите текст для поиска",
"settings": {
"preprocessing": "Предварительная обработка",
"preprocessing_tooltip": "Предварительная обработка изображений с помощью OCR",
"preprocessing_tooltip": "Предварительная обработка документов",
"title": "Настройки базы знаний"
},
"sitemap_added": "添加成功",
@ -3307,26 +3309,11 @@
},
"title": "Настройки",
"tool": {
"ocr": {
"mac_system_ocr_options": {
"min_confidence": "Минимальная достоверность",
"mode": {
"accurate": "Точный",
"fast": "Быстро",
"title": "Режим распознавания"
}
},
"provider": "Поставщик OCR",
"provider_placeholder": "Выберите провайдера OCR",
"title": "OCR (оптическое распознавание символов)"
},
"preprocess": {
"provider": "Предварительная обработка Поставщик",
"provider_placeholder": "Выберите поставщика услуг предварительной обработки",
"title": "Предварительная обработка"
},
"preprocessOrOcr": {
"tooltip": "В настройках (Настройки -> Инструменты) укажите поставщика услуги предварительной обработки документов или OCR. Предварительная обработка документов может значительно повысить эффективность поиска для документов сложных форматов и отсканированных документов. OCR способен распознавать только текст внутри изображений в документах или текст в отсканированных PDF."
"title": "Предварительная обработка",
"tooltip": "В настройках (Настройки -> Инструменты) укажите поставщика услуги предварительной обработки документов. Предварительная обработка документов может значительно повысить эффективность поиска для документов сложных форматов и отсканированных документов."
},
"title": "Настройки инструментов",
"websearch": {

View File

@ -703,6 +703,7 @@
"no_results": "无结果",
"open": "打开",
"paste": "粘贴",
"preview": "预览",
"prompt": "提示词",
"provider": "提供商",
"reasoning_content": "已深度思考",
@ -711,6 +712,7 @@
"rename": "重命名",
"reset": "重置",
"save": "保存",
"saved": "已保存",
"search": "搜索",
"select": "选择",
"selectedItems": "已选择 {{count}} 项",
@ -886,7 +888,7 @@
"error": {
"failed_to_create": "知识库创建失败",
"failed_to_edit": "知识库编辑失败",
"model_invalid": "未选择模型或已删除"
"model_invalid": "未选择模型"
},
"file_hint": "支持 {{file_types}} 格式",
"index_all": "索引全部",
@ -3307,26 +3309,11 @@
},
"title": "设置",
"tool": {
"ocr": {
"mac_system_ocr_options": {
"min_confidence": "最低置信度",
"mode": {
"accurate": "准确",
"fast": "快速",
"title": "识别模式"
}
},
"provider": "OCR 服务商",
"provider_placeholder": "选择一个 OCR 服务商",
"title": "OCR 文字识别"
},
"preprocess": {
"provider": "文档预处理服务商",
"provider_placeholder": "选择一个文档预处理服务商",
"title": "文档预处理"
},
"preprocessOrOcr": {
"tooltip": "在设置 -> 工具中设置文档预处理服务商或OCR文档预处理可以有效提升复杂格式文档与扫描版文档的检索效果OCR仅可识别文档内图片或扫描版PDF的文本"
"title": "文档预处理",
"tooltip": "在设置 -> 工具中设置文档预处理服务商,文档预处理可以有效提升复杂格式文档与扫描版文档的检索效果"
},
"title": "工具设置",
"websearch": {

View File

@ -703,6 +703,7 @@
"no_results": "沒有結果",
"open": "開啟",
"paste": "貼上",
"preview": "預覽",
"prompt": "提示詞",
"provider": "供應商",
"reasoning_content": "已深度思考",
@ -711,6 +712,7 @@
"rename": "重新命名",
"reset": "重設",
"save": "儲存",
"saved": "已儲存",
"search": "搜尋",
"select": "選擇",
"selectedItems": "已選擇 {{count}} 項",
@ -886,7 +888,7 @@
"error": {
"failed_to_create": "知識庫創建失敗",
"failed_to_edit": "知識庫編輯失敗",
"model_invalid": "未選擇模型或已刪除"
"model_invalid": "未選擇模型"
},
"file_hint": "支援 {{file_types}} 格式",
"index_all": "索引全部",
@ -3307,26 +3309,11 @@
},
"title": "設定",
"tool": {
"ocr": {
"mac_system_ocr_options": {
"min_confidence": "最小置信度",
"mode": {
"accurate": "準確",
"fast": "快速",
"title": "識別模式"
}
},
"provider": "OCR 供應商",
"provider_placeholder": "選擇一個OCR服務提供商",
"title": "OCR 文字識別"
},
"preprocess": {
"provider": "前置處理供應商",
"provider_placeholder": "選擇一個預處理供應商",
"title": "前置處理"
},
"preprocessOrOcr": {
"tooltip": "在「設定」->「工具」中設定文件預處理服務供應商或OCR。文件預處理可有效提升複雜格式文件及掃描文件的檢索效能而OCR僅能辨識文件內圖片文字或掃描PDF文字。"
"title": "前置處理",
"tooltip": "在「設定」->「工具」中設定文件預處理服務供應商。文件預處理可有效提升複雜格式文件及掃描文件的檢索效能"
},
"title": "工具設定",
"websearch": {

View File

@ -3325,7 +3325,7 @@
"provider_placeholder": "Επιλέξτε έναν πάροχο προεπεξεργασίας εγγράφων",
"title": "Προεπεξεργασία Εγγράφων"
},
"preprocessOrOcr": {
"preprocess": {
"tooltip": "Ορίστε πάροχο προεπεξεργασίας εγγράφων ή OCR στις Ρυθμίσεις -> Εργαλεία. Η προεπεξεργασία εγγράφων μπορεί να βελτιώσει σημαντικά την απόδοση αναζήτησης για έγγραφα πολύπλοκης μορφής ή εγγράφων σε μορφή σάρωσης. Το OCR μπορεί να αναγνωρίσει μόνο κείμενο μέσα σε εικόνες εγγράφων ή σε PDF σε μορφή σάρωσης."
},
"title": "Ρυθμίσεις Εργαλείων",

View File

@ -3325,7 +3325,7 @@
"provider_placeholder": "Selecciona un proveedor de preprocesamiento de documentos",
"title": "Preprocesamiento de Documentos"
},
"preprocessOrOcr": {
"preprocess": {
"tooltip": "Configure un proveedor de preprocesamiento de documentos o OCR en Configuración -> Herramientas. El preprocesamiento de documentos puede mejorar significativamente la eficacia de búsqueda en documentos con formatos complejos o versiones escaneadas. El OCR solo puede reconocer texto en imágenes o en archivos PDF escaneados."
},
"title": "Configuración de Herramientas",

View File

@ -3325,7 +3325,7 @@
"provider_placeholder": "Sélectionnez un fournisseur de traitement préalable de documents",
"title": "Traitement Préliminaire de Documents"
},
"preprocessOrOcr": {
"preprocess": {
"tooltip": "Configurer un fournisseur de prétraitement de documents ou OCR dans Paramètres -> Outils. Le prétraitement des documents améliore efficacement la précision de recherche pour les documents à format complexe ou les versions scannées, tandis que l'OCR permet uniquement d'extraire le texte contenu dans les images ou les PDF scannés."
},
"title": "Paramètres des outils",

View File

@ -3325,7 +3325,7 @@
"provider_placeholder": "Selecione um prestador de serviços de pré-processamento de documentos",
"title": "Pré-processamento de Documentos"
},
"preprocessOrOcr": {
"preprocess": {
"tooltip": "Configure o provedor de pré-processamento de documentos ou OCR em Configurações -> Ferramentas. O pré-processamento de documentos pode melhorar significativamente a eficácia da busca em documentos com formatos complexos ou versões escaneadas. O OCR só consegue reconhecer texto em imagens ou PDFs escaneados."
},
"title": "Configurações de Ferramentas",

View File

@ -52,18 +52,11 @@ export function useSystemAgents() {
// 如果没有远程配置或获取失败,加载本地代理
if (resourcesPath) {
try {
let fileName = 'agents.json'
if (currentLanguage === 'zh-CN') {
fileName = 'agents-zh.json'
} else {
fileName = 'agents-en.json'
}
const fileName = currentLanguage === 'zh-CN' ? 'agents-zh.json' : 'agents-en.json'
const localAgentsData = await window.api.fs.read(`${resourcesPath}/data/${fileName}`, 'utf-8')
_agents = JSON.parse(localAgentsData) as Agent[]
} catch (error) {
const localAgentsData = await window.api.fs.read(resourcesPath + '/data/agents.json', 'utf-8')
_agents = JSON.parse(localAgentsData) as Agent[]
logger.error('Failed to load local agents:', error as Error)
}
}

View File

@ -1,12 +1,12 @@
import { DeleteOutlined, ExclamationCircleOutlined } from '@ant-design/icons'
import { DynamicVirtualList } from '@renderer/components/VirtualList'
import { handleDelete } from '@renderer/services/FileAction'
import FileManager from '@renderer/services/FileManager'
import { FileMetadata, FileTypes } from '@renderer/types'
import { formatFileSize } from '@renderer/utils'
import { Col, Image, Row, Spin } from 'antd'
import { t } from 'i18next'
import VirtualList from 'rc-virtual-list'
import React, { memo } from 'react'
import React, { memo, useCallback } from 'react'
import styled from 'styled-components'
import FileItem from './FileItem'
@ -27,6 +27,8 @@ interface FileItemProps {
}
const FileList: React.FC<FileItemProps> = ({ id, list, files }) => {
const estimateSize = useCallback(() => 75, [])
if (id === FileTypes.IMAGE && files?.length && files?.length > 0) {
return (
<div style={{ padding: 16, overflowY: 'auto' }}>
@ -78,38 +80,29 @@ const FileList: React.FC<FileItemProps> = ({ id, list, files }) => {
}
return (
<VirtualList
data={list}
height={window.innerHeight - 100}
itemHeight={75}
itemKey="key"
style={{ padding: '0 16px 16px 16px' }}
styles={{
verticalScrollBar: {
width: 6
},
verticalScrollBarThumb: {
background: 'var(--color-scrollbar-thumb)'
}
<DynamicVirtualList
list={list}
estimateSize={estimateSize}
overscan={2}
scrollerStyle={{
padding: '0 16px 16px 16px'
}}
itemContainerStyle={{
height: '75px',
paddingTop: '12px'
}}>
{(item) => (
<div
style={{
height: '75px',
paddingTop: '12px'
}}>
<FileItem
key={item.key}
fileInfo={{
name: item.file,
ext: item.ext,
extra: `${item.created_at} · ${item.count}${t('files.count')} · ${item.size}`,
actions: item.actions
}}
/>
</div>
<FileItem
key={item.key}
fileInfo={{
name: item.file,
ext: item.ext,
extra: `${item.created_at} · ${item.count}${t('files.count')} · ${item.size}`,
actions: item.actions
}}
/>
)}
</VirtualList>
</DynamicVirtualList>
)
}

View File

@ -704,7 +704,10 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
useEffect(() => {
if (!document.querySelector('.topview-fullscreen-container')) {
textareaRef.current?.focus()
const lastFocusedComponent = PasteService.getLastFocusedComponent()
if (lastFocusedComponent === 'inputbar') {
textareaRef.current?.focus()
}
}
}, [assistant, topic])

View File

@ -2,6 +2,7 @@ import { loggerService } from '@logger'
import Scrollbar from '@renderer/components/Scrollbar'
import { useMessageEditing } from '@renderer/context/MessageEditingContext'
import { useAssistant } from '@renderer/hooks/useAssistant'
import { useChatContext } from '@renderer/hooks/useChatContext'
import { useMessageOperations } from '@renderer/hooks/useMessageOperations'
import { useModel } from '@renderer/hooks/useModel'
import { useSettings } from '@renderer/hooks/useSettings'
@ -38,6 +39,16 @@ interface Props {
const logger = loggerService.withContext('MessageItem')
const WrapperContainer = ({
isMultiSelectMode,
children
}: {
isMultiSelectMode: boolean
children: React.ReactNode
}) => {
return isMultiSelectMode ? <label style={{ cursor: 'pointer' }}>{children}</label> : children
}
const MessageItem: FC<Props> = ({
message,
topic,
@ -49,6 +60,7 @@ const MessageItem: FC<Props> = ({
}) => {
const { t } = useTranslation()
const { assistant, setModel } = useAssistant(message.assistantId)
const { isMultiSelectMode } = useChatContext(topic)
const model = useModel(getMessageModelId(message), message.model?.provider) || message.model
const { messageFont, fontSize, messageStyle } = useSettings()
const { editMessageBlocks, resendUserMessageWithEdit, editMessage } = useMessageOperations(topic)
@ -122,7 +134,15 @@ const MessageItem: FC<Props> = ({
if (message.type === 'clear') {
return (
<NewContextMessage className="clear-context-divider" onClick={() => EventEmitter.emit(EVENT_NAMES.NEW_CONTEXT)}>
<NewContextMessage
isMultiSelectMode={isMultiSelectMode}
className="clear-context-divider"
onClick={() => {
if (isMultiSelectMode) {
return
}
EventEmitter.emit(EVENT_NAMES.NEW_CONTEXT)
}}>
<Divider dashed style={{ padding: '0 20px' }} plain>
{t('chat.message.new.context')}
</Divider>
@ -131,56 +151,64 @@ const MessageItem: FC<Props> = ({
}
return (
<MessageContainer
key={message.id}
className={classNames({
message: true,
'message-assistant': isAssistantMessage,
'message-user': !isAssistantMessage
})}
ref={messageContainerRef}>
<MessageHeader message={message} assistant={assistant} model={model} key={getModelUniqId(model)} topic={topic} />
{isEditing && (
<MessageEditor
<WrapperContainer isMultiSelectMode={isMultiSelectMode}>
<MessageContainer
key={message.id}
className={classNames({
message: true,
'message-assistant': isAssistantMessage,
'message-user': !isAssistantMessage
})}
ref={messageContainerRef}>
<MessageHeader
message={message}
topicId={topic.id}
onSave={handleEditSave}
onResend={handleEditResend}
onCancel={handleEditCancel}
assistant={assistant}
model={model}
key={getModelUniqId(model)}
topic={topic}
/>
)}
{!isEditing && (
<>
<MessageContentContainer
className="message-content-container"
style={{
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
fontSize,
overflowY: 'visible'
}}>
<MessageErrorBoundary>
<MessageContent message={message} />
</MessageErrorBoundary>
</MessageContentContainer>
{showMenubar && (
<MessageFooter className="MessageFooter" $isLastMessage={isLastMessage} $messageStyle={messageStyle}>
<MessageMenubar
message={message}
assistant={assistant}
model={model}
index={index}
topic={topic}
isLastMessage={isLastMessage}
isAssistantMessage={isAssistantMessage}
isGrouped={isGrouped}
messageContainerRef={messageContainerRef as React.RefObject<HTMLDivElement>}
setModel={setModel}
/>
</MessageFooter>
)}
</>
)}
</MessageContainer>
{isEditing && (
<MessageEditor
message={message}
topicId={topic.id}
onSave={handleEditSave}
onResend={handleEditResend}
onCancel={handleEditCancel}
/>
)}
{!isEditing && (
<>
<MessageContentContainer
className="message-content-container"
style={{
fontFamily: messageFont === 'serif' ? 'var(--font-family-serif)' : 'var(--font-family)',
fontSize,
overflowY: 'visible'
}}>
<MessageErrorBoundary>
<MessageContent message={message} />
</MessageErrorBoundary>
</MessageContentContainer>
{showMenubar && (
<MessageFooter className="MessageFooter" $isLastMessage={isLastMessage} $messageStyle={messageStyle}>
<MessageMenubar
message={message}
assistant={assistant}
model={model}
index={index}
topic={topic}
isLastMessage={isLastMessage}
isAssistantMessage={isAssistantMessage}
isGrouped={isGrouped}
messageContainerRef={messageContainerRef as React.RefObject<HTMLDivElement>}
setModel={setModel}
/>
</MessageFooter>
)}
</>
)}
</MessageContainer>
</WrapperContainer>
)
}
@ -232,9 +260,11 @@ const MessageFooter = styled.div<{ $isLastMessage: boolean; $messageStyle: 'plai
margin-top: 8px;
`
const NewContextMessage = styled.div`
const NewContextMessage = styled.div<{ isMultiSelectMode: boolean }>`
cursor: pointer;
flex: 1;
${({ isMultiSelectMode }) => isMultiSelectMode && 'cursor: default;'}
`
export default memo(MessageItem)

View File

@ -9,22 +9,23 @@ interface Props extends HTMLAttributes<HTMLDivElement> {
const NarrowLayout: FC<Props> = ({ children, ...props }) => {
const { narrowMode } = useSettings()
if (narrowMode) {
return (
<Container className="narrow-mode" {...props}>
{children}
</Container>
)
}
return children
return (
<Container className={`narrow-mode ${narrowMode ? 'active' : ''}`} {...props}>
{children}
</Container>
)
}
const Container = styled.div`
max-width: 800px;
max-width: 100%;
width: 100%;
margin: 0 auto;
position: relative;
transition: max-width 0.3s ease-in-out;
&.active {
max-width: 800px;
}
`
export default NarrowLayout

View File

@ -1,7 +1,9 @@
import { useTheme } from '@renderer/context/ThemeProvider'
import { usePromptProcessor } from '@renderer/hooks/usePromptProcessor'
import AssistantSettingsPopup from '@renderer/pages/settings/AssistantSettings'
import { Assistant, Topic } from '@renderer/types'
import { FC } from 'react'
import { containsSupportedVariables } from '@renderer/utils/prompt'
import { FC, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@ -18,13 +20,50 @@ const Prompt: FC<Props> = ({ assistant, topic }) => {
const topicPrompt = topic?.prompt || ''
const isDark = theme === 'dark'
const processedPrompt = usePromptProcessor({ prompt, modelName: assistant.model?.name })
// 用于控制显示的状态
const [displayText, setDisplayText] = useState(prompt)
const [isVisible, setIsVisible] = useState(true)
useEffect(() => {
// 如果没有变量需要替换,直接显示处理后的内容
if (!containsSupportedVariables(prompt)) {
setDisplayText(processedPrompt)
setIsVisible(true)
return
}
// 如果有变量需要替换先显示原始prompt
setDisplayText(prompt)
setIsVisible(true)
// 延迟过渡
let innerTimer: NodeJS.Timeout
const outerTimer = setTimeout(() => {
// 先淡出
setIsVisible(false)
// 切换内容并淡入
innerTimer = setTimeout(() => {
setDisplayText(processedPrompt)
setIsVisible(true)
}, 300)
}, 300)
return () => {
clearTimeout(outerTimer)
clearTimeout(innerTimer)
}
}, [prompt, processedPrompt])
if (!prompt && !topicPrompt) {
return null
}
return (
<Container className="system-prompt" onClick={() => AssistantSettingsPopup.show({ assistant })} $isDark={isDark}>
<Text>{prompt}</Text>
<Text $isVisible={isVisible}>{displayText}</Text>
</Container>
)
}
@ -38,13 +77,17 @@ const Container = styled.div<{ $isDark: boolean }>`
margin-bottom: 0;
`
const Text = styled.div`
const Text = styled.div<{ $isVisible: boolean }>`
color: var(--color-text-2);
font-size: 12px;
display: -webkit-box;
-webkit-line-clamp: 2;
-webkit-box-orient: vertical;
overflow: hidden;
user-select: none;
opacity: ${(props) => (props.$isVisible ? 1 : 0)};
transition: opacity 0.3s ease-in-out;
`
export default Prompt

View File

@ -17,9 +17,14 @@ const SelectionBox: React.FC<SelectionBoxProps> = ({
const [isDragging, setIsDragging] = useState(false)
const [dragStart, setDragStart] = useState({ x: 0, y: 0 })
const [dragCurrent, setDragCurrent] = useState({ x: 0, y: 0 })
const [isMouseDown, setIsMouseDown] = useState(false)
const dragSelectedIds = useRef<Set<string>>(new Set())
// 拖拽阈值,只有移动距离超过这个值才开始框选
// 避免触控板点击触发拖拽
const DRAG_THRESHOLD = 5
useEffect(() => {
if (!isMultiSelectMode) return
@ -39,20 +44,30 @@ const SelectionBox: React.FC<SelectionBoxProps> = ({
e.preventDefault()
setIsDragging(true)
setIsMouseDown(true)
const pos = updateDragPos(e)
setDragStart(pos)
setDragCurrent(pos)
dragSelectedIds.current.clear()
document.body.classList.add('no-select')
}
const handleMouseMove = (e: MouseEvent) => {
if (!isMouseDown) return
const pos = updateDragPos(e)
const deltaX = Math.abs(pos.x - dragStart.x)
const deltaY = Math.abs(pos.y - dragStart.y)
const distance = Math.sqrt(deltaX * deltaX + deltaY * deltaY)
if (!isDragging && distance > DRAG_THRESHOLD) {
setIsDragging(true)
document.body.classList.add('no-select')
}
if (!isDragging) return
e.preventDefault()
const pos = updateDragPos(e)
setDragCurrent(pos)
// 计算当前框选矩形
@ -69,6 +84,9 @@ const SelectionBox: React.FC<SelectionBoxProps> = ({
const checkbox = el.querySelector('input[type="checkbox"]') as HTMLInputElement | null
const isAlreadySelected = checkbox?.checked || false
// 清除上下文这类消息也会被选中,所以需要跳过
if (!checkbox) return
// 如果已经被记录为拖动选中,跳过
if (dragSelectedIds.current.has(id)) return
@ -94,9 +112,11 @@ const SelectionBox: React.FC<SelectionBoxProps> = ({
}
const handleMouseUp = () => {
if (!isDragging) return
setIsDragging(false)
document.body.classList.remove('no-select')
setIsMouseDown(false)
if (isDragging) {
setIsDragging(false)
document.body.classList.remove('no-select')
}
}
const container = scrollContainerRef.current!
@ -110,7 +130,7 @@ const SelectionBox: React.FC<SelectionBoxProps> = ({
window.removeEventListener('mouseup', handleMouseUp)
document.body.classList.remove('no-select')
}
}, [isMultiSelectMode, isDragging, dragStart, scrollContainerRef, messageElements, handleSelectMessage])
}, [isMultiSelectMode, isDragging, isMouseDown, dragStart, scrollContainerRef, messageElements, handleSelectMessage])
if (!isDragging || !isMultiSelectMode) return null

View File

@ -10,7 +10,7 @@ import {
QuestionCircleOutlined,
UploadOutlined
} from '@ant-design/icons'
import { DraggableVirtualList as DraggableList } from '@renderer/components/DraggableList'
import { DraggableVirtualList } from '@renderer/components/DraggableList'
import CopyIcon from '@renderer/components/Icons/CopyIcon'
import ObsidianExportPopup from '@renderer/components/Popups/ObsidianExportPopup'
import PromptPopup from '@renderer/components/Popups/PromptPopup'
@ -438,11 +438,11 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic,
const singlealone = topicPosition === 'right' && position === 'right'
return (
<DraggableList
<DraggableVirtualList
className="topics-tab"
list={sortedTopics}
onUpdate={updateTopics}
style={{ height: '100%', padding: '13px 0 10px 10px', display: 'flex', flexDirection: 'column' }}
style={{ height: '100%', padding: '13px 0 10px 10px' }}
itemContainerStyle={{ paddingBottom: '8px' }}
header={
<AddTopicButton onClick={() => EventEmitter.emit(EVENT_NAMES.ADD_NEW_TOPIC)}>
@ -521,7 +521,7 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic,
</Dropdown>
)
}}
</DraggableList>
</DraggableVirtualList>
)
}

View File

@ -139,8 +139,8 @@ const KnowledgeContent: FC<KnowledgeContentProps> = ({ selectedBase }) => {
</div>
</Tooltip>
{base.rerankModel && <Tag style={{ borderRadius: 20, margin: 0 }}>{base.rerankModel.name}</Tag>}
{base.preprocessOrOcrProvider && base.preprocessOrOcrProvider.type === 'preprocess' && (
<QuotaTag base={base} providerId={base.preprocessOrOcrProvider?.provider.id} quota={quota} />
{base.preprocessProvider && base.preprocessProvider.type === 'preprocess' && (
<QuotaTag base={base} providerId={base.preprocessProvider?.provider.id} quota={quota} />
)}
</div>
</ModelInfo>

View File

@ -104,36 +104,34 @@ const KnowledgePage: FC = () => {
</Navbar>
<ContentContainer id="content-container">
<KnowledgeSideNav>
<ScrollContainer>
<DraggableList
list={bases}
onUpdate={updateKnowledgeBases}
style={{ marginBottom: 0, paddingBottom: isDragging ? 50 : 0 }}
onDragStart={() => setIsDragging(true)}
onDragEnd={() => setIsDragging(false)}>
{(base: KnowledgeBase) => (
<Dropdown menu={{ items: getMenuItems(base) }} trigger={['contextMenu']} key={base.id}>
<div>
<ListItem
active={selectedBase?.id === base.id}
icon={<Book size={16} />}
title={base.name}
onClick={() => setSelectedBase(base)}
/>
</div>
</Dropdown>
)}
</DraggableList>
{!isDragging && (
<AddKnowledgeItem onClick={handleAddKnowledge}>
<AddKnowledgeName>
<Plus size={18} />
{t('button.add')}
</AddKnowledgeName>
</AddKnowledgeItem>
<DraggableList
list={bases}
onUpdate={updateKnowledgeBases}
style={{ marginBottom: 0, paddingBottom: isDragging ? 50 : 0 }}
onDragStart={() => setIsDragging(true)}
onDragEnd={() => setIsDragging(false)}>
{(base: KnowledgeBase) => (
<Dropdown menu={{ items: getMenuItems(base) }} trigger={['contextMenu']} key={base.id}>
<div>
<ListItem
active={selectedBase?.id === base.id}
icon={<Book size={16} />}
title={base.name}
onClick={() => setSelectedBase(base)}
/>
</div>
</Dropdown>
)}
<div style={{ minHeight: '10px' }}></div>
</ScrollContainer>
</DraggableList>
{!isDragging && (
<AddKnowledgeItem onClick={handleAddKnowledge}>
<AddKnowledgeName>
<Plus size={18} />
{t('button.add')}
</AddKnowledgeName>
</AddKnowledgeItem>
)}
<div style={{ minHeight: '10px' }}></div>
</KnowledgeSideNav>
{bases.length === 0 ? (
<MainContent>
@ -169,13 +167,14 @@ const MainContent = styled(Scrollbar)`
padding-bottom: 50px;
`
export const KnowledgeSideNav = styled.div`
min-width: var(--settings-width);
border-right: 0.5px solid var(--color-border);
padding: 12px 10px;
const KnowledgeSideNav = styled(Scrollbar)`
display: flex;
flex-direction: column;
width: calc(var(--settings-width) + 100px);
border-right: 0.5px solid var(--color-border);
padding: 12px 10px;
.ant-menu {
border-inline-end: none !important;
background: transparent;
@ -197,12 +196,6 @@ export const KnowledgeSideNav = styled.div`
color: var(--color-primary);
}
}
`
const ScrollContainer = styled(Scrollbar)`
display: flex;
flex-direction: column;
flex: 1;
> div {
margin-bottom: 8px;

View File

@ -34,45 +34,6 @@ exports[`GeneralSettingsPanel > basic rendering > should match snapshot 1`] = `
value="Test Knowledge Base"
/>
</div>
<div
class="c1"
>
<div
class="settings-label"
>
settings.tool.preprocess.title
/
settings.tool.ocr.title
<span
data-placement="right"
data-testid="info-tooltip"
title="settings.tool.preprocessOrOcr.tooltip"
>
</span>
</div>
<select
data-allow-clear="true"
data-placeholder="settings.tool.preprocess.provider_placeholder"
data-testid="preprocess-select"
>
<option
value=""
>
Select option
</option>
<option
value="doc2x"
>
Doc2X
</option>
<option
value="mistral"
>
Mistral
</option>
</select>
</div>
<div
class="c1"
>
@ -172,6 +133,43 @@ exports[`GeneralSettingsPanel > basic rendering > should match snapshot 1`] = `
</option>
</select>
</div>
<div
class="c1"
>
<div
class="settings-label"
>
settings.tool.preprocess.title
<span
data-placement="right"
data-testid="info-tooltip"
title="settings.tool.preprocess.tooltip"
>
</span>
</div>
<select
data-allow-clear="true"
data-placeholder="settings.tool.preprocess.provider_placeholder"
data-testid="preprocess-select"
>
<option
value=""
>
Select option
</option>
<option
value="doc2x"
>
Doc2X
</option>
<option
value="mistral"
>
Mistral
</option>
</select>
</div>
<div
class="c1"
>

View File

@ -6,7 +6,7 @@ import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { useProviders } from '@renderer/hooks/useProvider'
import { getModelUniqId } from '@renderer/services/ModelService'
import { KnowledgeBase, PreprocessProvider } from '@renderer/types'
import { Input, Select, Slider } from 'antd'
import { Input, Select, SelectProps, Slider } from 'antd'
import { useTranslation } from 'react-i18next'
import { SettingsItem, SettingsPanel } from './styles'
@ -15,7 +15,7 @@ interface GeneralSettingsPanelProps {
newBase: KnowledgeBase
setNewBase: React.Dispatch<React.SetStateAction<KnowledgeBase>>
selectedDocPreprocessProvider?: PreprocessProvider
docPreprocessSelectOptions: any[]
docPreprocessSelectOptions: SelectProps['options']
handlers: {
handleEmbeddingModelChange: (value: string) => void
handleDimensionChange: (value: number | null) => void
@ -47,21 +47,6 @@ const GeneralSettingsPanel: React.FC<GeneralSettingsPanelProps> = ({
/>
</SettingsItem>
<SettingsItem>
<div className="settings-label">
{t('settings.tool.preprocess.title')} / {t('settings.tool.ocr.title')}
<InfoTooltip title={t('settings.tool.preprocessOrOcr.tooltip')} placement="right" />
</div>
<Select
value={selectedDocPreprocessProvider?.id}
style={{ width: '100%' }}
onChange={handleDocPreprocessChange}
placeholder={t('settings.tool.preprocess.provider_placeholder')}
options={docPreprocessSelectOptions}
allowClear
/>
</SettingsItem>
<SettingsItem>
<div className="settings-label">
{t('models.embedding_model')}
@ -106,6 +91,21 @@ const GeneralSettingsPanel: React.FC<GeneralSettingsPanelProps> = ({
/>
</SettingsItem>
<SettingsItem>
<div className="settings-label">
{t('settings.tool.preprocess.title')}
<InfoTooltip title={t('settings.tool.preprocess.tooltip')} placement="right" />
</div>
<Select
value={selectedDocPreprocessProvider?.id}
style={{ width: '100%' }}
onChange={handleDocPreprocessChange}
placeholder={t('settings.tool.preprocess.provider_placeholder')}
options={docPreprocessSelectOptions}
allowClear
/>
</SettingsItem>
<SettingsItem>
<div className="settings-label">
{t('knowledge.document_count')}

View File

@ -1,7 +1,7 @@
import { DeleteOutlined } from '@ant-design/icons'
import { loggerService } from '@logger'
import Ellipsis from '@renderer/components/Ellipsis'
import Scrollbar from '@renderer/components/Scrollbar'
import { DynamicVirtualList } from '@renderer/components/VirtualList'
import { useKnowledge } from '@renderer/hooks/useKnowledge'
import FileItem from '@renderer/pages/files/FileItem'
import { getProviderName } from '@renderer/services/ProviderService'
@ -9,7 +9,7 @@ import { KnowledgeBase, KnowledgeItem } from '@renderer/types'
import { Button, Tooltip } from 'antd'
import dayjs from 'dayjs'
import { Plus } from 'lucide-react'
import { FC } from 'react'
import { FC, useCallback, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@ -46,6 +46,9 @@ const KnowledgeDirectories: FC<KnowledgeContentProps> = ({ selectedBase, progres
const providerName = getProviderName(base?.model.provider || '')
const disabled = !base?.version || !providerName
const reversedItems = useMemo(() => [...directoryItems].reverse(), [directoryItems])
const estimateSize = useCallback(() => 75, [])
if (!base) {
return null
}
@ -76,46 +79,51 @@ const KnowledgeDirectories: FC<KnowledgeContentProps> = ({ selectedBase, progres
</ItemHeader>
<ItemFlexColumn>
{directoryItems.length === 0 && <KnowledgeEmptyView />}
{directoryItems.reverse().map((item) => (
<FileItem
key={item.id}
fileInfo={{
name: (
<ClickableSpan onClick={() => window.api.file.openPath(item.content as string)}>
<Ellipsis>
<Tooltip title={item.content as string}>{item.content as string}</Tooltip>
</Ellipsis>
</ClickableSpan>
),
ext: '.folder',
extra: getDisplayTime(item),
actions: (
<FlexAlignCenter>
{item.uniqueId && <Button type="text" icon={<RefreshIcon />} onClick={() => refreshItem(item)} />}
<StatusIconWrapper>
<StatusIcon
sourceId={item.id}
base={base}
getProcessingStatus={getProcessingStatus}
progress={progressMap.get(item.id)}
type="directory"
/>
</StatusIconWrapper>
<Button type="text" danger onClick={() => removeItem(item)} icon={<DeleteOutlined />} />
</FlexAlignCenter>
)
}}
/>
))}
<DynamicVirtualList
list={reversedItems}
estimateSize={estimateSize}
overscan={2}
scrollerStyle={{ paddingRight: 2 }}
itemContainerStyle={{ paddingBottom: 10 }}
autoHideScrollbar>
{(item) => (
<FileItem
key={item.id}
fileInfo={{
name: (
<ClickableSpan onClick={() => window.api.file.openPath(item.content as string)}>
<Ellipsis>
<Tooltip title={item.content as string}>{item.content as string}</Tooltip>
</Ellipsis>
</ClickableSpan>
),
ext: '.folder',
extra: getDisplayTime(item),
actions: (
<FlexAlignCenter>
{item.uniqueId && <Button type="text" icon={<RefreshIcon />} onClick={() => refreshItem(item)} />}
<StatusIconWrapper>
<StatusIcon
sourceId={item.id}
base={base}
getProcessingStatus={getProcessingStatus}
progress={progressMap.get(item.id)}
type="directory"
/>
</StatusIconWrapper>
<Button type="text" danger onClick={() => removeItem(item)} icon={<DeleteOutlined />} />
</FlexAlignCenter>
)
}}
/>
)}
</DynamicVirtualList>
</ItemFlexColumn>
</ItemContainer>
)
}
const ItemFlexColumn = styled(Scrollbar)`
display: flex;
flex-direction: column;
gap: 10px;
const ItemFlexColumn = styled.div`
padding: 20px 16px;
height: calc(100vh - 135px);
`

View File

@ -12,13 +12,14 @@ import { bookExts, documentExts, textExts, thirdPartyApplicationExts } from '@sh
import { Button, Tooltip, Upload } from 'antd'
import dayjs from 'dayjs'
import { Plus } from 'lucide-react'
import VirtualList from 'rc-virtual-list'
import { FC, useEffect, useState } from 'react'
import { FC, useCallback, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
const logger = loggerService.withContext('KnowledgeFiles')
import { DynamicVirtualList } from '@renderer/components/VirtualList'
import {
ClickableSpan,
FlexAlignCenter,
@ -64,6 +65,8 @@ const KnowledgeFiles: FC<KnowledgeContentProps> = ({ selectedBase, progressMap,
const providerName = getProviderName(base?.model.provider || '')
const disabled = !base?.version || !providerName
const estimateSize = useCallback(() => 75, [])
if (!base) {
return null
}
@ -122,10 +125,10 @@ const KnowledgeFiles: FC<KnowledgeContentProps> = ({ selectedBase, progressMap,
}
const showPreprocessIcon = (item: KnowledgeItem) => {
if (base.preprocessOrOcrProvider && item.isPreprocessed !== false) {
if (base.preprocessProvider && item.isPreprocessed !== false) {
return true
}
if (!base.preprocessOrOcrProvider && item.isPreprocessed === true) {
if (!base.preprocessProvider && item.isPreprocessed === true) {
return true
}
return false
@ -160,15 +163,12 @@ const KnowledgeFiles: FC<KnowledgeContentProps> = ({ selectedBase, progressMap,
{fileItems.length === 0 ? (
<KnowledgeEmptyView />
) : (
<VirtualList
data={fileItems.reverse()}
height={windowHeight - 270}
itemHeight={75}
itemKey="id"
styles={{
verticalScrollBar: { width: 6 },
verticalScrollBarThumb: { background: 'var(--color-scrollbar-thumb)' }
}}>
<DynamicVirtualList
list={fileItems.reverse()}
estimateSize={estimateSize}
overscan={2}
scrollerStyle={{ height: windowHeight - 270 }}
autoHideScrollbar>
{(item) => {
const file = item.content as FileType
return (
@ -218,7 +218,7 @@ const KnowledgeFiles: FC<KnowledgeContentProps> = ({ selectedBase, progressMap,
</div>
)
}}
</VirtualList>
</DynamicVirtualList>
)}
</ItemFlexColumn>
</ItemContainer>

View File

@ -1,6 +1,6 @@
import { DeleteOutlined, EditOutlined } from '@ant-design/icons'
import TextEditPopup from '@renderer/components/Popups/TextEditPopup'
import Scrollbar from '@renderer/components/Scrollbar'
import { DynamicVirtualList } from '@renderer/components/VirtualList'
import { useKnowledge } from '@renderer/hooks/useKnowledge'
import FileItem from '@renderer/pages/files/FileItem'
import { getProviderName } from '@renderer/services/ProviderService'
@ -8,7 +8,7 @@ import { KnowledgeBase, KnowledgeItem } from '@renderer/types'
import { Button } from 'antd'
import dayjs from 'dayjs'
import { Plus } from 'lucide-react'
import { FC } from 'react'
import { FC, useCallback, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@ -34,6 +34,9 @@ const KnowledgeNotes: FC<KnowledgeContentProps> = ({ selectedBase }) => {
const providerName = getProviderName(base?.model.provider || '')
const disabled = !base?.version || !providerName
const reversedItems = useMemo(() => [...noteItems].reverse(), [noteItems])
const estimateSize = useCallback(() => 75, [])
if (!base) {
return null
}
@ -72,34 +75,44 @@ const KnowledgeNotes: FC<KnowledgeContentProps> = ({ selectedBase }) => {
</ItemHeader>
<ItemFlexColumn>
{noteItems.length === 0 && <KnowledgeEmptyView />}
{noteItems.reverse().map((note) => (
<FileItem
key={note.id}
fileInfo={{
name: <span onClick={() => handleEditNote(note)}>{(note.content as string).slice(0, 50)}...</span>,
ext: '.txt',
extra: getDisplayTime(note),
actions: (
<FlexAlignCenter>
<Button type="text" onClick={() => handleEditNote(note)} icon={<EditOutlined />} />
<StatusIconWrapper>
<StatusIcon sourceId={note.id} base={base} getProcessingStatus={getProcessingStatus} type="note" />
</StatusIconWrapper>
<Button type="text" danger onClick={() => removeItem(note)} icon={<DeleteOutlined />} />
</FlexAlignCenter>
)
}}
/>
))}
<DynamicVirtualList
list={reversedItems}
estimateSize={estimateSize}
overscan={2}
scrollerStyle={{ paddingRight: 2 }}
itemContainerStyle={{ paddingBottom: 10 }}
autoHideScrollbar>
{(note) => (
<FileItem
key={note.id}
fileInfo={{
name: <span onClick={() => handleEditNote(note)}>{(note.content as string).slice(0, 50)}...</span>,
ext: '.txt',
extra: getDisplayTime(note),
actions: (
<FlexAlignCenter>
<Button type="text" onClick={() => handleEditNote(note)} icon={<EditOutlined />} />
<StatusIconWrapper>
<StatusIcon
sourceId={note.id}
base={base}
getProcessingStatus={getProcessingStatus}
type="note"
/>
</StatusIconWrapper>
<Button type="text" danger onClick={() => removeItem(note)} icon={<DeleteOutlined />} />
</FlexAlignCenter>
)
}}
/>
)}
</DynamicVirtualList>
</ItemFlexColumn>
</ItemContainer>
)
}
const ItemFlexColumn = styled(Scrollbar)`
display: flex;
flex-direction: column;
gap: 10px;
const ItemFlexColumn = styled.div`
padding: 20px 16px;
height: calc(100vh - 135px);
`

View File

@ -2,7 +2,7 @@ import { DeleteOutlined } from '@ant-design/icons'
import { loggerService } from '@logger'
import Ellipsis from '@renderer/components/Ellipsis'
import PromptPopup from '@renderer/components/Popups/PromptPopup'
import Scrollbar from '@renderer/components/Scrollbar'
import { DynamicVirtualList } from '@renderer/components/VirtualList'
import { useKnowledge } from '@renderer/hooks/useKnowledge'
import FileItem from '@renderer/pages/files/FileItem'
import { getProviderName } from '@renderer/services/ProviderService'
@ -10,7 +10,7 @@ import { KnowledgeBase, KnowledgeItem } from '@renderer/types'
import { Button, message, Tooltip } from 'antd'
import dayjs from 'dayjs'
import { Plus } from 'lucide-react'
import { FC } from 'react'
import { FC, useCallback, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@ -46,6 +46,9 @@ const KnowledgeSitemaps: FC<KnowledgeContentProps> = ({ selectedBase }) => {
const providerName = getProviderName(base?.model.provider || '')
const disabled = !base?.version || !providerName
const reversedItems = useMemo(() => [...sitemapItems].reverse(), [sitemapItems])
const estimateSize = useCallback(() => 75, [])
if (!base) {
return null
}
@ -95,49 +98,54 @@ const KnowledgeSitemaps: FC<KnowledgeContentProps> = ({ selectedBase }) => {
</ItemHeader>
<ItemFlexColumn>
{sitemapItems.length === 0 && <KnowledgeEmptyView />}
{sitemapItems.reverse().map((item) => (
<FileItem
key={item.id}
fileInfo={{
name: (
<ClickableSpan>
<Tooltip title={item.content as string}>
<Ellipsis>
<a href={item.content as string} target="_blank" rel="noopener noreferrer">
{item.content as string}
</a>
</Ellipsis>
</Tooltip>
</ClickableSpan>
),
ext: '.sitemap',
extra: getDisplayTime(item),
actions: (
<FlexAlignCenter>
{item.uniqueId && <Button type="text" icon={<RefreshIcon />} onClick={() => refreshItem(item)} />}
<StatusIconWrapper>
<StatusIcon
sourceId={item.id}
base={base}
getProcessingStatus={getProcessingStatus}
type="sitemap"
/>
</StatusIconWrapper>
<Button type="text" danger onClick={() => removeItem(item)} icon={<DeleteOutlined />} />
</FlexAlignCenter>
)
}}
/>
))}
<DynamicVirtualList
list={reversedItems}
estimateSize={estimateSize}
overscan={2}
scrollerStyle={{ paddingRight: 2 }}
itemContainerStyle={{ paddingBottom: 10 }}
autoHideScrollbar>
{(item) => (
<FileItem
key={item.id}
fileInfo={{
name: (
<ClickableSpan>
<Tooltip title={item.content as string}>
<Ellipsis>
<a href={item.content as string} target="_blank" rel="noopener noreferrer">
{item.content as string}
</a>
</Ellipsis>
</Tooltip>
</ClickableSpan>
),
ext: '.sitemap',
extra: getDisplayTime(item),
actions: (
<FlexAlignCenter>
{item.uniqueId && <Button type="text" icon={<RefreshIcon />} onClick={() => refreshItem(item)} />}
<StatusIconWrapper>
<StatusIcon
sourceId={item.id}
base={base}
getProcessingStatus={getProcessingStatus}
type="sitemap"
/>
</StatusIconWrapper>
<Button type="text" danger onClick={() => removeItem(item)} icon={<DeleteOutlined />} />
</FlexAlignCenter>
)
}}
/>
)}
</DynamicVirtualList>
</ItemFlexColumn>
</ItemContainer>
)
}
const ItemFlexColumn = styled(Scrollbar)`
display: flex;
flex-direction: column;
gap: 10px;
const ItemFlexColumn = styled.div`
padding: 20px 16px;
height: calc(100vh - 135px);
`

View File

@ -1,7 +1,7 @@
import { CopyOutlined, DeleteOutlined, EditOutlined } from '@ant-design/icons'
import Ellipsis from '@renderer/components/Ellipsis'
import PromptPopup from '@renderer/components/Popups/PromptPopup'
import Scrollbar from '@renderer/components/Scrollbar'
import { DynamicVirtualList } from '@renderer/components/VirtualList'
import { useKnowledge } from '@renderer/hooks/useKnowledge'
import FileItem from '@renderer/pages/files/FileItem'
import { getProviderName } from '@renderer/services/ProviderService'
@ -9,7 +9,7 @@ import { KnowledgeBase, KnowledgeItem } from '@renderer/types'
import { Button, Dropdown, Tooltip } from 'antd'
import dayjs from 'dayjs'
import { Plus } from 'lucide-react'
import { FC } from 'react'
import { FC, useCallback, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@ -43,6 +43,9 @@ const KnowledgeUrls: FC<KnowledgeContentProps> = ({ selectedBase }) => {
const providerName = getProviderName(base?.model.provider || '')
const disabled = !base?.version || !providerName
const reversedItems = useMemo(() => [...urlItems].reverse(), [urlItems])
const estimateSize = useCallback(() => 75, [])
if (!base) {
return null
}
@ -123,66 +126,71 @@ const KnowledgeUrls: FC<KnowledgeContentProps> = ({ selectedBase }) => {
</ItemHeader>
<ItemFlexColumn>
{urlItems.length === 0 && <KnowledgeEmptyView />}
{urlItems.reverse().map((item) => (
<FileItem
key={item.id}
fileInfo={{
name: (
<Dropdown
menu={{
items: [
{
key: 'edit',
icon: <EditOutlined />,
label: t('knowledge.edit_remark'),
onClick: () => handleEditRemark(item)
},
{
key: 'copy',
icon: <CopyOutlined />,
label: t('common.copy'),
onClick: () => {
navigator.clipboard.writeText(item.content as string)
window.message.success(t('message.copied'))
<DynamicVirtualList
list={reversedItems}
estimateSize={estimateSize}
overscan={2}
scrollerStyle={{ paddingRight: 2 }}
itemContainerStyle={{ paddingBottom: 10 }}
autoHideScrollbar>
{(item) => (
<FileItem
key={item.id}
fileInfo={{
name: (
<Dropdown
menu={{
items: [
{
key: 'edit',
icon: <EditOutlined />,
label: t('knowledge.edit_remark'),
onClick: () => handleEditRemark(item)
},
{
key: 'copy',
icon: <CopyOutlined />,
label: t('common.copy'),
onClick: () => {
navigator.clipboard.writeText(item.content as string)
window.message.success(t('message.copied'))
}
}
}
]
}}
trigger={['contextMenu']}>
<ClickableSpan>
<Tooltip title={item.content as string}>
<Ellipsis>
<a href={item.content as string} target="_blank" rel="noopener noreferrer">
{item.remark || (item.content as string)}
</a>
</Ellipsis>
</Tooltip>
</ClickableSpan>
</Dropdown>
),
ext: '.url',
extra: getDisplayTime(item),
actions: (
<FlexAlignCenter>
{item.uniqueId && <Button type="text" icon={<RefreshIcon />} onClick={() => refreshItem(item)} />}
<StatusIconWrapper>
<StatusIcon sourceId={item.id} base={base} getProcessingStatus={getProcessingStatus} type="url" />
</StatusIconWrapper>
<Button type="text" danger onClick={() => removeItem(item)} icon={<DeleteOutlined />} />
</FlexAlignCenter>
)
}}
/>
))}
]
}}
trigger={['contextMenu']}>
<ClickableSpan>
<Tooltip title={item.content as string}>
<Ellipsis>
<a href={item.content as string} target="_blank" rel="noopener noreferrer">
{item.remark || (item.content as string)}
</a>
</Ellipsis>
</Tooltip>
</ClickableSpan>
</Dropdown>
),
ext: '.url',
extra: getDisplayTime(item),
actions: (
<FlexAlignCenter>
{item.uniqueId && <Button type="text" icon={<RefreshIcon />} onClick={() => refreshItem(item)} />}
<StatusIconWrapper>
<StatusIcon sourceId={item.id} base={base} getProcessingStatus={getProcessingStatus} type="url" />
</StatusIconWrapper>
<Button type="text" danger onClick={() => removeItem(item)} icon={<DeleteOutlined />} />
</FlexAlignCenter>
)
}}
/>
)}
</DynamicVirtualList>
</ItemFlexColumn>
</ItemContainer>
)
}
const ItemFlexColumn = styled(Scrollbar)`
display: flex;
flex-direction: column;
gap: 10px;
const ItemFlexColumn = styled.div`
padding: 20px 16px;
height: calc(100vh - 135px);
`

View File

@ -7,7 +7,7 @@ import { useMinappPopup } from '@renderer/hooks/useMinappPopup'
import { useRuntime } from '@renderer/hooks/useRuntime'
import { useSettings } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { useAppDispatch } from '@renderer/store'
import { handleSaveData, useAppDispatch } from '@renderer/store'
import { setUpdateState } from '@renderer/store/runtime'
import { ThemeMode } from '@renderer/types'
import { runAsyncFunction } from '@renderer/utils'
@ -41,6 +41,7 @@ const AboutSettings: FC = () => {
}
if (update.downloaded) {
await handleSaveData()
window.api.showUpdateDialog()
return
}

View File

@ -1,423 +0,0 @@
import { useTheme } from '@renderer/context/ThemeProvider'
import { loggerService } from '@renderer/services/LoggerService'
import { RootState, useAppDispatch } from '@renderer/store'
import { setApiServerApiKey, setApiServerEnabled, setApiServerPort } from '@renderer/store/settings'
import { IpcChannel } from '@shared/IpcChannel'
import { Button, Input, InputNumber, Tooltip, Typography } from 'antd'
import { Copy, ExternalLink, Play, RotateCcw, Square } from 'lucide-react'
import { FC, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useSelector } from 'react-redux'
import styled from 'styled-components'
import { v4 as uuidv4 } from 'uuid'
import { SettingContainer } from '..'
const logger = loggerService.withContext('ApiServerSettings')
const { Text, Title } = Typography
const ApiServerSettings: FC = () => {
const { theme } = useTheme()
const dispatch = useAppDispatch()
const { t } = useTranslation()
// API Server state with proper defaults
const apiServerConfig = useSelector((state: RootState) => state.settings.apiServer)
const [apiServerRunning, setApiServerRunning] = useState(false)
const [apiServerLoading, setApiServerLoading] = useState(false)
// API Server functions
const checkApiServerStatus = async () => {
try {
const status = await window.electron.ipcRenderer.invoke(IpcChannel.ApiServer_GetStatus)
setApiServerRunning(status.running)
} catch (error: any) {
logger.error('Failed to check API server status:', error)
}
}
useEffect(() => {
checkApiServerStatus()
}, [])
const handleApiServerToggle = async (enabled: boolean) => {
setApiServerLoading(true)
try {
if (enabled) {
const result = await window.electron.ipcRenderer.invoke(IpcChannel.ApiServer_Start)
if (result.success) {
setApiServerRunning(true)
window.message.success(t('apiServer.messages.startSuccess'))
} else {
window.message.error(t('apiServer.messages.startError') + result.error)
}
} else {
const result = await window.electron.ipcRenderer.invoke(IpcChannel.ApiServer_Stop)
if (result.success) {
setApiServerRunning(false)
window.message.success(t('apiServer.messages.stopSuccess'))
} else {
window.message.error(t('apiServer.messages.stopError') + result.error)
}
}
} catch (error) {
window.message.error(t('apiServer.messages.operationFailed') + (error as Error).message)
} finally {
dispatch(setApiServerEnabled(enabled))
setApiServerLoading(false)
}
}
const handleApiServerRestart = async () => {
setApiServerLoading(true)
try {
const result = await window.electron.ipcRenderer.invoke(IpcChannel.ApiServer_Restart)
if (result.success) {
await checkApiServerStatus()
window.message.success(t('apiServer.messages.restartSuccess'))
} else {
window.message.error(t('apiServer.messages.restartError') + result.error)
}
} catch (error) {
window.message.error(t('apiServer.messages.restartFailed') + (error as Error).message)
} finally {
setApiServerLoading(false)
}
}
const copyApiKey = () => {
navigator.clipboard.writeText(apiServerConfig.apiKey)
window.message.success(t('apiServer.messages.apiKeyCopied'))
}
const regenerateApiKey = () => {
const newApiKey = `cs-sk-${uuidv4()}`
dispatch(setApiServerApiKey(newApiKey))
window.message.success(t('apiServer.messages.apiKeyRegenerated'))
}
const handlePortChange = (value: string) => {
const port = parseInt(value) || 23333
if (port >= 1000 && port <= 65535) {
dispatch(setApiServerPort(port))
}
}
const openApiDocs = () => {
if (apiServerRunning) {
window.open(`http://localhost:${apiServerConfig.port}/api-docs`, '_blank')
}
}
return (
<Container theme={theme}>
{/* Header Section */}
<HeaderSection>
<HeaderContent>
<Title level={3} style={{ margin: 0, marginBottom: 8 }}>
{t('apiServer.title')}
</Title>
<Text type="secondary">{t('apiServer.description')}</Text>
</HeaderContent>
{apiServerRunning && (
<Button type="primary" icon={<ExternalLink size={14} />} onClick={openApiDocs}>
{t('apiServer.documentation.title')}
</Button>
)}
</HeaderSection>
{/* Server Control Panel with integrated configuration */}
<ServerControlPanel $status={apiServerRunning}>
<StatusSection>
<StatusIndicator $status={apiServerRunning} />
<StatusContent>
<StatusText $status={apiServerRunning}>
{apiServerRunning ? t('apiServer.status.running') : t('apiServer.status.stopped')}
</StatusText>
<StatusSubtext>
{apiServerRunning ? `http://localhost:${apiServerConfig.port}` : t('apiServer.fields.port.description')}
</StatusSubtext>
</StatusContent>
</StatusSection>
<ControlSection>
{apiServerRunning && (
<Tooltip title={t('apiServer.actions.restart.tooltip')}>
<RestartButton
$loading={apiServerLoading}
onClick={apiServerLoading ? undefined : handleApiServerRestart}>
<RotateCcw size={14} />
<span>{t('apiServer.actions.restart.button')}</span>
</RestartButton>
</Tooltip>
)}
{/* Port input when server is stopped */}
{!apiServerRunning && (
<StyledInputNumber
value={apiServerConfig.port}
onChange={(value) => handlePortChange(String(value || 23333))}
min={1000}
max={65535}
disabled={apiServerRunning}
placeholder="23333"
size="middle"
/>
)}
<Tooltip title={apiServerRunning ? t('apiServer.actions.stop') : t('apiServer.actions.start')}>
{apiServerRunning ? (
<StopButton
$loading={apiServerLoading}
onClick={apiServerLoading ? undefined : () => handleApiServerToggle(false)}>
<Square size={20} style={{ color: 'var(--color-status-error)' }} />
</StopButton>
) : (
<StartButton
$loading={apiServerLoading}
onClick={apiServerLoading ? undefined : () => handleApiServerToggle(true)}>
<Play size={20} style={{ color: 'var(--color-status-success)' }} />
</StartButton>
)}
</Tooltip>
</ControlSection>
</ServerControlPanel>
{/* API Key Configuration */}
<ConfigurationField>
<FieldLabel>{t('apiServer.fields.apiKey.label')}</FieldLabel>
<FieldDescription>{t('apiServer.fields.apiKey.description')}</FieldDescription>
<StyledInput
value={apiServerConfig.apiKey}
readOnly
placeholder={t('apiServer.fields.apiKey.placeholder')}
size="middle"
suffix={
<InputButtonContainer>
{!apiServerRunning && (
<RegenerateButton onClick={regenerateApiKey} disabled={apiServerRunning} type="link">
{t('apiServer.actions.regenerate')}
</RegenerateButton>
)}
<Tooltip title={t('apiServer.fields.apiKey.copyTooltip')}>
<InputButton icon={<Copy size={14} />} onClick={copyApiKey} disabled={!apiServerConfig.apiKey} />
</Tooltip>
</InputButtonContainer>
}
/>
{/* Authorization header info */}
<AuthHeaderSection>
<FieldLabel>{t('apiServer.authHeader.title')}</FieldLabel>
<StyledInput
style={{ height: 38 }}
value={`Authorization: Bearer ${apiServerConfig.apiKey || 'your-api-key'}`}
readOnly
size="middle"
/>
</AuthHeaderSection>
</ConfigurationField>
</Container>
)
}
// Styled Components
const Container = styled(SettingContainer)`
display: flex;
flex-direction: column;
height: calc(100vh - var(--navbar-height));
`
const HeaderSection = styled.div`
display: flex;
flex-direction: row;
align-items: center;
justify-content: space-between;
margin-bottom: 24px;
`
const HeaderContent = styled.div`
flex: 1;
`
const ServerControlPanel = styled.div<{ $status: boolean }>`
display: flex;
align-items: center;
justify-content: space-between;
padding: 16px 20px;
border-radius: 8px;
background: var(--color-background);
border: 1px solid ${(props) => (props.$status ? 'var(--color-status-success)' : 'var(--color-border)')};
transition: all 0.3s ease;
margin-bottom: 16px;
`
const StatusSection = styled.div`
display: flex;
align-items: center;
gap: 10px;
`
const StatusIndicator = styled.div<{ $status: boolean }>`
position: relative;
width: 10px;
height: 10px;
border-radius: 50%;
background: ${(props) => (props.$status ? 'var(--color-status-success)' : 'var(--color-status-error)')};
&::before {
content: '';
position: absolute;
inset: -3px;
border-radius: 50%;
background: ${(props) => (props.$status ? 'var(--color-status-success)' : 'var(--color-status-error)')};
opacity: 0.2;
animation: ${(props) => (props.$status ? 'pulse 2s infinite' : 'none')};
}
@keyframes pulse {
0%,
100% {
transform: scale(1);
opacity: 0.2;
}
50% {
transform: scale(1.5);
opacity: 0.1;
}
}
`
const StatusContent = styled.div`
display: flex;
flex-direction: column;
gap: 2px;
`
const StatusText = styled.div<{ $status: boolean }>`
font-weight: 600;
font-size: 14px;
color: ${(props) => (props.$status ? 'var(--color-status-success)' : 'var(--color-text-1)')};
margin: 0;
`
const StatusSubtext = styled.div`
font-size: 12px;
color: var(--color-text-3);
margin: 0;
`
const ControlSection = styled.div`
display: flex;
align-items: center;
gap: 12px;
`
const RestartButton = styled.div<{ $loading: boolean }>`
display: flex;
align-items: center;
gap: 4px;
color: var(--color-text-2);
cursor: ${(props) => (props.$loading ? 'not-allowed' : 'pointer')};
opacity: ${(props) => (props.$loading ? 0.5 : 1)};
font-size: 12px;
transition: all 0.2s ease;
&:hover {
color: ${(props) => (props.$loading ? 'var(--color-text-2)' : 'var(--color-primary)')};
}
`
const StyledInputNumber = styled(InputNumber)`
width: 80px;
border-radius: 6px;
border: 1.5px solid var(--color-border);
margin-right: 5px;
`
const StartButton = styled.div<{ $loading: boolean }>`
display: inline-flex;
align-items: center;
justify-content: center;
cursor: ${(props) => (props.$loading ? 'not-allowed' : 'pointer')};
opacity: ${(props) => (props.$loading ? 0.5 : 1)};
transition: all 0.2s ease;
&:hover {
transform: ${(props) => (props.$loading ? 'scale(1)' : 'scale(1.1)')};
}
`
const StopButton = styled.div<{ $loading: boolean }>`
display: inline-flex;
align-items: center;
justify-content: center;
cursor: ${(props) => (props.$loading ? 'not-allowed' : 'pointer')};
opacity: ${(props) => (props.$loading ? 0.5 : 1)};
transition: all 0.2s ease;
&:hover {
transform: ${(props) => (props.$loading ? 'scale(1)' : 'scale(1.1)')};
}
`
const ConfigurationField = styled.div`
display: flex;
flex-direction: column;
gap: 8px;
padding: 16px;
background: var(--color-background);
border-radius: 8px;
border: 1px solid var(--color-border);
`
const FieldLabel = styled.div`
font-size: 14px;
font-weight: 500;
color: var(--color-text-1);
margin: 0;
`
const FieldDescription = styled.div`
font-size: 12px;
color: var(--color-text-3);
margin: 0;
`
const StyledInput = styled(Input)`
width: 100%;
border-radius: 6px;
border: 1.5px solid var(--color-border);
`
const InputButtonContainer = styled.div`
display: flex;
align-items: center;
gap: 4px;
`
const InputButton = styled(Button)`
border: none;
padding: 0 4px;
background: transparent;
`
const RegenerateButton = styled(Button)`
padding: 0 4px;
font-size: 12px;
height: auto;
line-height: 1;
border: none;
background: transparent;
`
const AuthHeaderSection = styled.div`
margin-top: 12px;
display: flex;
flex-direction: column;
gap: 8px;
`
export default ApiServerSettings

View File

@ -1 +0,0 @@
export { default as ApiServerSettings } from './ApiServerSettings'

View File

@ -1,13 +1,15 @@
import 'emoji-picker-element'
import { CloseCircleFilled, QuestionCircleOutlined } from '@ant-design/icons'
import { CloseCircleFilled } from '@ant-design/icons'
import CodeEditor from '@renderer/components/CodeEditor'
import EmojiPicker from '@renderer/components/EmojiPicker'
import { Box, HSpaceBetweenStack, HStack } from '@renderer/components/Layout'
import { usePromptProcessor } from '@renderer/hooks/usePromptProcessor'
import { estimateTextTokens } from '@renderer/services/TokenService'
import { Assistant, AssistantSettings } from '@renderer/types'
import { getLeadingEmoji } from '@renderer/utils'
import { Button, Input, Popover } from 'antd'
import { Edit, Eye, HelpCircle } from 'lucide-react'
import { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import ReactMarkdown from 'react-markdown'
@ -28,7 +30,7 @@ const AssistantPromptSettings: React.FC<Props> = ({ assistant, updateAssistant }
const [prompt, setPrompt] = useState(assistant.prompt)
const [tokenCount, setTokenCount] = useState(0)
const { t } = useTranslation()
const [showMarkdown, setShowMarkdown] = useState(prompt.length > 0)
const [showPreview, setShowPreview] = useState(prompt.length > 0)
useEffect(() => {
const updateTokenCount = async () => {
@ -38,9 +40,15 @@ const AssistantPromptSettings: React.FC<Props> = ({ assistant, updateAssistant }
updateTokenCount()
}, [prompt])
const processedPrompt = usePromptProcessor({
prompt,
modelName: assistant.model?.name
})
const onUpdate = () => {
const _assistant = { ...assistant, name: name.trim(), emoji, prompt }
updateAssistant(_assistant)
window.message.success(t('common.saved'))
}
const handleEmojiSelect = (selectedEmoji: string) => {
@ -106,13 +114,13 @@ const AssistantPromptSettings: React.FC<Props> = ({ assistant, updateAssistant }
<HStack mb={8} alignItems="center" gap={4}>
<Box style={{ fontWeight: 'bold' }}>{t('common.prompt')}</Box>
<Popover title={t('agents.add.prompt.variables.tip.title')} content={promptVarsContent}>
<QuestionCircleOutlined size={14} color="var(--color-text-2)" />
<HelpCircle size={14} color="var(--color-text-2)" />
</Popover>
</HStack>
<TextAreaContainer>
{showMarkdown ? (
<MarkdownContainer className="markdown" onClick={() => setShowMarkdown(false)}>
<ReactMarkdown>{prompt}</ReactMarkdown>
{showPreview ? (
<MarkdownContainer className="markdown" onClick={() => setShowPreview(false)}>
<ReactMarkdown>{processedPrompt || prompt}</ReactMarkdown>
<div style={{ height: '30px' }} />
</MarkdownContainer>
) : (
@ -141,8 +149,11 @@ const AssistantPromptSettings: React.FC<Props> = ({ assistant, updateAssistant }
</TextAreaContainer>
<HSpaceBetweenStack width="100%" justifyContent="flex-end" mt="10px">
<TokenCount>Tokens: {tokenCount}</TokenCount>
<Button type="primary" onClick={() => setShowMarkdown((prev) => !prev)}>
{t(showMarkdown ? 'common.edit' : 'common.save')}
<Button
type="primary"
icon={showPreview ? <Edit size={14} /> : <Eye size={14} />}
onClick={() => setShowPreview((prev) => !prev)}>
{showPreview ? t('common.edit') : t('common.preview')}
</Button>
</HSpaceBetweenStack>
</Container>
@ -154,6 +165,10 @@ const Container = styled.div`
flex: 1;
flex-direction: column;
overflow: hidden;
.ant-btn {
line-height: 0;
}
`
const EmojiButtonWrapper = styled.div`

View File

@ -1,4 +1,5 @@
import { InfoCircleOutlined } from '@ant-design/icons'
import { HStack } from '@renderer/components/Layout'
import Selector from '@renderer/components/Selector'
import { useTheme } from '@renderer/context/ThemeProvider'
import { useEnableDeveloperMode, useSettings } from '@renderer/hooks/useSettings'
@ -199,37 +200,6 @@ const GeneralSettings: FC = () => {
/>
</SettingRow>
<SettingDivider />
<SettingRow>
<SettingRowTitle>{t('settings.general.spell_check.label')}</SettingRowTitle>
<Switch checked={enableSpellCheck} onChange={handleSpellCheckChange} />
</SettingRow>
{enableSpellCheck && (
<>
<SettingDivider />
<SettingRow>
<SettingRowTitle>{t('settings.general.spell_check.languages')}</SettingRowTitle>
<Selector<string>
size={14}
multiple
value={spellCheckLanguages}
placeholder={t('settings.general.spell_check.languages')}
onChange={handleSpellCheckLanguagesChange}
options={spellCheckLanguageOptions.map((lang) => ({
value: lang.value,
label: (
<Flex align="center" gap={8}>
<span role="img" aria-label={lang.flag}>
{lang.flag}
</span>
{lang.label}
</Flex>
)
}))}
/>
</SettingRow>
</>
)}
<SettingDivider />
<SettingRow>
<SettingRowTitle>{t('settings.proxy.mode.title')}</SettingRowTitle>
<Selector value={storeProxyMode} onChange={onProxyModeChange} options={proxyModeOptions} />
@ -251,6 +221,33 @@ const GeneralSettings: FC = () => {
</>
)}
<SettingDivider />
<SettingRow>
<HStack justifyContent="space-between" alignItems="center" style={{ flex: 1, marginRight: 16 }}>
<SettingRowTitle>{t('settings.general.spell_check.label')}</SettingRowTitle>
{enableSpellCheck && (
<Selector<string>
size={14}
multiple
value={spellCheckLanguages}
placeholder={t('settings.general.spell_check.languages')}
onChange={handleSpellCheckLanguagesChange}
options={spellCheckLanguageOptions.map((lang) => ({
value: lang.value,
label: (
<Flex align="center" gap={8}>
<span role="img" aria-label={lang.flag}>
{lang.flag}
</span>
{lang.label}
</Flex>
)
}))}
/>
)}
</HStack>
<Switch checked={enableSpellCheck} onChange={handleSpellCheckChange} />
</SettingRow>
<SettingDivider />
<SettingRow>
<SettingRowTitle>{t('settings.hardware_acceleration.title')}</SettingRowTitle>
<Switch checked={disableHardwareAcceleration} onChange={handleHardwareAccelerationChange} />

View File

@ -1,7 +1,6 @@
import { DeleteOutlined, EditOutlined, PlusOutlined } from '@ant-design/icons'
import { DragDropContext, Draggable, Droppable, DropResult } from '@hello-pangea/dnd'
import { loggerService } from '@logger'
import Scrollbar from '@renderer/components/Scrollbar'
import { DraggableVirtualList } from '@renderer/components/DraggableList'
import { getProviderLogo } from '@renderer/config/providers'
import { useAllProviders, useProviders } from '@renderer/hooks/useProvider'
import { getProviderLabel } from '@renderer/i18n/label'
@ -9,7 +8,6 @@ import ImageStorage from '@renderer/services/ImageStorage'
import { INITIAL_PROVIDERS } from '@renderer/store/llm'
import { Provider, ProviderType } from '@renderer/types'
import {
droppableReorder,
generateColorFromChar,
getFancyProviderName,
getFirstCharacter,
@ -30,6 +28,8 @@ import ProviderSetting from './ProviderSetting'
const logger = loggerService.withContext('ProvidersList')
const BUTTON_WRAPPER_HEIGHT = 50
const ProvidersList: FC = () => {
const [searchParams] = useSearchParams()
const providers = useAllProviders()
@ -272,14 +272,9 @@ const ProvidersList: FC = () => {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [searchParams])
const onDragEnd = (result: DropResult) => {
const handleUpdateProviders = (reorderProviders: Provider[]) => {
setDragging(false)
if (result.destination) {
const sourceIndex = result.source.index
const destIndex = result.destination.index
const reorderProviders = droppableReorder<Provider>(providers, sourceIndex, destIndex)
updateProviders(reorderProviders)
}
updateProviders(reorderProviders)
}
const onAddProvider = async () => {
@ -462,50 +457,37 @@ const ProvidersList: FC = () => {
disabled={dragging}
/>
</AddButtonWrapper>
<Scrollbar>
<ProviderList>
<DragDropContext onDragStart={() => setDragging(true)} onDragEnd={onDragEnd}>
<Droppable droppableId="droppable">
{(provided) => (
<div {...provided.droppableProps} ref={provided.innerRef}>
{filteredProviders.map((provider, index) => (
<Draggable
key={`draggable_${provider.id}_${index}`}
draggableId={provider.id}
index={index}
isDragDisabled={searchText.length > 0}>
{(provided) => (
<div
ref={provided.innerRef}
{...provided.draggableProps}
{...provided.dragHandleProps}
style={{ ...provided.draggableProps.style, marginBottom: 5 }}>
<Dropdown menu={{ items: getDropdownMenus(provider) }} trigger={['contextMenu']}>
<ProviderListItem
key={JSON.stringify(provider)}
className={provider.id === selectedProvider?.id ? 'active' : ''}
onClick={() => setSelectedProvider(provider)}>
{getProviderAvatar(provider)}
<ProviderItemName className="text-nowrap">
{getFancyProviderName(provider)}
</ProviderItemName>
{provider.enabled && (
<Tag color="green" style={{ marginLeft: 'auto', marginRight: 0, borderRadius: 16 }}>
ON
</Tag>
)}
</ProviderListItem>
</Dropdown>
</div>
)}
</Draggable>
))}
</div>
<DraggableVirtualList
list={filteredProviders}
onUpdate={handleUpdateProviders}
onDragStart={() => setDragging(true)}
estimateSize={useCallback(() => 40, [])}
overscan={3}
style={{
height: `calc(100% - 2 * ${BUTTON_WRAPPER_HEIGHT}px)`
}}
scrollerStyle={{
padding: 8,
paddingRight: 5
}}
itemContainerStyle={{ paddingBottom: 5 }}>
{(provider) => (
<Dropdown menu={{ items: getDropdownMenus(provider) }} trigger={['contextMenu']}>
<ProviderListItem
key={JSON.stringify(provider)}
className={provider.id === selectedProvider?.id ? 'active' : ''}
onClick={() => setSelectedProvider(provider)}>
{getProviderAvatar(provider)}
<ProviderItemName className="text-nowrap">{getFancyProviderName(provider)}</ProviderItemName>
{provider.enabled && (
<Tag color="green" style={{ marginLeft: 'auto', marginRight: 0, borderRadius: 16 }}>
ON
</Tag>
)}
</Droppable>
</DragDropContext>
</ProviderList>
</Scrollbar>
</ProviderListItem>
</Dropdown>
)}
</DraggableVirtualList>
<AddButtonWrapper>
<Button
style={{ width: '100%', borderRadius: 'var(--list-item-border-radius)' }}
@ -536,14 +518,6 @@ const ProviderListContainer = styled.div`
border-right: 0.5px solid var(--color-border);
`
const ProviderList = styled.div`
display: flex;
flex: 1;
flex-direction: column;
padding: 8px;
padding-right: 5px;
`
const ProviderListItem = styled.div`
display: flex;
flex-direction: row;
@ -575,7 +549,7 @@ const ProviderItemName = styled.div`
`
const AddButtonWrapper = styled.div`
height: 50px;
height: ${BUTTON_WRAPPER_HEIGHT}px;
flex-direction: row;
justify-content: center;
align-items: center;

View File

@ -8,8 +8,8 @@ import {
Info,
MonitorCog,
Package,
PencilRuler,
Rocket,
Server,
Settings2,
SquareTerminal,
TextCursorInput,
@ -21,7 +21,6 @@ import { Link, Route, Routes, useLocation } from 'react-router-dom'
import styled from 'styled-components'
import AboutSettings from './AboutSettings'
import { ApiServerSettings } from './ApiServerSettings'
import DataSettings from './DataSettings/DataSettings'
import DisplaySettings from './DisplaySettings/DisplaySettings'
import GeneralSettings from './GeneralSettings'
@ -77,18 +76,18 @@ const SettingsPage: FC = () => {
{t('settings.mcp.title')}
</MenuItem>
</MenuItemLink>
<MenuItemLink to="/settings/api-server">
<MenuItem className={isRoute('/settings/api-server')}>
<Server size={18} />
{t('apiServer.title')}
</MenuItem>
</MenuItemLink>
<MenuItemLink to="/settings/memory">
<MenuItem className={isRoute('/settings/memory')}>
<Brain size={18} />
{t('memory.title')}
</MenuItem>
</MenuItemLink>
<MenuItemLink to="/settings/tool">
<MenuItem className={isRoute('/settings/tool')}>
<PencilRuler size={18} />
{t('settings.tool.title')}
</MenuItem>
</MenuItemLink>
<MenuItemLink to="/settings/shortcut">
<MenuItem className={isRoute('/settings/shortcut')}>
<Command size={18} />
@ -133,7 +132,6 @@ const SettingsPage: FC = () => {
<Route path="tool/*" element={<ToolSettings />} />
<Route path="mcp/*" element={<MCPSettings />} />
<Route path="memory" element={<MemorySettings />} />
<Route path="api-server" element={<ApiServerSettings />} />
<Route path="general/*" element={<GeneralSettings />} />
<Route path="display" element={<DisplaySettings />} />
<Route path="shortcut" element={<ShortcutSettings />} />

View File

@ -1,168 +0,0 @@
import { ExportOutlined } from '@ant-design/icons'
import { getOcrProviderLogo, OCR_PROVIDER_CONFIG } from '@renderer/config/ocrProviders'
import { useOcrProvider } from '@renderer/hooks/useOcr'
import { OcrProvider } from '@renderer/types'
import { formatApiKeys, hasObjectKey } from '@renderer/utils'
import { Avatar, Divider, Flex, Input, InputNumber, Segmented } from 'antd'
import Link from 'antd/es/typography/Link'
import { FC, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
import {
SettingDivider,
SettingHelpLink,
SettingHelpText,
SettingHelpTextRow,
SettingRow,
SettingRowTitle,
SettingSubtitle,
SettingTitle
} from '../..'
interface Props {
provider: OcrProvider
}
const OcrProviderSettings: FC<Props> = ({ provider: _provider }) => {
const { provider: ocrProvider, updateOcrProvider } = useOcrProvider(_provider.id)
const { t } = useTranslation()
const [apiKey, setApiKey] = useState(ocrProvider.apiKey || '')
const [apiHost, setApiHost] = useState(ocrProvider.apiHost || '')
const [options, setOptions] = useState(ocrProvider.options || {})
const ocrProviderConfig = OCR_PROVIDER_CONFIG[ocrProvider.id]
const apiKeyWebsite = ocrProviderConfig?.websites?.apiKey
const officialWebsite = ocrProviderConfig?.websites?.official
useEffect(() => {
setApiKey(ocrProvider.apiKey ?? '')
setApiHost(ocrProvider.apiHost ?? '')
setOptions(ocrProvider.options ?? {})
}, [ocrProvider.apiKey, ocrProvider.apiHost, ocrProvider.options])
const onUpdateApiKey = () => {
if (apiKey !== ocrProvider.apiKey) {
updateOcrProvider({ ...ocrProvider, apiKey })
}
}
const onUpdateApiHost = () => {
let trimmedHost = apiHost?.trim() || ''
if (trimmedHost.endsWith('/')) {
trimmedHost = trimmedHost.slice(0, -1)
}
if (trimmedHost !== ocrProvider.apiHost) {
updateOcrProvider({ ...ocrProvider, apiHost: trimmedHost })
} else {
setApiHost(ocrProvider.apiHost || '')
}
}
const onUpdateOptions = (key: string, value: any) => {
const newOptions = { ...options, [key]: value }
setOptions(newOptions)
updateOcrProvider({ ...ocrProvider, options: newOptions })
}
return (
<>
<SettingTitle>
<Flex align="center" gap={8}>
<ProviderLogo shape="square" src={getOcrProviderLogo(ocrProvider.id)} size={16} />
<ProviderName> {ocrProvider.name}</ProviderName>
{officialWebsite && ocrProviderConfig?.websites && (
<Link target="_blank" href={ocrProviderConfig.websites.official}>
<ExportOutlined style={{ color: 'var(--color-text)', fontSize: '12px' }} />
</Link>
)}
</Flex>
</SettingTitle>
<Divider style={{ width: '100%', margin: '10px 0' }} />
{hasObjectKey(ocrProvider, 'apiKey') && (
<>
<SettingSubtitle style={{ marginTop: 5, marginBottom: 10 }}>
{t('settings.provider.api_key.label')}
</SettingSubtitle>
<Flex gap={8}>
<Input.Password
value={apiKey}
placeholder={t('settings.provider.api_key.label')}
onChange={(e) => setApiKey(formatApiKeys(e.target.value))}
onBlur={onUpdateApiKey}
spellCheck={false}
type="password"
autoFocus={apiKey === ''}
/>
</Flex>
<SettingHelpTextRow style={{ justifyContent: 'space-between', marginTop: 5 }}>
<SettingHelpLink target="_blank" href={apiKeyWebsite}>
{t('settings.provider.get_api_key')}
</SettingHelpLink>
<SettingHelpText>{t('settings.provider.api_key.tip')}</SettingHelpText>
</SettingHelpTextRow>
</>
)}
{hasObjectKey(ocrProvider, 'apiHost') && (
<>
<SettingSubtitle style={{ marginTop: 5, marginBottom: 10 }}>
{t('settings.provider.api_host')}
</SettingSubtitle>
<Flex>
<Input
value={apiHost}
placeholder={t('settings.provider.api_host')}
onChange={(e) => setApiHost(e.target.value)}
onBlur={onUpdateApiHost}
/>
</Flex>
</>
)}
{hasObjectKey(ocrProvider, 'options') && ocrProvider.id === 'system' && (
<>
<SettingRow>
<SettingRowTitle>{t('settings.tool.ocr.mac_system_ocr_options.mode.title')}</SettingRowTitle>
<Segmented
options={[
{
label: t('settings.tool.ocr.mac_system_ocr_options.mode.accurate'),
value: 1
},
{
label: t('settings.tool.ocr.mac_system_ocr_options.mode.fast'),
value: 0
}
]}
value={options.recognitionLevel}
onChange={(value) => onUpdateOptions('recognitionLevel', value)}
/>
</SettingRow>
<SettingDivider style={{ marginTop: 15, marginBottom: 12 }} />
<SettingRow>
<SettingRowTitle>{t('settings.tool.ocr.mac_system_ocr_options.min_confidence')}</SettingRowTitle>
<InputNumber
value={options.minConfidence}
onChange={(value) => onUpdateOptions('minConfidence', value)}
min={0}
max={1}
step={0.1}
/>
</SettingRow>
</>
)}
</>
)
}
const ProviderName = styled.span`
font-size: 14px;
font-weight: 500;
`
const ProviderLogo = styled(Avatar)`
border: 0.5px solid var(--color-border);
`
export default OcrProviderSettings

View File

@ -1,58 +0,0 @@
import { isMac } from '@renderer/config/constant'
import { useTheme } from '@renderer/context/ThemeProvider'
import { useDefaultOcrProvider, useOcrProviders } from '@renderer/hooks/useOcr'
import { PreprocessProvider } from '@renderer/types'
import { Select } from 'antd'
import { FC, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { SettingContainer, SettingDivider, SettingGroup, SettingRow, SettingRowTitle, SettingTitle } from '../..'
import OcrProviderSettings from './OcrSettings'
const OcrSettings: FC = () => {
const { ocrProviders } = useOcrProviders()
const { provider: defaultProvider, setDefaultOcrProvider } = useDefaultOcrProvider()
const { t } = useTranslation()
const [selectedProvider, setSelectedProvider] = useState<PreprocessProvider | undefined>(defaultProvider)
const { theme: themeMode } = useTheme()
function updateSelectedOcrProvider(providerId: string) {
const provider = ocrProviders.find((p) => p.id === providerId)
if (!provider) {
return
}
setDefaultOcrProvider(provider)
setSelectedProvider(provider)
}
return (
<SettingContainer theme={themeMode}>
<SettingGroup theme={themeMode}>
<SettingTitle>{t('settings.tool.ocr.title')}</SettingTitle>
<SettingDivider />
<SettingRow>
<SettingRowTitle>{t('settings.tool.ocr.provider')}</SettingRowTitle>
<div style={{ display: 'flex', gap: '8px' }}>
<Select
value={selectedProvider?.id}
style={{ width: '200px' }}
onChange={(value: string) => updateSelectedOcrProvider(value)}
placeholder={t('settings.tool.ocr.provider_placeholder')}
options={ocrProviders.map((p) => ({
value: p.id,
label: p.name,
disabled: !isMac && p.id === 'system' // 在非 Mac 系统下禁用 system 选项
}))}
/>
</div>
</SettingRow>
</SettingGroup>
{selectedProvider && (
<SettingGroup theme={themeMode}>
<OcrProviderSettings provider={selectedProvider} />
</SettingGroup>
)}
</SettingContainer>
)
}
export default OcrSettings

View File

@ -1,5 +1,4 @@
import { GlobalOutlined } from '@ant-design/icons'
import OcrIcon from '@renderer/components/Icons/OcrIcon'
import { HStack } from '@renderer/components/Layout'
import ListItem from '@renderer/components/ListItem'
import { FileCode } from 'lucide-react'
@ -7,18 +6,21 @@ import { FC, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
import OcrSettings from './OcrSettings'
import PreprocessSettings from './PreprocessSettings'
import WebSearchSettings from './WebSearchSettings'
let _menu: string = 'web-search'
const ToolSettings: FC = () => {
const { t } = useTranslation()
const [menu, setMenu] = useState<string>('web-search')
const [menu, setMenu] = useState<string>(_menu)
const menuItems = [
{ key: 'web-search', title: 'settings.tool.websearch.title', icon: <GlobalOutlined style={{ fontSize: 16 }} /> },
{ key: 'preprocess', title: 'settings.tool.preprocess.title', icon: <FileCode size={16} /> },
{ key: 'ocr', title: 'settings.tool.ocr.title', icon: <OcrIcon /> }
{ key: 'preprocess', title: 'settings.tool.preprocess.title', icon: <FileCode size={16} /> }
]
_menu = menu
return (
<Container>
<MenuList>
@ -35,7 +37,6 @@ const ToolSettings: FC = () => {
</MenuList>
{menu == 'web-search' && <WebSearchSettings />}
{menu == 'preprocess' && <PreprocessSettings />}
{menu == 'ocr' && <OcrSettings />}
</Container>
)
}

View File

@ -195,7 +195,7 @@ class KnowledgeQueue {
updateBaseItemIsPreprocessed({
baseId,
itemId: item.id,
isPreprocessed: !!base.preprocessOrOcrProvider
isPreprocessed: !!base.preprocessProvider
})
)
}

View File

@ -41,7 +41,12 @@ import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import { isAbortError } from '@renderer/utils/error'
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPromptWithThinkTool, buildSystemPromptWithTools } from '@renderer/utils/prompt'
import {
buildSystemPromptWithThinkTool,
buildSystemPromptWithTools,
containsSupportedVariables,
replacePromptVariables
} from '@renderer/utils/prompt'
import { findLast, isEmpty, takeRight } from 'lodash'
import AiProvider from '../aiCore'
@ -375,8 +380,8 @@ async function fetchExternalTool(
.map((result) => result.value)
.flat()
// 添加内置工具
const { BUILT_IN_TOOLS } = await import('../tools')
mcpTools.push(...BUILT_IN_TOOLS)
// const { BUILT_IN_TOOLS } = await import('../tools')
// mcpTools.push(...BUILT_IN_TOOLS)
// 根据toolUseMode决定如何构建系统提示词
const basePrompt = assistant.prompt
@ -426,6 +431,10 @@ export async function fetchChatCompletion({
}) {
logger.debug('fetchChatCompletion', messages, assistant)
if (assistant.prompt && containsSupportedVariables(assistant.prompt)) {
assistant.prompt = await replacePromptVariables(assistant.prompt, assistant.model?.name)
}
const provider = getAssistantProvider(assistant)
const AI = new AiProvider(provider)
@ -643,9 +652,13 @@ export async function fetchTranslate({ content, assistant, onResponse }: FetchTr
}
export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
const prompt = (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title')
let prompt = (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title')
const model = getTopNamingModel() || assistant.model || getDefaultModel()
if (prompt && containsSupportedVariables(prompt)) {
prompt = await replacePromptVariables(prompt, model.name)
}
// 总结上下文总是取最后5条消息
const contextMessages = takeRight(messages, 5)

View File

@ -57,7 +57,7 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams
apiKey: rerankAiProvider.getApiKey() || 'secret',
baseURL: rerankHost
},
preprocessOrOcrProvider: base.preprocessOrOcrProvider,
preprocessProvider: base.preprocessProvider,
documentCount: base.documentCount
}
}

View File

@ -1,4 +1,5 @@
import { combineReducers, configureStore } from '@reduxjs/toolkit'
import { loggerService } from '@renderer/services/LoggerService'
import { useDispatch, useSelector, useStore } from 'react-redux'
import { FLUSH, PAUSE, PERSIST, persistReducer, persistStore, PURGE, REGISTER, REHYDRATE } from 'redux-persist'
import storage from 'redux-persist/lib/storage'
@ -18,7 +19,6 @@ import migrate from './migrate'
import minapps from './minapps'
import newMessagesReducer from './newMessage'
import nutstore from './nutstore'
import ocr from './ocr'
import paintings from './paintings'
import preprocess from './preprocess'
import runtime from './runtime'
@ -29,6 +29,8 @@ import tabs from './tabs'
import translate from './translate'
import websearch from './websearch'
const logger = loggerService.withContext('Store')
const rootReducer = combineReducers({
assistants,
agents,
@ -38,7 +40,6 @@ const rootReducer = combineReducers({
llm,
settings,
runtime,
ocr,
shortcuts,
knowledge,
minapps,
@ -48,7 +49,6 @@ const rootReducer = combineReducers({
copilot,
selectionStore,
tabs,
// messages: messagesReducer,
preprocess,
messages: newMessagesReducer,
messageBlocks: messageBlocksReducer,
@ -60,7 +60,7 @@ const persistedReducer = persistReducer(
{
key: 'cherry-studio',
storage,
version: 126,
version: 127,
blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs'],
migrate
},
@ -104,4 +104,10 @@ export const useAppSelector = useSelector.withTypes<RootState>()
export const useAppStore = useStore.withTypes<typeof store>()
window.store = store
export async function handleSaveData() {
logger.info('Flushing redux persistor data')
await persistor.flush()
logger.info('Flushed redux persistor data')
}
export default store

View File

@ -1938,6 +1938,28 @@ const migrateConfig = {
}
},
'126': (state: RootState) => {
try {
state.knowledge.bases.forEach((base) => {
// @ts-ignore eslint-disable-next-line
if (base.preprocessOrOcrProvider) {
// @ts-ignore eslint-disable-next-line
base.preprocessProvider = base.preprocessOrOcrProvider
// @ts-ignore eslint-disable-next-line
delete base.preprocessOrOcrProvider
// @ts-ignore eslint-disable-next-line
if (base.preprocessProvider.type === 'ocr') {
// @ts-ignore eslint-disable-next-line
delete base.preprocessProvider
}
}
})
return state
} catch (error) {
logger.error('migrate 126 error', error as Error)
return state
}
},
'127': (state: RootState) => {
try {
const visibleIcons = state.settings.sidebarIcons.visible
if (visibleIcons.includes('discover')) {
@ -1955,6 +1977,7 @@ const migrateConfig = {
}
}
} catch (error) {
logger.error('migrate 127 error', error as Error)
return state
}
}

View File

@ -1,46 +0,0 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit'
import { OcrProvider } from '@renderer/types'
export interface OcrState {
providers: OcrProvider[]
defaultProvider: string
}
const initialState: OcrState = {
providers: [
{
id: 'system',
name: 'System(Mac Only)',
options: {
recognitionLevel: 0,
minConfidence: 0.5
}
}
],
defaultProvider: ''
}
const ocrSlice = createSlice({
name: 'ocr',
initialState,
reducers: {
setDefaultOcrProvider(state, action: PayloadAction<string>) {
state.defaultProvider = action.payload
},
setOcrProviders(state, action: PayloadAction<OcrProvider[]>) {
state.providers = action.payload
},
updateOcrProviders(state, action: PayloadAction<OcrProvider[]>) {
state.providers = action.payload
},
updateOcrProvider(state, action: PayloadAction<OcrProvider>) {
const index = state.providers.findIndex((provider) => provider.id === action.payload.id)
if (index !== -1) {
state.providers[index] = action.payload
}
}
}
})
export const { updateOcrProviders, updateOcrProvider, setDefaultOcrProvider, setOcrProviders } = ocrSlice.actions
export default ocrSlice.reducer

Some files were not shown because too many files have changed in this diff Show More