Add streaming file upload and download actions

Introduces new OneBot actions for streaming file upload and download, including chunked file transfer with memory/disk management and SHA256 verification. Adds CleanStreamTempFile, DownloadFileStream, UploadFileStream, and TestStreamDownload actions, updates action routing and network adapters to support streaming via HTTP and WebSocket, and provides Python test scripts for concurrent upload testing.
This commit is contained in:
手瓜一十雪 2025-09-16 23:24:00 +08:00
parent 66f30e1ebf
commit 890d032794
14 changed files with 1163 additions and 17 deletions

View File

@ -29,7 +29,9 @@ export class OB11Response {
return this.createResponse(null, 'failed', retcode, err, echo);
}
}
export interface OneBotRequestToolkit<T = unknown> {
send: (data: T) => Promise<void>;
}
export abstract class OneBotAction<PayloadType, ReturnDataType> {
actionName: typeof ActionName[keyof typeof ActionName] = ActionName.Unknown;
core: NapCatCore;
@ -57,27 +59,27 @@ export abstract class OneBotAction<PayloadType, ReturnDataType> {
return { valid: true };
}
public async handle(payload: PayloadType, adaptername: string, config: NetworkAdapterConfig): Promise<OB11Return<ReturnDataType | null>> {
public async handle(payload: PayloadType, adaptername: string, config: NetworkAdapterConfig, req: OneBotRequestToolkit = { send: async () => { } }, echo?: string): Promise<OB11Return<ReturnDataType | null>> {
const result = await this.check(payload);
if (!result.valid) {
return OB11Response.error(result.message, 400);
}
try {
const resData = await this._handle(payload, adaptername, config);
return OB11Response.ok(resData);
const resData = await this._handle(payload, adaptername, config, req);
return OB11Response.ok(resData, echo);
} catch (e: unknown) {
this.core.context.logger.logError('发生错误', e);
return OB11Response.error((e as Error).message.toString() || (e as Error)?.stack?.toString() || '未知错误,可能操作超时', 200);
return OB11Response.error((e as Error).message.toString() || (e as Error)?.stack?.toString() || '未知错误,可能操作超时', 200, echo);
}
}
public async websocketHandle(payload: PayloadType, echo: unknown, adaptername: string, config: NetworkAdapterConfig): Promise<OB11Return<ReturnDataType | null>> {
public async websocketHandle(payload: PayloadType, echo: unknown, adaptername: string, config: NetworkAdapterConfig, req: OneBotRequestToolkit = { send: async () => { } }): Promise<OB11Return<ReturnDataType | null>> {
const result = await this.check(payload);
if (!result.valid) {
return OB11Response.error(result.message, 1400, echo);
}
try {
const resData = await this._handle(payload, adaptername, config);
const resData = await this._handle(payload, adaptername, config, req);
return OB11Response.ok(resData, echo);
} catch (e: unknown) {
this.core.context.logger.logError('发生错误', e);
@ -85,5 +87,5 @@ export abstract class OneBotAction<PayloadType, ReturnDataType> {
}
}
abstract _handle(payload: PayloadType, adaptername: string, config: NetworkAdapterConfig): Promise<ReturnDataType>;
abstract _handle(payload: PayloadType, adaptername: string, config: NetworkAdapterConfig, req: OneBotRequestToolkit): Promise<ReturnDataType>;
}

View File

@ -41,12 +41,12 @@ export class GetFileBase extends OneBotAction<GetFilePayload, GetFileResponse> {
let url = '';
if (mixElement?.picElement && rawMessage) {
const tempData =
await this.obContext.apis.MsgApi.rawToOb11Converters.picElement?.(mixElement?.picElement, rawMessage, mixElement, { parseMultMsg: false }) as OB11MessageImage | undefined;
await this.obContext.apis.MsgApi.rawToOb11Converters.picElement?.(mixElement?.picElement, rawMessage, mixElement, { parseMultMsg: false, disableGetUrl: false, quick_reply: true }) as OB11MessageImage | undefined;
url = tempData?.data.url ?? '';
}
if (mixElement?.videoElement && rawMessage) {
const tempData =
await this.obContext.apis.MsgApi.rawToOb11Converters.videoElement?.(mixElement?.videoElement, rawMessage, mixElement, { parseMultMsg: false }) as OB11MessageVideo | undefined;
await this.obContext.apis.MsgApi.rawToOb11Converters.videoElement?.(mixElement?.videoElement, rawMessage, mixElement, { parseMultMsg: false, disableGetUrl: false, quick_reply: true }) as OB11MessageVideo | undefined;
url = tempData?.data.url ?? '';
}
const res: GetFileResponse = {

View File

@ -130,10 +130,18 @@ import { DoGroupAlbumComment } from './extends/DoGroupAlbumComment';
import { GetGroupAlbumMediaList } from './extends/GetGroupAlbumMediaList';
import { SetGroupAlbumMediaLike } from './extends/SetGroupAlbumMediaLike';
import { DelGroupAlbumMedia } from './extends/DelGroupAlbumMedia';
import { CleanStreamTempFile } from './stream/CleanStreamTempFile';
import { DownloadFileStream } from './stream/DownloadFileStream';
import { TestStreamDownload } from './stream/TestStreamDownload';
import { UploadFileStream } from './stream/UploadFileStream';
export function createActionMap(obContext: NapCatOneBot11Adapter, core: NapCatCore) {
const actionHandlers = [
new CleanStreamTempFile(obContext, core),
new DownloadFileStream(obContext, core),
new TestStreamDownload(obContext, core),
new UploadFileStream(obContext, core),
new DelGroupAlbumMedia(obContext, core),
new SetGroupAlbumMediaLike(obContext, core),
new DoGroupAlbumComment(obContext, core),

View File

@ -10,6 +10,10 @@ export interface InvalidCheckResult {
}
export const ActionName = {
CleanStreamTempFile: 'clean_stream_temp_file',
TestStreamDownload: 'test_stream_download',
UploadFileStream: 'upload_file_stream',
DownloadFileStream: 'download_file_stream',
DelGroupAlbumMedia: 'del_group_album_media',
SetGroupAlbumMediaLike: 'set_group_album_media_like',
DoGroupAlbumComment: 'do_group_album_comment',

View File

@ -0,0 +1,33 @@
import { ActionName } from '@/onebot/action/router';
import { OneBotAction } from '@/onebot/action/OneBotAction';
import { join } from 'node:path';
import { readdir, unlink } from 'node:fs/promises';
export class CleanStreamTempFile extends OneBotAction<void, void> {
override actionName = ActionName.CleanStreamTempFile;
async _handle(_payload: void): Promise<void> {
try {
// 获取临时文件夹路径
const tempPath = this.core.NapCatTempPath;
// 读取文件夹中的所有文件
const files = await readdir(tempPath);
// 删除每个文件
const deletePromises = files.map(async (file) => {
const filePath = join(tempPath, file);
try {
await unlink(filePath);
this.core.context.logger.log(`已删除文件: ${filePath}`);
} catch (err: unknown) {
this.core.context.logger.log(`删除文件 ${filePath} 失败: ${(err as Error).message}`);
}
});
await Promise.all(deletePromises);
} catch (err: unknown) {
this.core.context.logger.log(`清理流临时文件失败: ${(err as Error).message}`);
}
}
}

View File

@ -0,0 +1,101 @@
import { ActionName } from '@/onebot/action/router';
import { OneBotAction, OneBotRequestToolkit } from '@/onebot/action/OneBotAction';
import { Static, Type } from '@sinclair/typebox';
import { NetworkAdapterConfig } from '@/onebot/config/config';
import fs from 'fs';
import { FileNapCatOneBotUUID } from '@/common/file-uuid';
const SchemaData = Type.Object({
file: Type.Optional(Type.String()),
file_id: Type.Optional(Type.String()),
chunk_size: Type.Optional(Type.Number({ default: 64 * 1024 })) // 默认64KB分块
});
type Payload = Static<typeof SchemaData>;
export class DownloadFileStream extends OneBotAction<Payload, void> {
override actionName = ActionName.DownloadFileStream;
override payloadSchema = SchemaData;
async _handle(payload: Payload, _adaptername: string, _config: NetworkAdapterConfig, req: OneBotRequestToolkit) {
payload.file ||= payload.file_id || '';
const chunkSize = payload.chunk_size || 64 * 1024;
let downloadPath = '';
let fileName = '';
let fileSize = 0;
//接收消息标记模式
const contextMsgFile = FileNapCatOneBotUUID.decode(payload.file);
if (contextMsgFile && contextMsgFile.msgId && contextMsgFile.elementId) {
const { peer, msgId, elementId } = contextMsgFile;
downloadPath = await this.core.apis.FileApi.downloadMedia(msgId, peer.chatType, peer.peerUid, elementId, '', '');
const rawMessage = (await this.core.apis.MsgApi.getMsgsByMsgId(peer, [msgId]))?.msgList
.find(msg => msg.msgId === msgId);
const mixElement = rawMessage?.elements.find(e => e.elementId === elementId);
const mixElementInner = mixElement?.videoElement ?? mixElement?.fileElement ?? mixElement?.pttElement ?? mixElement?.picElement;
if (!mixElementInner) throw new Error('element not found');
fileSize = parseInt(mixElementInner.fileSize?.toString() ?? '0');
fileName = mixElementInner.fileName ?? '';
}
//群文件模式
else if (FileNapCatOneBotUUID.decodeModelId(payload.file)) {
const contextModelIdFile = FileNapCatOneBotUUID.decodeModelId(payload.file);
if (contextModelIdFile && contextModelIdFile.modelId) {
const { peer, modelId } = contextModelIdFile;
downloadPath = await this.core.apis.FileApi.downloadFileForModelId(peer, modelId, '');
}
}
//搜索名字模式
else {
const searchResult = (await this.core.apis.FileApi.searchForFile([payload.file]));
if (searchResult) {
downloadPath = await this.core.apis.FileApi.downloadFileById(searchResult.id, parseInt(searchResult.fileSize));
fileSize = parseInt(searchResult.fileSize);
fileName = searchResult.fileName;
}
}
if (!downloadPath) {
throw new Error('file not found');
}
// 获取文件大小
const stats = await fs.promises.stat(downloadPath);
const totalSize = fileSize || stats.size;
// 发送文件信息
req.send({
type: 'file_info',
file_name: fileName,
file_size: totalSize,
chunk_size: chunkSize
});
// 创建读取流并分块发送
const readStream = fs.createReadStream(downloadPath, { highWaterMark: chunkSize });
let chunkIndex = 0;
let bytesRead = 0;
for await (const chunk of readStream) {
const base64Chunk = chunk.toString('base64');
bytesRead += chunk.length;
req.send({
type: 'chunk',
index: chunkIndex,
data: base64Chunk,
size: chunk.length,
progress: Math.round((bytesRead / totalSize) * 100),
base64_size: base64Chunk.length
});
chunkIndex++;
}
// 发送完成信号
req.send({
type: 'complete',
total_chunks: chunkIndex,
total_bytes: bytesRead
});
}
}

View File

@ -0,0 +1,23 @@
import { ActionName } from '@/onebot/action/router';
import { OneBotAction, OneBotRequestToolkit } from '@/onebot/action/OneBotAction';
import { Static, Type } from '@sinclair/typebox';
import { NetworkAdapterConfig } from '@/onebot/config/config';
const SchemaData = Type.Object({
});
type Payload = Static<typeof SchemaData>;
export class TestStreamDownload extends OneBotAction<Payload, string> {
override actionName = ActionName.TestStreamDownload;
override payloadSchema = SchemaData;
async _handle(_payload: Payload, _adaptername: string, _config: NetworkAdapterConfig, req: OneBotRequestToolkit) {
for (let i = 0; i < 10; i++) {
req.send({ index: i });
await new Promise(resolve => setTimeout(resolve, 100));
}
return 'done';
}
}

View File

@ -0,0 +1,368 @@
import asyncio
import websockets
import json
import base64
import hashlib
import os
from typing import Optional, Dict, Set
import time
from dataclasses import dataclass
@dataclass
class ChunkInfo:
index: int
data: bytes
size: int
retry_count: int = 0
uploaded: bool = False
class FileUploadTester:
def __init__(self, ws_uri: str, file_path: str, max_concurrent: int = 5):
self.ws_uri = ws_uri
self.file_path = file_path
self.chunk_size = 64 * 1024 # 64KB per chunk
self.max_concurrent = max_concurrent # 最大并发数
self.max_retries = 3 # 最大重试次数
self.stream_id = None
self.chunks: Dict[int, ChunkInfo] = {}
self.upload_semaphore = asyncio.Semaphore(max_concurrent)
self.failed_chunks: Set[int] = set()
# 消息路由机制
self.response_futures: Dict[str, asyncio.Future] = {}
self.message_receiver_task = None
async def connect_and_upload(self):
"""连接到WebSocket并上传文件"""
async with websockets.connect(self.ws_uri) as ws:
print(f"已连接到 {self.ws_uri}")
# 启动消息接收器
self.message_receiver_task = asyncio.create_task(self._message_receiver(ws))
try:
# 准备文件数据
file_info = self.prepare_file()
if not file_info:
return
print(f"文件信息: {file_info['filename']}, 大小: {file_info['file_size']} bytes, 块数: {file_info['total_chunks']}")
print(f"并发设置: 最大 {self.max_concurrent} 个并发上传")
# 生成stream_id
self.stream_id = f"upload_{hash(file_info['filename'] + str(file_info['file_size']))}"
print(f"Stream ID: {self.stream_id}")
# 重置流(如果存在)
await self.reset_stream(ws)
# 准备所有分片
self.prepare_chunks(file_info)
# 并行上传分片
await self.upload_chunks_parallel(ws, file_info)
# 重试失败的分片
if self.failed_chunks:
await self.retry_failed_chunks(ws, file_info)
# 完成上传
await self.complete_upload(ws)
# 等待一段时间确保所有响应都收到
await asyncio.sleep(2)
finally:
# 取消消息接收器
if self.message_receiver_task:
self.message_receiver_task.cancel()
try:
await self.message_receiver_task
except asyncio.CancelledError:
pass
# 清理未完成的Future
for future in self.response_futures.values():
if not future.done():
future.cancel()
async def _message_receiver(self, ws):
"""专门的消息接收协程负责分发响应到对应的Future"""
try:
while True:
message = await ws.recv()
try:
data = json.loads(message)
echo = data.get('echo', 'unknown')
# 查找对应的Future
if echo in self.response_futures:
future = self.response_futures[echo]
if not future.done():
future.set_result(data)
else:
# 处理未预期的响应
print(f"📨 未预期响应 [{echo}]: {data}")
except json.JSONDecodeError as e:
print(f"⚠️ JSON解析错误: {e}")
except Exception as e:
print(f"⚠️ 消息处理错误: {e}")
except asyncio.CancelledError:
print("🔄 消息接收器已停止")
raise
except Exception as e:
print(f"💥 消息接收器异常: {e}")
async def _send_and_wait_response(self, ws, request: dict, timeout: float = 10.0) -> Optional[dict]:
"""发送请求并等待响应"""
echo = request.get('echo', 'unknown')
# 创建Future用于接收响应
future = asyncio.Future()
self.response_futures[echo] = future
try:
# 发送请求
await ws.send(json.dumps(request))
# 等待响应
response = await asyncio.wait_for(future, timeout=timeout)
return response
except asyncio.TimeoutError:
print(f"⏰ 请求超时: {echo}")
return None
except Exception as e:
print(f"💥 请求异常: {echo} - {e}")
return None
finally:
# 清理Future
if echo in self.response_futures:
del self.response_futures[echo]
def prepare_file(self):
"""准备文件信息"""
if not os.path.exists(self.file_path):
print(f"文件不存在: {self.file_path}")
return None
file_size = os.path.getsize(self.file_path)
filename = os.path.basename(self.file_path)
total_chunks = (file_size + self.chunk_size - 1) // self.chunk_size
# 计算SHA256
print("计算文件SHA256...")
sha256_hash = hashlib.sha256()
with open(self.file_path, 'rb') as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256_hash.update(chunk)
expected_sha256 = sha256_hash.hexdigest()
return {
'filename': filename,
'file_size': file_size,
'total_chunks': total_chunks,
'expected_sha256': expected_sha256
}
def prepare_chunks(self, file_info):
"""预读取所有分片数据"""
print("预读取分片数据...")
with open(self.file_path, 'rb') as f:
for chunk_index in range(file_info['total_chunks']):
chunk_data = f.read(self.chunk_size)
self.chunks[chunk_index] = ChunkInfo(
index=chunk_index,
data=chunk_data,
size=len(chunk_data)
)
print(f"已准备 {len(self.chunks)} 个分片")
async def reset_stream(self, ws):
"""重置流"""
req = {
"action": "upload_file_stream",
"params": {
"stream_id": self.stream_id,
"reset": True
},
"echo": "reset"
}
print("发送重置请求...")
response = await self._send_and_wait_response(ws, req, timeout=5.0)
if response and response.get('echo') == 'reset':
print("✅ 流重置完成")
else:
print(f"⚠️ 重置响应异常: {response}")
async def upload_chunks_parallel(self, ws, file_info):
"""并行上传所有分片"""
print(f"\n开始并行上传 {len(self.chunks)} 个分片...")
start_time = time.time()
# 创建上传任务
tasks = []
for chunk_index in range(file_info['total_chunks']):
task = asyncio.create_task(
self.upload_single_chunk(ws, chunk_index, file_info)
)
tasks.append(task)
# 等待所有任务完成
results = await asyncio.gather(*tasks, return_exceptions=True)
# 统计结果
successful = sum(1 for r in results if r is True)
failed = sum(1 for r in results if r is not True)
elapsed = time.time() - start_time
speed = file_info['file_size'] / elapsed / 1024 / 1024 # MB/s
print(f"\n📊 并行上传完成:")
print(f" 成功: {successful}/{len(self.chunks)}")
print(f" 失败: {failed}")
print(f" 耗时: {elapsed:.2f}")
print(f" 速度: {speed:.2f}MB/s")
if failed > 0:
print(f"⚠️ {failed} 个分片上传失败,将进行重试")
async def upload_single_chunk(self, ws, chunk_index: int, file_info) -> bool:
"""上传单个分片"""
async with self.upload_semaphore: # 限制并发数
chunk = self.chunks[chunk_index]
try:
chunk_base64 = base64.b64encode(chunk.data).decode('utf-8')
req = {
"action": "upload_file_stream",
"params": {
"stream_id": self.stream_id,
"chunk_data": chunk_base64,
"chunk_index": chunk_index,
"total_chunks": file_info['total_chunks'],
"file_size": file_info['file_size'],
"filename": file_info['filename'],
"expected_sha256": file_info['expected_sha256']
},
"echo": f"chunk_{chunk_index}"
}
# 使用统一的发送和接收方法
response = await self._send_and_wait_response(ws, req, timeout=10.0)
if response and response.get('echo') == f"chunk_{chunk_index}":
if response.get('status') == 'ok':
chunk.uploaded = True
data = response.get('data', {})
progress = data.get('received_chunks', 0)
total = data.get('total_chunks', file_info['total_chunks'])
print(f"✅ 块 {chunk_index + 1:3d}/{total} ({chunk.size:5d}B) - 进度: {progress}/{total}")
return True
else:
error_msg = response.get('message', 'Unknown error')
print(f"❌ 块 {chunk_index + 1} 失败: {error_msg}")
self.failed_chunks.add(chunk_index)
return False
else:
print(f"⚠️ 块 {chunk_index + 1} 响应异常或超时")
self.failed_chunks.add(chunk_index)
return False
except Exception as e:
print(f"💥 块 {chunk_index + 1} 异常: {e}")
self.failed_chunks.add(chunk_index)
return False
async def retry_failed_chunks(self, ws, file_info):
"""重试失败的分片"""
print(f"\n🔄 开始重试 {len(self.failed_chunks)} 个失败分片...")
for retry_round in range(self.max_retries):
if not self.failed_chunks:
break
print(f"{retry_round + 1} 轮重试,剩余 {len(self.failed_chunks)} 个分片")
current_failed = self.failed_chunks.copy()
self.failed_chunks.clear()
# 重试当前失败的分片
retry_tasks = []
for chunk_index in current_failed:
task = asyncio.create_task(
self.upload_single_chunk(ws, chunk_index, file_info)
)
retry_tasks.append(task)
retry_results = await asyncio.gather(*retry_tasks, return_exceptions=True)
successful_retries = sum(1 for r in retry_results if r is True)
print(f"重试结果: {successful_retries}/{len(current_failed)} 成功")
if not self.failed_chunks:
print("✅ 所有分片重试成功!")
break
else:
await asyncio.sleep(1) # 重试间隔
if self.failed_chunks:
print(f"❌ 仍有 {len(self.failed_chunks)} 个分片失败: {sorted(self.failed_chunks)}")
async def complete_upload(self, ws):
"""完成上传"""
req = {
"action": "upload_file_stream",
"params": {
"stream_id": self.stream_id,
"is_complete": True
},
"echo": "complete"
}
print("\n发送完成请求...")
response = await self._send_and_wait_response(ws, req, timeout=10.0)
if response:
if response.get('status') == 'ok':
data = response.get('data', {})
print(f"✅ 上传完成!")
print(f" 文件路径: {data.get('file_path')}")
print(f" 文件大小: {data.get('file_size')} bytes")
print(f" SHA256: {data.get('sha256')}")
print(f" 状态: {data.get('status')}")
else:
print(f"❌ 上传失败: {response.get('message')}")
else:
print("⚠️ 完成请求超时或失败")
async def main():
# 配置
WS_URI = "ws://localhost:3001" # 修改为你的WebSocket地址
FILE_PATH = r"C:\Users\nanaeo\Pictures\CatPicture.zip" #!!!!!!!!!!!
MAX_CONCURRENT = 8 # 最大并发上传数,可根据服务器性能调整
# 创建测试文件(如果不存在)
if not os.path.exists(FILE_PATH):
with open(FILE_PATH, 'w', encoding='utf-8') as f:
f.write("这是一个测试文件,用于演示并行文件分片上传功能。\n" * 100)
print(f"✅ 创建测试文件: {FILE_PATH}")
print("=== 并行文件流上传测试 ===")
print(f"WebSocket URI: {WS_URI}")
print(f"文件路径: {FILE_PATH}")
print(f"最大并发数: {MAX_CONCURRENT}")
try:
tester = FileUploadTester(WS_URI, FILE_PATH, MAX_CONCURRENT)
await tester.connect_and_upload()
print("🎉 测试完成!")
except Exception as e:
print(f"💥 测试出错: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,384 @@
import { ActionName } from '@/onebot/action/router';
import { OneBotAction } from '@/onebot/action/OneBotAction';
import { Static, Type } from '@sinclair/typebox';
import fs from 'fs';
import { join as joinPath } from 'node:path';
import { randomUUID } from 'crypto';
import { createHash } from 'crypto';
// 简化配置
const CONFIG = {
TIMEOUT: 10 * 60 * 1000, // 10分钟超时
MEMORY_THRESHOLD: 10 * 1024 * 1024, // 10MB超过使用磁盘
MEMORY_LIMIT: 100 * 1024 * 1024 // 100MB内存总限制
} as const;
const SchemaData = Type.Object({
stream_id: Type.String(),
chunk_data: Type.Optional(Type.String()),
chunk_index: Type.Optional(Type.Number()),
total_chunks: Type.Optional(Type.Number()),
file_size: Type.Optional(Type.Number()),
expected_sha256: Type.Optional(Type.String()),
is_complete: Type.Optional(Type.Boolean()),
filename: Type.Optional(Type.String()),
reset: Type.Optional(Type.Boolean()),
verify_only: Type.Optional(Type.Boolean())
});
type Payload = Static<typeof SchemaData>;
// 简化流状态接口
interface StreamState {
id: string;
filename: string;
totalChunks: number;
receivedChunks: number;
missingChunks: Set<number>;
// 可选属性
fileSize?: number;
expectedSha256?: string;
// 存储策略
useMemory: boolean;
memoryChunks?: Map<number, Buffer>;
tempDir?: string;
finalPath?: string;
// 管理
createdAt: number;
timeoutId: NodeJS.Timeout;
}
interface StreamResult {
stream_id: string;
status: 'receiving' | 'completed' | 'error' | 'ready';
received_chunks: number;
total_chunks?: number;
missing_chunks?: number[];
file_path?: string;
file_size?: number;
sha256?: string;
message?: string;
}
export class UploadFileStream extends OneBotAction<Payload, StreamResult> {
override actionName = ActionName.UploadFileStream;
override payloadSchema = SchemaData;
private static streams = new Map<string, StreamState>();
private static memoryUsage = 0;
async _handle(payload: Payload): Promise<StreamResult> {
const { stream_id, reset, verify_only } = payload;
try {
if (reset) return this.resetStream(stream_id);
if (verify_only) return this.verifyStream(stream_id);
const stream = this.getOrCreateStream(payload);
if (payload.chunk_data && payload.chunk_index !== undefined) {
const result = await this.processChunk(stream, payload.chunk_data, payload.chunk_index);
if (result.status === 'error') return result;
}
if (payload.is_complete || stream.receivedChunks === stream.totalChunks) {
return await this.completeStream(stream);
}
return this.getStreamStatus(stream);
} catch (error) {
// 确保在任何错误情况下都清理资源
this.cleanupStream(stream_id, true);
return this.errorResult(stream_id, error);
}
}
private resetStream(streamId: string): StreamResult {
this.cleanupStream(streamId);
return {
stream_id: streamId,
status: 'ready',
received_chunks: 0,
message: 'Stream reset'
};
}
private verifyStream(streamId: string): StreamResult {
const stream = UploadFileStream.streams.get(streamId);
if (!stream) {
return this.errorResult(streamId, new Error('Stream not found'));
}
return this.getStreamStatus(stream);
}
private getOrCreateStream(payload: Payload): StreamState {
let stream = UploadFileStream.streams.get(payload.stream_id);
if (!stream) {
if (!payload.total_chunks) {
throw new Error('total_chunks required for new stream');
}
stream = this.createStream(payload);
}
return stream;
}
private createStream(payload: Payload): StreamState {
const { stream_id, total_chunks, file_size, filename, expected_sha256 } = payload;
const useMemory = this.shouldUseMemory(file_size);
if (useMemory && file_size && (UploadFileStream.memoryUsage + file_size) > CONFIG.MEMORY_LIMIT) {
throw new Error('Memory limit exceeded');
}
const stream: StreamState = {
id: stream_id,
filename: filename || `upload_${randomUUID()}`,
totalChunks: total_chunks!,
receivedChunks: 0,
missingChunks: new Set(Array.from({ length: total_chunks! }, (_, i) => i)),
fileSize: file_size,
expectedSha256: expected_sha256,
useMemory,
createdAt: Date.now(),
timeoutId: this.setupTimeout(stream_id)
};
try {
if (useMemory) {
stream.memoryChunks = new Map();
if (file_size) UploadFileStream.memoryUsage += file_size;
} else {
this.setupDiskStorage(stream);
}
UploadFileStream.streams.set(stream_id, stream);
return stream;
} catch (error) {
// 如果设置存储失败,清理已创建的资源
clearTimeout(stream.timeoutId);
if (stream.tempDir && fs.existsSync(stream.tempDir)) {
try {
fs.rmSync(stream.tempDir, { recursive: true, force: true });
} catch (cleanupError) {
console.error(`Failed to cleanup temp dir during creation error:`, cleanupError);
}
}
throw error;
}
}
private shouldUseMemory(fileSize?: number): boolean {
return fileSize !== undefined && fileSize <= CONFIG.MEMORY_THRESHOLD;
}
private setupDiskStorage(stream: StreamState): void {
const tempDir = joinPath(this.core.NapCatTempPath, `upload_${stream.id}`);
const finalPath = joinPath(this.core.NapCatTempPath, stream.filename);
fs.mkdirSync(tempDir, { recursive: true });
stream.tempDir = tempDir;
stream.finalPath = finalPath;
}
private setupTimeout(streamId: string): NodeJS.Timeout {
return setTimeout(() => {
console.log(`Stream ${streamId} timeout`);
this.cleanupStream(streamId);
}, CONFIG.TIMEOUT);
}
private async processChunk(stream: StreamState, chunkData: string, chunkIndex: number): Promise<StreamResult> {
// 验证索引
if (chunkIndex < 0 || chunkIndex >= stream.totalChunks) {
throw new Error(`Invalid chunk index: ${chunkIndex}`);
}
// 检查重复
if (!stream.missingChunks.has(chunkIndex)) {
return this.getStreamStatus(stream, `Chunk ${chunkIndex} already received`);
}
try {
const buffer = Buffer.from(chunkData, 'base64');
// 存储分片
if (stream.useMemory) {
stream.memoryChunks!.set(chunkIndex, buffer);
} else {
const chunkPath = joinPath(stream.tempDir!, `${chunkIndex}.chunk`);
await fs.promises.writeFile(chunkPath, buffer);
}
// 更新状态
stream.missingChunks.delete(chunkIndex);
stream.receivedChunks++;
this.refreshTimeout(stream);
return this.getStreamStatus(stream);
} catch (error) {
throw new Error(`Chunk processing failed: ${error instanceof Error ? error.message : 'Unknown error'}`);
}
}
private refreshTimeout(stream: StreamState): void {
clearTimeout(stream.timeoutId);
stream.timeoutId = this.setupTimeout(stream.id);
}
private getStreamStatus(stream: StreamState, message?: string): StreamResult {
const missingChunks = Array.from(stream.missingChunks).sort();
return {
stream_id: stream.id,
status: 'receiving',
received_chunks: stream.receivedChunks,
total_chunks: stream.totalChunks,
missing_chunks: missingChunks.length > 0 ? missingChunks : undefined,
file_size: stream.fileSize,
message
};
}
private async completeStream(stream: StreamState): Promise<StreamResult> {
try {
// 合并分片
const finalBuffer = stream.useMemory ?
await this.mergeMemoryChunks(stream) :
await this.mergeDiskChunks(stream);
// 验证SHA256
const sha256 = this.validateSha256(stream, finalBuffer);
// 保存文件
const finalPath = stream.finalPath || joinPath(this.core.NapCatTempPath, stream.filename);
await fs.promises.writeFile(finalPath, finalBuffer);
// 清理资源但保留文件
this.cleanupStream(stream.id, false);
return {
stream_id: stream.id,
status: 'completed',
received_chunks: stream.receivedChunks,
total_chunks: stream.totalChunks,
file_path: finalPath,
file_size: finalBuffer.length,
sha256,
message: 'Upload completed'
};
} catch (error) {
throw new Error(`Stream completion failed: ${error instanceof Error ? error.message : 'Unknown error'}`);
}
}
private async mergeMemoryChunks(stream: StreamState): Promise<Buffer> {
const chunks: Buffer[] = [];
for (let i = 0; i < stream.totalChunks; i++) {
const chunk = stream.memoryChunks!.get(i);
if (!chunk) throw new Error(`Missing memory chunk ${i}`);
chunks.push(chunk);
}
return Buffer.concat(chunks);
}
private async mergeDiskChunks(stream: StreamState): Promise<Buffer> {
const chunks: Buffer[] = [];
for (let i = 0; i < stream.totalChunks; i++) {
const chunkPath = joinPath(stream.tempDir!, `${i}.chunk`);
if (!fs.existsSync(chunkPath)) throw new Error(`Missing chunk file ${i}`);
chunks.push(await fs.promises.readFile(chunkPath));
}
return Buffer.concat(chunks);
}
private validateSha256(stream: StreamState, buffer: Buffer): string | undefined {
if (!stream.expectedSha256) return undefined;
const actualSha256 = createHash('sha256').update(buffer).digest('hex');
if (actualSha256 !== stream.expectedSha256) {
throw new Error(`SHA256 mismatch. Expected: ${stream.expectedSha256}, Got: ${actualSha256}`);
}
return actualSha256;
}
private cleanupStream(streamId: string, deleteFinalFile = true): void {
const stream = UploadFileStream.streams.get(streamId);
if (!stream) return;
try {
// 清理超时
clearTimeout(stream.timeoutId);
// 清理内存
if (stream.useMemory) {
if (stream.fileSize) {
UploadFileStream.memoryUsage = Math.max(0, UploadFileStream.memoryUsage - stream.fileSize);
}
stream.memoryChunks?.clear();
}
// 清理临时文件夹及其所有内容
if (stream.tempDir) {
try {
if (fs.existsSync(stream.tempDir)) {
fs.rmSync(stream.tempDir, { recursive: true, force: true });
console.log(`Cleaned up temp directory: ${stream.tempDir}`);
}
} catch (error) {
console.error(`Failed to cleanup temp directory ${stream.tempDir}:`, error);
}
}
// 删除最终文件(如果需要)
if (deleteFinalFile && stream.finalPath) {
try {
if (fs.existsSync(stream.finalPath)) {
fs.unlinkSync(stream.finalPath);
console.log(`Deleted final file: ${stream.finalPath}`);
}
} catch (error) {
console.error(`Failed to delete final file ${stream.finalPath}:`, error);
}
}
} catch (error) {
console.error(`Cleanup error for stream ${streamId}:`, error);
} finally {
UploadFileStream.streams.delete(streamId);
console.log(`Stream ${streamId} cleaned up`);
}
}
private errorResult(streamId: string, error: any): StreamResult {
return {
stream_id: streamId,
status: 'error',
received_chunks: 0,
message: error instanceof Error ? error.message : 'Unknown error'
};
}
// 全局状态查询
static getGlobalStatus() {
return {
activeStreams: this.streams.size,
memoryUsageMB: Math.round(this.memoryUsage / 1024 / 1024 * 100) / 100,
streams: Array.from(this.streams.values()).map(stream => ({
streamId: stream.id,
filename: stream.filename,
progress: `${stream.receivedChunks}/${stream.totalChunks}`,
useMemory: stream.useMemory,
createdAt: new Date(stream.createdAt).toISOString()
}))
};
}
}

View File

@ -0,0 +1,201 @@
import asyncio
import websockets
import json
import base64
import hashlib
import os
from typing import Optional
class FileUploadTester:
def __init__(self, ws_uri: str, file_path: str):
self.ws_uri = ws_uri
self.file_path = file_path
self.chunk_size = 64 * 1024 # 64KB per chunk
self.stream_id = None
async def connect_and_upload(self):
"""连接到WebSocket并上传文件"""
async with websockets.connect(self.ws_uri) as ws:
print(f"已连接到 {self.ws_uri}")
# 准备文件数据
file_info = self.prepare_file()
if not file_info:
return
print(f"文件信息: {file_info['filename']}, 大小: {file_info['file_size']} bytes, 块数: {file_info['total_chunks']}")
# 生成stream_id
self.stream_id = f"upload_{hash(file_info['filename'] + str(file_info['file_size']))}"
print(f"Stream ID: {self.stream_id}")
# 重置流(如果存在)
await self.reset_stream(ws)
# 开始分块上传
await self.upload_chunks(ws, file_info)
# 完成上传
await self.complete_upload(ws)
# 等待一些响应
await self.listen_for_responses(ws)
def prepare_file(self):
"""准备文件信息"""
if not os.path.exists(self.file_path):
print(f"文件不存在: {self.file_path}")
return None
file_size = os.path.getsize(self.file_path)
filename = os.path.basename(self.file_path)
total_chunks = (file_size + self.chunk_size - 1) // self.chunk_size
# 计算SHA256
sha256_hash = hashlib.sha256()
with open(self.file_path, 'rb') as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256_hash.update(chunk)
expected_sha256 = sha256_hash.hexdigest()
return {
'filename': filename,
'file_size': file_size,
'total_chunks': total_chunks,
'expected_sha256': expected_sha256
}
async def reset_stream(self, ws):
"""重置流"""
req = {
"action": "upload_file_stream",
"params": {
"stream_id": self.stream_id,
"reset": True
},
"echo": "reset"
}
await ws.send(json.dumps(req))
print("发送重置请求...")
async def upload_chunks(self, ws, file_info):
"""上传文件块"""
with open(self.file_path, 'rb') as f:
for chunk_index in range(file_info['total_chunks']):
# 读取块数据
chunk_data = f.read(self.chunk_size)
chunk_base64 = base64.b64encode(chunk_data).decode('utf-8')
# 准备请求
req = {
"action": "upload_file_stream",
"params": {
"stream_id": self.stream_id,
"chunk_data": chunk_base64,
"chunk_index": chunk_index,
"total_chunks": file_info['total_chunks'],
"file_size": file_info['file_size'],
"filename": file_info['filename'],
#"expected_sha256": file_info['expected_sha256']
},
"echo": f"chunk_{chunk_index}"
}
await ws.send(json.dumps(req))
print(f"发送块 {chunk_index + 1}/{file_info['total_chunks']} ({len(chunk_data)} bytes)")
# 等待响应
try:
response = await asyncio.wait_for(ws.recv(), timeout=5.0)
resp_data = json.loads(response)
if resp_data.get('echo') == f"chunk_{chunk_index}":
if resp_data.get('status') == 'ok':
data = resp_data.get('data', {})
print(f" -> 状态: {data.get('status')}, 已接收: {data.get('received_chunks')}")
else:
print(f" -> 错误: {resp_data.get('message')}")
except asyncio.TimeoutError:
print(f" -> 块 {chunk_index} 响应超时")
# 小延迟避免过快发送
await asyncio.sleep(0.1)
async def complete_upload(self, ws):
"""完成上传"""
req = {
"action": "upload_file_stream",
"params": {
"stream_id": self.stream_id,
"is_complete": True
},
"echo": "complete"
}
await ws.send(json.dumps(req))
print("发送完成请求...")
async def verify_stream(self, ws):
"""验证流状态"""
req = {
"action": "upload_file_stream",
"params": {
"stream_id": self.stream_id,
"verify_only": True
},
"echo": "verify"
}
await ws.send(json.dumps(req))
print("发送验证请求...")
async def listen_for_responses(self, ws, duration=10):
"""监听响应"""
print(f"监听响应 {duration} 秒...")
try:
end_time = asyncio.get_event_loop().time() + duration
while asyncio.get_event_loop().time() < end_time:
try:
msg = await asyncio.wait_for(ws.recv(), timeout=1.0)
resp_data = json.loads(msg)
echo = resp_data.get('echo', 'unknown')
if echo == "complete":
if resp_data.get('status') == 'ok':
data = resp_data.get('data', {})
print(f"✅ 上传完成!")
print(f" 文件路径: {data.get('file_path')}")
print(f" 文件大小: {data.get('file_size')} bytes")
print(f" SHA256: {data.get('sha256')}")
print(f" 状态: {data.get('status')}")
else:
print(f"❌ 上传失败: {resp_data.get('message')}")
elif echo == "verify":
if resp_data.get('status') == 'ok':
data = resp_data.get('data', {})
print(f"🔍 验证结果: {data}")
elif echo == "reset":
print(f"🔄 重置完成: {resp_data}")
else:
print(f"📨 收到响应 [{echo}]: {resp_data}")
except asyncio.TimeoutError:
continue
except Exception as e:
print(f"监听出错: {e}")
async def main():
# 配置
WS_URI = "ws://localhost:3001" # 修改为你的WebSocket地址
FILE_PATH = "C:\\Users\\nanaeo\\Pictures\\CatPicture.zip"
print("=== 文件流上传测试 ===")
print(f"WebSocket URI: {WS_URI}")
print(f"文件路径: {FILE_PATH}")
try:
tester = FileUploadTester(WS_URI, FILE_PATH)
await tester.connect_and_upload()
except Exception as e:
print(f"测试出错: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -9,7 +9,7 @@ export class OB11HttpSSEServerAdapter extends OB11HttpServerAdapter {
if (req.path === '/_events') {
this.createSseSupport(req, res);
} else {
super.httpApiRequest(req, res);
super.httpApiRequest(req, res, true);
}
}

View File

@ -104,7 +104,7 @@ export class OB11HttpServerAdapter extends IOB11NetworkAdapter<HttpServerConfig>
}
}
async httpApiRequest(req: Request, res: Response) {
async httpApiRequest(req: Request, res: Response, request_sse: boolean = false) {
let payload = req.body;
if (req.method == 'get') {
payload = req.query;
@ -117,17 +117,31 @@ export class OB11HttpServerAdapter extends IOB11NetworkAdapter<HttpServerConfig>
return res.json(hello);
}
const actionName = req.path.split('/')[1];
const payload_echo = payload['echo'];
const real_echo = payload_echo ?? Math.random().toString(36).substring(2, 15);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const action = this.actions.get(actionName as any);
if (action) {
try {
const result = await action.handle(payload, this.name, this.config);
let stream = false;
const result = await action.handle(payload, this.name, this.config, {
send: request_sse ? async (data: object) => {
this.onEvent({ ...OB11Response.ok(data, real_echo), type: 'sse-action' } as unknown as OB11EmitEventContent);
} : async (data: object) => {
stream = true;
res.write(JSON.stringify({ ...OB11Response.ok(data, real_echo), type: 'stream-action' }) + "\r\n\r\n");
}
}, real_echo);
if (stream) {
res.write(JSON.stringify({ ...result, type: 'stream-action' }) + "\r\n\r\n");
return res.end();
};
return res.json(result);
} catch (error: unknown) {
return res.json(OB11Response.error((error as Error)?.stack?.toString() || (error as Error)?.message || 'Error Handle', 200));
return res.json(OB11Response.error((error as Error)?.stack?.toString() || (error as Error)?.message || 'Error Handle', 200, real_echo));
}
} else {
return res.json(OB11Response.error('不支持的Api ' + actionName, 200));
return res.json(OB11Response.error('不支持的Api ' + actionName, 200, real_echo));
}
}

View File

@ -151,7 +151,11 @@ export class OB11WebSocketClientAdapter extends IOB11NetworkAdapter<WebsocketCli
this.checkStateAndReply<unknown>(OB11Response.error('不支持的Api ' + receiveData.action, 1404, echo));
return;
}
const retdata = await action.websocketHandle(receiveData.params, echo ?? '', this.name, this.config);
const retdata = await action.websocketHandle(receiveData.params, echo ?? '', this.name, this.config, {
send: async (data: object) => {
this.checkStateAndReply<unknown>({ ...OB11Response.ok(data, echo ?? '') });
}
});
this.checkStateAndReply<unknown>({ ...retdata });
}
async reload(newConfig: WebsocketClientConfig) {

View File

@ -186,7 +186,11 @@ export class OB11WebSocketServerAdapter extends IOB11NetworkAdapter<WebsocketSer
this.checkStateAndReply<unknown>(OB11Response.error('不支持的API ' + receiveData.action, 1404, echo), wsClient);
return;
}
const retdata = await action.websocketHandle(receiveData.params, echo ?? '', this.name, this.config);
const retdata = await action.websocketHandle(receiveData.params, echo ?? '', this.name, this.config, {
send: async (data: object) => {
this.checkStateAndReply<unknown>({ ...OB11Response.ok(data, echo ?? '') }, wsClient);
}
});
this.checkStateAndReply<unknown>({ ...retdata }, wsClient);
}