diff --git a/src/onebot/action/OneBotAction.ts b/src/onebot/action/OneBotAction.ts index 818169e2..cdd1dd22 100644 --- a/src/onebot/action/OneBotAction.ts +++ b/src/onebot/action/OneBotAction.ts @@ -29,7 +29,9 @@ export class OB11Response { return this.createResponse(null, 'failed', retcode, err, echo); } } - +export interface OneBotRequestToolkit { + send: (data: T) => Promise; +} export abstract class OneBotAction { actionName: typeof ActionName[keyof typeof ActionName] = ActionName.Unknown; core: NapCatCore; @@ -57,27 +59,27 @@ export abstract class OneBotAction { return { valid: true }; } - public async handle(payload: PayloadType, adaptername: string, config: NetworkAdapterConfig): Promise> { + public async handle(payload: PayloadType, adaptername: string, config: NetworkAdapterConfig, req: OneBotRequestToolkit = { send: async () => { } }, echo?: string): Promise> { const result = await this.check(payload); if (!result.valid) { return OB11Response.error(result.message, 400); } try { - const resData = await this._handle(payload, adaptername, config); - return OB11Response.ok(resData); + const resData = await this._handle(payload, adaptername, config, req); + return OB11Response.ok(resData, echo); } catch (e: unknown) { this.core.context.logger.logError('发生错误', e); - return OB11Response.error((e as Error).message.toString() || (e as Error)?.stack?.toString() || '未知错误,可能操作超时', 200); + return OB11Response.error((e as Error).message.toString() || (e as Error)?.stack?.toString() || '未知错误,可能操作超时', 200, echo); } } - public async websocketHandle(payload: PayloadType, echo: unknown, adaptername: string, config: NetworkAdapterConfig): Promise> { + public async websocketHandle(payload: PayloadType, echo: unknown, adaptername: string, config: NetworkAdapterConfig, req: OneBotRequestToolkit = { send: async () => { } }): Promise> { const result = await this.check(payload); if (!result.valid) { return OB11Response.error(result.message, 1400, echo); } try { - const resData = await this._handle(payload, adaptername, config); + const resData = await this._handle(payload, adaptername, config, req); return OB11Response.ok(resData, echo); } catch (e: unknown) { this.core.context.logger.logError('发生错误', e); @@ -85,5 +87,5 @@ export abstract class OneBotAction { } } - abstract _handle(payload: PayloadType, adaptername: string, config: NetworkAdapterConfig): Promise; + abstract _handle(payload: PayloadType, adaptername: string, config: NetworkAdapterConfig, req: OneBotRequestToolkit): Promise; } diff --git a/src/onebot/action/file/GetFile.ts b/src/onebot/action/file/GetFile.ts index 853261cd..ceb904ee 100644 --- a/src/onebot/action/file/GetFile.ts +++ b/src/onebot/action/file/GetFile.ts @@ -41,12 +41,12 @@ export class GetFileBase extends OneBotAction { let url = ''; if (mixElement?.picElement && rawMessage) { const tempData = - await this.obContext.apis.MsgApi.rawToOb11Converters.picElement?.(mixElement?.picElement, rawMessage, mixElement, { parseMultMsg: false }) as OB11MessageImage | undefined; + await this.obContext.apis.MsgApi.rawToOb11Converters.picElement?.(mixElement?.picElement, rawMessage, mixElement, { parseMultMsg: false, disableGetUrl: false, quick_reply: true }) as OB11MessageImage | undefined; url = tempData?.data.url ?? ''; } if (mixElement?.videoElement && rawMessage) { const tempData = - await this.obContext.apis.MsgApi.rawToOb11Converters.videoElement?.(mixElement?.videoElement, rawMessage, mixElement, { parseMultMsg: false }) as OB11MessageVideo | undefined; + await this.obContext.apis.MsgApi.rawToOb11Converters.videoElement?.(mixElement?.videoElement, rawMessage, mixElement, { parseMultMsg: false, disableGetUrl: false, quick_reply: true }) as OB11MessageVideo | undefined; url = tempData?.data.url ?? ''; } const res: GetFileResponse = { diff --git a/src/onebot/action/index.ts b/src/onebot/action/index.ts index d6187237..28bf3a84 100644 --- a/src/onebot/action/index.ts +++ b/src/onebot/action/index.ts @@ -130,10 +130,18 @@ import { DoGroupAlbumComment } from './extends/DoGroupAlbumComment'; import { GetGroupAlbumMediaList } from './extends/GetGroupAlbumMediaList'; import { SetGroupAlbumMediaLike } from './extends/SetGroupAlbumMediaLike'; import { DelGroupAlbumMedia } from './extends/DelGroupAlbumMedia'; +import { CleanStreamTempFile } from './stream/CleanStreamTempFile'; +import { DownloadFileStream } from './stream/DownloadFileStream'; +import { TestStreamDownload } from './stream/TestStreamDownload'; +import { UploadFileStream } from './stream/UploadFileStream'; export function createActionMap(obContext: NapCatOneBot11Adapter, core: NapCatCore) { const actionHandlers = [ + new CleanStreamTempFile(obContext, core), + new DownloadFileStream(obContext, core), + new TestStreamDownload(obContext, core), + new UploadFileStream(obContext, core), new DelGroupAlbumMedia(obContext, core), new SetGroupAlbumMediaLike(obContext, core), new DoGroupAlbumComment(obContext, core), diff --git a/src/onebot/action/router.ts b/src/onebot/action/router.ts index c2ed5a00..57b9de36 100644 --- a/src/onebot/action/router.ts +++ b/src/onebot/action/router.ts @@ -10,6 +10,10 @@ export interface InvalidCheckResult { } export const ActionName = { + CleanStreamTempFile: 'clean_stream_temp_file', + TestStreamDownload: 'test_stream_download', + UploadFileStream: 'upload_file_stream', + DownloadFileStream: 'download_file_stream', DelGroupAlbumMedia: 'del_group_album_media', SetGroupAlbumMediaLike: 'set_group_album_media_like', DoGroupAlbumComment: 'do_group_album_comment', diff --git a/src/onebot/action/stream/CleanStreamTempFile.ts b/src/onebot/action/stream/CleanStreamTempFile.ts new file mode 100644 index 00000000..48928fcd --- /dev/null +++ b/src/onebot/action/stream/CleanStreamTempFile.ts @@ -0,0 +1,33 @@ +import { ActionName } from '@/onebot/action/router'; +import { OneBotAction } from '@/onebot/action/OneBotAction'; +import { join } from 'node:path'; +import { readdir, unlink } from 'node:fs/promises'; + +export class CleanStreamTempFile extends OneBotAction { + override actionName = ActionName.CleanStreamTempFile; + + async _handle(_payload: void): Promise { + try { + // 获取临时文件夹路径 + const tempPath = this.core.NapCatTempPath; + + // 读取文件夹中的所有文件 + const files = await readdir(tempPath); + + // 删除每个文件 + const deletePromises = files.map(async (file) => { + const filePath = join(tempPath, file); + try { + await unlink(filePath); + this.core.context.logger.log(`已删除文件: ${filePath}`); + } catch (err: unknown) { + this.core.context.logger.log(`删除文件 ${filePath} 失败: ${(err as Error).message}`); + + } + }); + await Promise.all(deletePromises); + } catch (err: unknown) { + this.core.context.logger.log(`清理流临时文件失败: ${(err as Error).message}`); + } + } +} diff --git a/src/onebot/action/stream/DownloadFileStream.ts b/src/onebot/action/stream/DownloadFileStream.ts new file mode 100644 index 00000000..a18e3e14 --- /dev/null +++ b/src/onebot/action/stream/DownloadFileStream.ts @@ -0,0 +1,101 @@ +import { ActionName } from '@/onebot/action/router'; +import { OneBotAction, OneBotRequestToolkit } from '@/onebot/action/OneBotAction'; +import { Static, Type } from '@sinclair/typebox'; +import { NetworkAdapterConfig } from '@/onebot/config/config'; +import fs from 'fs'; +import { FileNapCatOneBotUUID } from '@/common/file-uuid'; +const SchemaData = Type.Object({ + file: Type.Optional(Type.String()), + file_id: Type.Optional(Type.String()), + chunk_size: Type.Optional(Type.Number({ default: 64 * 1024 })) // 默认64KB分块 +}); + +type Payload = Static; + +export class DownloadFileStream extends OneBotAction { + override actionName = ActionName.DownloadFileStream; + override payloadSchema = SchemaData; + + async _handle(payload: Payload, _adaptername: string, _config: NetworkAdapterConfig, req: OneBotRequestToolkit) { + payload.file ||= payload.file_id || ''; + const chunkSize = payload.chunk_size || 64 * 1024; + let downloadPath = ''; + let fileName = ''; + let fileSize = 0; + + //接收消息标记模式 + const contextMsgFile = FileNapCatOneBotUUID.decode(payload.file); + if (contextMsgFile && contextMsgFile.msgId && contextMsgFile.elementId) { + const { peer, msgId, elementId } = contextMsgFile; + downloadPath = await this.core.apis.FileApi.downloadMedia(msgId, peer.chatType, peer.peerUid, elementId, '', ''); + const rawMessage = (await this.core.apis.MsgApi.getMsgsByMsgId(peer, [msgId]))?.msgList + .find(msg => msg.msgId === msgId); + const mixElement = rawMessage?.elements.find(e => e.elementId === elementId); + const mixElementInner = mixElement?.videoElement ?? mixElement?.fileElement ?? mixElement?.pttElement ?? mixElement?.picElement; + if (!mixElementInner) throw new Error('element not found'); + fileSize = parseInt(mixElementInner.fileSize?.toString() ?? '0'); + fileName = mixElementInner.fileName ?? ''; + } + //群文件模式 + else if (FileNapCatOneBotUUID.decodeModelId(payload.file)) { + const contextModelIdFile = FileNapCatOneBotUUID.decodeModelId(payload.file); + if (contextModelIdFile && contextModelIdFile.modelId) { + const { peer, modelId } = contextModelIdFile; + downloadPath = await this.core.apis.FileApi.downloadFileForModelId(peer, modelId, ''); + } + } + //搜索名字模式 + else { + const searchResult = (await this.core.apis.FileApi.searchForFile([payload.file])); + if (searchResult) { + downloadPath = await this.core.apis.FileApi.downloadFileById(searchResult.id, parseInt(searchResult.fileSize)); + fileSize = parseInt(searchResult.fileSize); + fileName = searchResult.fileName; + } + } + + if (!downloadPath) { + throw new Error('file not found'); + } + + // 获取文件大小 + const stats = await fs.promises.stat(downloadPath); + const totalSize = fileSize || stats.size; + + // 发送文件信息 + req.send({ + type: 'file_info', + file_name: fileName, + file_size: totalSize, + chunk_size: chunkSize + }); + + // 创建读取流并分块发送 + const readStream = fs.createReadStream(downloadPath, { highWaterMark: chunkSize }); + let chunkIndex = 0; + let bytesRead = 0; + + for await (const chunk of readStream) { + const base64Chunk = chunk.toString('base64'); + bytesRead += chunk.length; + + req.send({ + type: 'chunk', + index: chunkIndex, + data: base64Chunk, + size: chunk.length, + progress: Math.round((bytesRead / totalSize) * 100), + base64_size: base64Chunk.length + }); + + chunkIndex++; + } + + // 发送完成信号 + req.send({ + type: 'complete', + total_chunks: chunkIndex, + total_bytes: bytesRead + }); + } +} diff --git a/src/onebot/action/stream/TestStreamDownload.ts b/src/onebot/action/stream/TestStreamDownload.ts new file mode 100644 index 00000000..b0362e04 --- /dev/null +++ b/src/onebot/action/stream/TestStreamDownload.ts @@ -0,0 +1,23 @@ +import { ActionName } from '@/onebot/action/router'; +import { OneBotAction, OneBotRequestToolkit } from '@/onebot/action/OneBotAction'; +import { Static, Type } from '@sinclair/typebox'; +import { NetworkAdapterConfig } from '@/onebot/config/config'; + +const SchemaData = Type.Object({ + +}); + +type Payload = Static; + +export class TestStreamDownload extends OneBotAction { + override actionName = ActionName.TestStreamDownload; + override payloadSchema = SchemaData; + + async _handle(_payload: Payload, _adaptername: string, _config: NetworkAdapterConfig, req: OneBotRequestToolkit) { + for (let i = 0; i < 10; i++) { + req.send({ index: i }); + await new Promise(resolve => setTimeout(resolve, 100)); + } + return 'done'; + } +} diff --git a/src/onebot/action/stream/UploadFileStream.py b/src/onebot/action/stream/UploadFileStream.py new file mode 100644 index 00000000..f982ca5d --- /dev/null +++ b/src/onebot/action/stream/UploadFileStream.py @@ -0,0 +1,368 @@ +import asyncio +import websockets +import json +import base64 +import hashlib +import os +from typing import Optional, Dict, Set +import time +from dataclasses import dataclass + +@dataclass +class ChunkInfo: + index: int + data: bytes + size: int + retry_count: int = 0 + uploaded: bool = False + +class FileUploadTester: + def __init__(self, ws_uri: str, file_path: str, max_concurrent: int = 5): + self.ws_uri = ws_uri + self.file_path = file_path + self.chunk_size = 64 * 1024 # 64KB per chunk + self.max_concurrent = max_concurrent # 最大并发数 + self.max_retries = 3 # 最大重试次数 + self.stream_id = None + self.chunks: Dict[int, ChunkInfo] = {} + self.upload_semaphore = asyncio.Semaphore(max_concurrent) + self.failed_chunks: Set[int] = set() + + # 消息路由机制 + self.response_futures: Dict[str, asyncio.Future] = {} + self.message_receiver_task = None + + async def connect_and_upload(self): + """连接到WebSocket并上传文件""" + async with websockets.connect(self.ws_uri) as ws: + print(f"已连接到 {self.ws_uri}") + + # 启动消息接收器 + self.message_receiver_task = asyncio.create_task(self._message_receiver(ws)) + + try: + # 准备文件数据 + file_info = self.prepare_file() + if not file_info: + return + + print(f"文件信息: {file_info['filename']}, 大小: {file_info['file_size']} bytes, 块数: {file_info['total_chunks']}") + print(f"并发设置: 最大 {self.max_concurrent} 个并发上传") + + # 生成stream_id + self.stream_id = f"upload_{hash(file_info['filename'] + str(file_info['file_size']))}" + print(f"Stream ID: {self.stream_id}") + + # 重置流(如果存在) + await self.reset_stream(ws) + + # 准备所有分片 + self.prepare_chunks(file_info) + + # 并行上传分片 + await self.upload_chunks_parallel(ws, file_info) + + # 重试失败的分片 + if self.failed_chunks: + await self.retry_failed_chunks(ws, file_info) + + # 完成上传 + await self.complete_upload(ws) + + # 等待一段时间确保所有响应都收到 + await asyncio.sleep(2) + + finally: + # 取消消息接收器 + if self.message_receiver_task: + self.message_receiver_task.cancel() + try: + await self.message_receiver_task + except asyncio.CancelledError: + pass + + # 清理未完成的Future + for future in self.response_futures.values(): + if not future.done(): + future.cancel() + + async def _message_receiver(self, ws): + """专门的消息接收协程,负责分发响应到对应的Future""" + try: + while True: + message = await ws.recv() + try: + data = json.loads(message) + echo = data.get('echo', 'unknown') + + # 查找对应的Future + if echo in self.response_futures: + future = self.response_futures[echo] + if not future.done(): + future.set_result(data) + else: + # 处理未预期的响应 + print(f"📨 未预期响应 [{echo}]: {data}") + + except json.JSONDecodeError as e: + print(f"⚠️ JSON解析错误: {e}") + except Exception as e: + print(f"⚠️ 消息处理错误: {e}") + + except asyncio.CancelledError: + print("🔄 消息接收器已停止") + raise + except Exception as e: + print(f"💥 消息接收器异常: {e}") + + async def _send_and_wait_response(self, ws, request: dict, timeout: float = 10.0) -> Optional[dict]: + """发送请求并等待响应""" + echo = request.get('echo', 'unknown') + + # 创建Future用于接收响应 + future = asyncio.Future() + self.response_futures[echo] = future + + try: + # 发送请求 + await ws.send(json.dumps(request)) + + # 等待响应 + response = await asyncio.wait_for(future, timeout=timeout) + return response + + except asyncio.TimeoutError: + print(f"⏰ 请求超时: {echo}") + return None + except Exception as e: + print(f"💥 请求异常: {echo} - {e}") + return None + finally: + # 清理Future + if echo in self.response_futures: + del self.response_futures[echo] + + def prepare_file(self): + """准备文件信息""" + if not os.path.exists(self.file_path): + print(f"文件不存在: {self.file_path}") + return None + + file_size = os.path.getsize(self.file_path) + filename = os.path.basename(self.file_path) + total_chunks = (file_size + self.chunk_size - 1) // self.chunk_size + + # 计算SHA256 + print("计算文件SHA256...") + sha256_hash = hashlib.sha256() + with open(self.file_path, 'rb') as f: + for chunk in iter(lambda: f.read(8192), b""): + sha256_hash.update(chunk) + expected_sha256 = sha256_hash.hexdigest() + + return { + 'filename': filename, + 'file_size': file_size, + 'total_chunks': total_chunks, + 'expected_sha256': expected_sha256 + } + + def prepare_chunks(self, file_info): + """预读取所有分片数据""" + print("预读取分片数据...") + with open(self.file_path, 'rb') as f: + for chunk_index in range(file_info['total_chunks']): + chunk_data = f.read(self.chunk_size) + self.chunks[chunk_index] = ChunkInfo( + index=chunk_index, + data=chunk_data, + size=len(chunk_data) + ) + print(f"已准备 {len(self.chunks)} 个分片") + + async def reset_stream(self, ws): + """重置流""" + req = { + "action": "upload_file_stream", + "params": { + "stream_id": self.stream_id, + "reset": True + }, + "echo": "reset" + } + + print("发送重置请求...") + response = await self._send_and_wait_response(ws, req, timeout=5.0) + + if response and response.get('echo') == 'reset': + print("✅ 流重置完成") + else: + print(f"⚠️ 重置响应异常: {response}") + + async def upload_chunks_parallel(self, ws, file_info): + """并行上传所有分片""" + print(f"\n开始并行上传 {len(self.chunks)} 个分片...") + start_time = time.time() + + # 创建上传任务 + tasks = [] + for chunk_index in range(file_info['total_chunks']): + task = asyncio.create_task( + self.upload_single_chunk(ws, chunk_index, file_info) + ) + tasks.append(task) + + # 等待所有任务完成 + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 统计结果 + successful = sum(1 for r in results if r is True) + failed = sum(1 for r in results if r is not True) + + elapsed = time.time() - start_time + speed = file_info['file_size'] / elapsed / 1024 / 1024 # MB/s + + print(f"\n📊 并行上传完成:") + print(f" 成功: {successful}/{len(self.chunks)}") + print(f" 失败: {failed}") + print(f" 耗时: {elapsed:.2f}秒") + print(f" 速度: {speed:.2f}MB/s") + + if failed > 0: + print(f"⚠️ {failed} 个分片上传失败,将进行重试") + + async def upload_single_chunk(self, ws, chunk_index: int, file_info) -> bool: + """上传单个分片""" + async with self.upload_semaphore: # 限制并发数 + chunk = self.chunks[chunk_index] + + try: + chunk_base64 = base64.b64encode(chunk.data).decode('utf-8') + + req = { + "action": "upload_file_stream", + "params": { + "stream_id": self.stream_id, + "chunk_data": chunk_base64, + "chunk_index": chunk_index, + "total_chunks": file_info['total_chunks'], + "file_size": file_info['file_size'], + "filename": file_info['filename'], + "expected_sha256": file_info['expected_sha256'] + }, + "echo": f"chunk_{chunk_index}" + } + + # 使用统一的发送和接收方法 + response = await self._send_and_wait_response(ws, req, timeout=10.0) + + if response and response.get('echo') == f"chunk_{chunk_index}": + if response.get('status') == 'ok': + chunk.uploaded = True + data = response.get('data', {}) + progress = data.get('received_chunks', 0) + total = data.get('total_chunks', file_info['total_chunks']) + print(f"✅ 块 {chunk_index + 1:3d}/{total} ({chunk.size:5d}B) - 进度: {progress}/{total}") + return True + else: + error_msg = response.get('message', 'Unknown error') + print(f"❌ 块 {chunk_index + 1} 失败: {error_msg}") + self.failed_chunks.add(chunk_index) + return False + else: + print(f"⚠️ 块 {chunk_index + 1} 响应异常或超时") + self.failed_chunks.add(chunk_index) + return False + + except Exception as e: + print(f"💥 块 {chunk_index + 1} 异常: {e}") + self.failed_chunks.add(chunk_index) + return False + + async def retry_failed_chunks(self, ws, file_info): + """重试失败的分片""" + print(f"\n🔄 开始重试 {len(self.failed_chunks)} 个失败分片...") + + for retry_round in range(self.max_retries): + if not self.failed_chunks: + break + + print(f"第 {retry_round + 1} 轮重试,剩余 {len(self.failed_chunks)} 个分片") + current_failed = self.failed_chunks.copy() + self.failed_chunks.clear() + + # 重试当前失败的分片 + retry_tasks = [] + for chunk_index in current_failed: + task = asyncio.create_task( + self.upload_single_chunk(ws, chunk_index, file_info) + ) + retry_tasks.append(task) + + retry_results = await asyncio.gather(*retry_tasks, return_exceptions=True) + successful_retries = sum(1 for r in retry_results if r is True) + + print(f"重试结果: {successful_retries}/{len(current_failed)} 成功") + + if not self.failed_chunks: + print("✅ 所有分片重试成功!") + break + else: + await asyncio.sleep(1) # 重试间隔 + + if self.failed_chunks: + print(f"❌ 仍有 {len(self.failed_chunks)} 个分片失败: {sorted(self.failed_chunks)}") + + async def complete_upload(self, ws): + """完成上传""" + req = { + "action": "upload_file_stream", + "params": { + "stream_id": self.stream_id, + "is_complete": True + }, + "echo": "complete" + } + + print("\n发送完成请求...") + response = await self._send_and_wait_response(ws, req, timeout=10.0) + + if response: + if response.get('status') == 'ok': + data = response.get('data', {}) + print(f"✅ 上传完成!") + print(f" 文件路径: {data.get('file_path')}") + print(f" 文件大小: {data.get('file_size')} bytes") + print(f" SHA256: {data.get('sha256')}") + print(f" 状态: {data.get('status')}") + else: + print(f"❌ 上传失败: {response.get('message')}") + else: + print("⚠️ 完成请求超时或失败") + +async def main(): + # 配置 + WS_URI = "ws://localhost:3001" # 修改为你的WebSocket地址 + FILE_PATH = r"C:\Users\nanaeo\Pictures\CatPicture.zip" #!!!!!!!!!!! + MAX_CONCURRENT = 8 # 最大并发上传数,可根据服务器性能调整 + + # 创建测试文件(如果不存在) + if not os.path.exists(FILE_PATH): + with open(FILE_PATH, 'w', encoding='utf-8') as f: + f.write("这是一个测试文件,用于演示并行文件分片上传功能。\n" * 100) + print(f"✅ 创建测试文件: {FILE_PATH}") + + print("=== 并行文件流上传测试 ===") + print(f"WebSocket URI: {WS_URI}") + print(f"文件路径: {FILE_PATH}") + print(f"最大并发数: {MAX_CONCURRENT}") + + try: + tester = FileUploadTester(WS_URI, FILE_PATH, MAX_CONCURRENT) + await tester.connect_and_upload() + print("🎉 测试完成!") + except Exception as e: + print(f"💥 测试出错: {e}") + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/onebot/action/stream/UploadFileStream.ts b/src/onebot/action/stream/UploadFileStream.ts new file mode 100644 index 00000000..24a5d8fa --- /dev/null +++ b/src/onebot/action/stream/UploadFileStream.ts @@ -0,0 +1,384 @@ +import { ActionName } from '@/onebot/action/router'; +import { OneBotAction } from '@/onebot/action/OneBotAction'; +import { Static, Type } from '@sinclair/typebox'; +import fs from 'fs'; +import { join as joinPath } from 'node:path'; +import { randomUUID } from 'crypto'; +import { createHash } from 'crypto'; + +// 简化配置 +const CONFIG = { + TIMEOUT: 10 * 60 * 1000, // 10分钟超时 + MEMORY_THRESHOLD: 10 * 1024 * 1024, // 10MB,超过使用磁盘 + MEMORY_LIMIT: 100 * 1024 * 1024 // 100MB内存总限制 +} as const; + +const SchemaData = Type.Object({ + stream_id: Type.String(), + chunk_data: Type.Optional(Type.String()), + chunk_index: Type.Optional(Type.Number()), + total_chunks: Type.Optional(Type.Number()), + file_size: Type.Optional(Type.Number()), + expected_sha256: Type.Optional(Type.String()), + is_complete: Type.Optional(Type.Boolean()), + filename: Type.Optional(Type.String()), + reset: Type.Optional(Type.Boolean()), + verify_only: Type.Optional(Type.Boolean()) +}); + +type Payload = Static; + +// 简化流状态接口 +interface StreamState { + id: string; + filename: string; + totalChunks: number; + receivedChunks: number; + missingChunks: Set; + + // 可选属性 + fileSize?: number; + expectedSha256?: string; + + // 存储策略 + useMemory: boolean; + memoryChunks?: Map; + tempDir?: string; + finalPath?: string; + + // 管理 + createdAt: number; + timeoutId: NodeJS.Timeout; +} + +interface StreamResult { + stream_id: string; + status: 'receiving' | 'completed' | 'error' | 'ready'; + received_chunks: number; + total_chunks?: number; + missing_chunks?: number[]; + file_path?: string; + file_size?: number; + sha256?: string; + message?: string; +} + +export class UploadFileStream extends OneBotAction { + override actionName = ActionName.UploadFileStream; + override payloadSchema = SchemaData; + + private static streams = new Map(); + private static memoryUsage = 0; + + async _handle(payload: Payload): Promise { + const { stream_id, reset, verify_only } = payload; + + try { + if (reset) return this.resetStream(stream_id); + if (verify_only) return this.verifyStream(stream_id); + + const stream = this.getOrCreateStream(payload); + + if (payload.chunk_data && payload.chunk_index !== undefined) { + const result = await this.processChunk(stream, payload.chunk_data, payload.chunk_index); + if (result.status === 'error') return result; + } + + if (payload.is_complete || stream.receivedChunks === stream.totalChunks) { + return await this.completeStream(stream); + } + + return this.getStreamStatus(stream); + + } catch (error) { + // 确保在任何错误情况下都清理资源 + this.cleanupStream(stream_id, true); + return this.errorResult(stream_id, error); + } + } + + private resetStream(streamId: string): StreamResult { + this.cleanupStream(streamId); + return { + stream_id: streamId, + status: 'ready', + received_chunks: 0, + message: 'Stream reset' + }; + } + + private verifyStream(streamId: string): StreamResult { + const stream = UploadFileStream.streams.get(streamId); + if (!stream) { + return this.errorResult(streamId, new Error('Stream not found')); + } + return this.getStreamStatus(stream); + } + + private getOrCreateStream(payload: Payload): StreamState { + let stream = UploadFileStream.streams.get(payload.stream_id); + + if (!stream) { + if (!payload.total_chunks) { + throw new Error('total_chunks required for new stream'); + } + stream = this.createStream(payload); + } + + return stream; + } + + private createStream(payload: Payload): StreamState { + const { stream_id, total_chunks, file_size, filename, expected_sha256 } = payload; + + const useMemory = this.shouldUseMemory(file_size); + if (useMemory && file_size && (UploadFileStream.memoryUsage + file_size) > CONFIG.MEMORY_LIMIT) { + throw new Error('Memory limit exceeded'); + } + + const stream: StreamState = { + id: stream_id, + filename: filename || `upload_${randomUUID()}`, + totalChunks: total_chunks!, + receivedChunks: 0, + missingChunks: new Set(Array.from({ length: total_chunks! }, (_, i) => i)), + fileSize: file_size, + expectedSha256: expected_sha256, + useMemory, + createdAt: Date.now(), + timeoutId: this.setupTimeout(stream_id) + }; + + try { + if (useMemory) { + stream.memoryChunks = new Map(); + if (file_size) UploadFileStream.memoryUsage += file_size; + } else { + this.setupDiskStorage(stream); + } + + UploadFileStream.streams.set(stream_id, stream); + return stream; + } catch (error) { + // 如果设置存储失败,清理已创建的资源 + clearTimeout(stream.timeoutId); + if (stream.tempDir && fs.existsSync(stream.tempDir)) { + try { + fs.rmSync(stream.tempDir, { recursive: true, force: true }); + } catch (cleanupError) { + console.error(`Failed to cleanup temp dir during creation error:`, cleanupError); + } + } + throw error; + } + } + + private shouldUseMemory(fileSize?: number): boolean { + return fileSize !== undefined && fileSize <= CONFIG.MEMORY_THRESHOLD; + } + + private setupDiskStorage(stream: StreamState): void { + const tempDir = joinPath(this.core.NapCatTempPath, `upload_${stream.id}`); + const finalPath = joinPath(this.core.NapCatTempPath, stream.filename); + + fs.mkdirSync(tempDir, { recursive: true }); + + stream.tempDir = tempDir; + stream.finalPath = finalPath; + } + + private setupTimeout(streamId: string): NodeJS.Timeout { + return setTimeout(() => { + console.log(`Stream ${streamId} timeout`); + this.cleanupStream(streamId); + }, CONFIG.TIMEOUT); + } + + private async processChunk(stream: StreamState, chunkData: string, chunkIndex: number): Promise { + // 验证索引 + if (chunkIndex < 0 || chunkIndex >= stream.totalChunks) { + throw new Error(`Invalid chunk index: ${chunkIndex}`); + } + + // 检查重复 + if (!stream.missingChunks.has(chunkIndex)) { + return this.getStreamStatus(stream, `Chunk ${chunkIndex} already received`); + } + + try { + const buffer = Buffer.from(chunkData, 'base64'); + + // 存储分片 + if (stream.useMemory) { + stream.memoryChunks!.set(chunkIndex, buffer); + } else { + const chunkPath = joinPath(stream.tempDir!, `${chunkIndex}.chunk`); + await fs.promises.writeFile(chunkPath, buffer); + } + + // 更新状态 + stream.missingChunks.delete(chunkIndex); + stream.receivedChunks++; + this.refreshTimeout(stream); + + return this.getStreamStatus(stream); + + } catch (error) { + throw new Error(`Chunk processing failed: ${error instanceof Error ? error.message : 'Unknown error'}`); + } + } + + private refreshTimeout(stream: StreamState): void { + clearTimeout(stream.timeoutId); + stream.timeoutId = this.setupTimeout(stream.id); + } + + private getStreamStatus(stream: StreamState, message?: string): StreamResult { + const missingChunks = Array.from(stream.missingChunks).sort(); + + return { + stream_id: stream.id, + status: 'receiving', + received_chunks: stream.receivedChunks, + total_chunks: stream.totalChunks, + missing_chunks: missingChunks.length > 0 ? missingChunks : undefined, + file_size: stream.fileSize, + message + }; + } + + private async completeStream(stream: StreamState): Promise { + try { + // 合并分片 + const finalBuffer = stream.useMemory ? + await this.mergeMemoryChunks(stream) : + await this.mergeDiskChunks(stream); + + // 验证SHA256 + const sha256 = this.validateSha256(stream, finalBuffer); + + // 保存文件 + const finalPath = stream.finalPath || joinPath(this.core.NapCatTempPath, stream.filename); + await fs.promises.writeFile(finalPath, finalBuffer); + + // 清理资源但保留文件 + this.cleanupStream(stream.id, false); + + return { + stream_id: stream.id, + status: 'completed', + received_chunks: stream.receivedChunks, + total_chunks: stream.totalChunks, + file_path: finalPath, + file_size: finalBuffer.length, + sha256, + message: 'Upload completed' + }; + + } catch (error) { + throw new Error(`Stream completion failed: ${error instanceof Error ? error.message : 'Unknown error'}`); + } + } + + private async mergeMemoryChunks(stream: StreamState): Promise { + const chunks: Buffer[] = []; + for (let i = 0; i < stream.totalChunks; i++) { + const chunk = stream.memoryChunks!.get(i); + if (!chunk) throw new Error(`Missing memory chunk ${i}`); + chunks.push(chunk); + } + return Buffer.concat(chunks); + } + + private async mergeDiskChunks(stream: StreamState): Promise { + const chunks: Buffer[] = []; + for (let i = 0; i < stream.totalChunks; i++) { + const chunkPath = joinPath(stream.tempDir!, `${i}.chunk`); + if (!fs.existsSync(chunkPath)) throw new Error(`Missing chunk file ${i}`); + chunks.push(await fs.promises.readFile(chunkPath)); + } + return Buffer.concat(chunks); + } + + private validateSha256(stream: StreamState, buffer: Buffer): string | undefined { + if (!stream.expectedSha256) return undefined; + + const actualSha256 = createHash('sha256').update(buffer).digest('hex'); + if (actualSha256 !== stream.expectedSha256) { + throw new Error(`SHA256 mismatch. Expected: ${stream.expectedSha256}, Got: ${actualSha256}`); + } + return actualSha256; + } + + private cleanupStream(streamId: string, deleteFinalFile = true): void { + const stream = UploadFileStream.streams.get(streamId); + if (!stream) return; + + try { + // 清理超时 + clearTimeout(stream.timeoutId); + + // 清理内存 + if (stream.useMemory) { + if (stream.fileSize) { + UploadFileStream.memoryUsage = Math.max(0, UploadFileStream.memoryUsage - stream.fileSize); + } + stream.memoryChunks?.clear(); + } + + // 清理临时文件夹及其所有内容 + if (stream.tempDir) { + try { + if (fs.existsSync(stream.tempDir)) { + fs.rmSync(stream.tempDir, { recursive: true, force: true }); + console.log(`Cleaned up temp directory: ${stream.tempDir}`); + } + } catch (error) { + console.error(`Failed to cleanup temp directory ${stream.tempDir}:`, error); + } + } + + // 删除最终文件(如果需要) + if (deleteFinalFile && stream.finalPath) { + try { + if (fs.existsSync(stream.finalPath)) { + fs.unlinkSync(stream.finalPath); + console.log(`Deleted final file: ${stream.finalPath}`); + } + } catch (error) { + console.error(`Failed to delete final file ${stream.finalPath}:`, error); + } + } + + } catch (error) { + console.error(`Cleanup error for stream ${streamId}:`, error); + } finally { + UploadFileStream.streams.delete(streamId); + console.log(`Stream ${streamId} cleaned up`); + } + } + + private errorResult(streamId: string, error: any): StreamResult { + return { + stream_id: streamId, + status: 'error', + received_chunks: 0, + message: error instanceof Error ? error.message : 'Unknown error' + }; + } + + // 全局状态查询 + static getGlobalStatus() { + return { + activeStreams: this.streams.size, + memoryUsageMB: Math.round(this.memoryUsage / 1024 / 1024 * 100) / 100, + streams: Array.from(this.streams.values()).map(stream => ({ + streamId: stream.id, + filename: stream.filename, + progress: `${stream.receivedChunks}/${stream.totalChunks}`, + useMemory: stream.useMemory, + createdAt: new Date(stream.createdAt).toISOString() + })) + }; + } +} diff --git a/src/onebot/action/stream/UploadFileStreanConcurent.py b/src/onebot/action/stream/UploadFileStreanConcurent.py new file mode 100644 index 00000000..38ac932e --- /dev/null +++ b/src/onebot/action/stream/UploadFileStreanConcurent.py @@ -0,0 +1,201 @@ +import asyncio +import websockets +import json +import base64 +import hashlib +import os +from typing import Optional + +class FileUploadTester: + def __init__(self, ws_uri: str, file_path: str): + self.ws_uri = ws_uri + self.file_path = file_path + self.chunk_size = 64 * 1024 # 64KB per chunk + self.stream_id = None + + async def connect_and_upload(self): + """连接到WebSocket并上传文件""" + async with websockets.connect(self.ws_uri) as ws: + print(f"已连接到 {self.ws_uri}") + + # 准备文件数据 + file_info = self.prepare_file() + if not file_info: + return + + print(f"文件信息: {file_info['filename']}, 大小: {file_info['file_size']} bytes, 块数: {file_info['total_chunks']}") + + # 生成stream_id + self.stream_id = f"upload_{hash(file_info['filename'] + str(file_info['file_size']))}" + print(f"Stream ID: {self.stream_id}") + + # 重置流(如果存在) + await self.reset_stream(ws) + + # 开始分块上传 + await self.upload_chunks(ws, file_info) + + # 完成上传 + await self.complete_upload(ws) + + # 等待一些响应 + await self.listen_for_responses(ws) + + def prepare_file(self): + """准备文件信息""" + if not os.path.exists(self.file_path): + print(f"文件不存在: {self.file_path}") + return None + + file_size = os.path.getsize(self.file_path) + filename = os.path.basename(self.file_path) + total_chunks = (file_size + self.chunk_size - 1) // self.chunk_size + + # 计算SHA256 + sha256_hash = hashlib.sha256() + with open(self.file_path, 'rb') as f: + for chunk in iter(lambda: f.read(8192), b""): + sha256_hash.update(chunk) + expected_sha256 = sha256_hash.hexdigest() + + return { + 'filename': filename, + 'file_size': file_size, + 'total_chunks': total_chunks, + 'expected_sha256': expected_sha256 + } + + async def reset_stream(self, ws): + """重置流""" + req = { + "action": "upload_file_stream", + "params": { + "stream_id": self.stream_id, + "reset": True + }, + "echo": "reset" + } + await ws.send(json.dumps(req)) + print("发送重置请求...") + + async def upload_chunks(self, ws, file_info): + """上传文件块""" + with open(self.file_path, 'rb') as f: + for chunk_index in range(file_info['total_chunks']): + # 读取块数据 + chunk_data = f.read(self.chunk_size) + chunk_base64 = base64.b64encode(chunk_data).decode('utf-8') + + # 准备请求 + req = { + "action": "upload_file_stream", + "params": { + "stream_id": self.stream_id, + "chunk_data": chunk_base64, + "chunk_index": chunk_index, + "total_chunks": file_info['total_chunks'], + "file_size": file_info['file_size'], + "filename": file_info['filename'], + #"expected_sha256": file_info['expected_sha256'] + }, + "echo": f"chunk_{chunk_index}" + } + + await ws.send(json.dumps(req)) + print(f"发送块 {chunk_index + 1}/{file_info['total_chunks']} ({len(chunk_data)} bytes)") + + # 等待响应 + try: + response = await asyncio.wait_for(ws.recv(), timeout=5.0) + resp_data = json.loads(response) + if resp_data.get('echo') == f"chunk_{chunk_index}": + if resp_data.get('status') == 'ok': + data = resp_data.get('data', {}) + print(f" -> 状态: {data.get('status')}, 已接收: {data.get('received_chunks')}") + else: + print(f" -> 错误: {resp_data.get('message')}") + except asyncio.TimeoutError: + print(f" -> 块 {chunk_index} 响应超时") + + # 小延迟避免过快发送 + await asyncio.sleep(0.1) + + async def complete_upload(self, ws): + """完成上传""" + req = { + "action": "upload_file_stream", + "params": { + "stream_id": self.stream_id, + "is_complete": True + }, + "echo": "complete" + } + await ws.send(json.dumps(req)) + print("发送完成请求...") + + async def verify_stream(self, ws): + """验证流状态""" + req = { + "action": "upload_file_stream", + "params": { + "stream_id": self.stream_id, + "verify_only": True + }, + "echo": "verify" + } + await ws.send(json.dumps(req)) + print("发送验证请求...") + + async def listen_for_responses(self, ws, duration=10): + """监听响应""" + print(f"监听响应 {duration} 秒...") + try: + end_time = asyncio.get_event_loop().time() + duration + while asyncio.get_event_loop().time() < end_time: + try: + msg = await asyncio.wait_for(ws.recv(), timeout=1.0) + resp_data = json.loads(msg) + echo = resp_data.get('echo', 'unknown') + + if echo == "complete": + if resp_data.get('status') == 'ok': + data = resp_data.get('data', {}) + print(f"✅ 上传完成!") + print(f" 文件路径: {data.get('file_path')}") + print(f" 文件大小: {data.get('file_size')} bytes") + print(f" SHA256: {data.get('sha256')}") + print(f" 状态: {data.get('status')}") + else: + print(f"❌ 上传失败: {resp_data.get('message')}") + elif echo == "verify": + if resp_data.get('status') == 'ok': + data = resp_data.get('data', {}) + print(f"🔍 验证结果: {data}") + elif echo == "reset": + print(f"🔄 重置完成: {resp_data}") + else: + print(f"📨 收到响应 [{echo}]: {resp_data}") + + except asyncio.TimeoutError: + continue + + except Exception as e: + print(f"监听出错: {e}") + +async def main(): + # 配置 + WS_URI = "ws://localhost:3001" # 修改为你的WebSocket地址 + FILE_PATH = "C:\\Users\\nanaeo\\Pictures\\CatPicture.zip" + + print("=== 文件流上传测试 ===") + print(f"WebSocket URI: {WS_URI}") + print(f"文件路径: {FILE_PATH}") + + try: + tester = FileUploadTester(WS_URI, FILE_PATH) + await tester.connect_and_upload() + except Exception as e: + print(f"测试出错: {e}") + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/onebot/network/http-server-sse.ts b/src/onebot/network/http-server-sse.ts index 47414871..da3181fc 100644 --- a/src/onebot/network/http-server-sse.ts +++ b/src/onebot/network/http-server-sse.ts @@ -9,7 +9,7 @@ export class OB11HttpSSEServerAdapter extends OB11HttpServerAdapter { if (req.path === '/_events') { this.createSseSupport(req, res); } else { - super.httpApiRequest(req, res); + super.httpApiRequest(req, res, true); } } diff --git a/src/onebot/network/http-server.ts b/src/onebot/network/http-server.ts index b197bdd2..d30496f3 100644 --- a/src/onebot/network/http-server.ts +++ b/src/onebot/network/http-server.ts @@ -104,7 +104,7 @@ export class OB11HttpServerAdapter extends IOB11NetworkAdapter } } - async httpApiRequest(req: Request, res: Response) { + async httpApiRequest(req: Request, res: Response, request_sse: boolean = false) { let payload = req.body; if (req.method == 'get') { payload = req.query; @@ -117,17 +117,31 @@ export class OB11HttpServerAdapter extends IOB11NetworkAdapter return res.json(hello); } const actionName = req.path.split('/')[1]; + const payload_echo = payload['echo']; + const real_echo = payload_echo ?? Math.random().toString(36).substring(2, 15); // eslint-disable-next-line @typescript-eslint/no-explicit-any const action = this.actions.get(actionName as any); if (action) { try { - const result = await action.handle(payload, this.name, this.config); + let stream = false; + const result = await action.handle(payload, this.name, this.config, { + send: request_sse ? async (data: object) => { + this.onEvent({ ...OB11Response.ok(data, real_echo), type: 'sse-action' } as unknown as OB11EmitEventContent); + } : async (data: object) => { + stream = true; + res.write(JSON.stringify({ ...OB11Response.ok(data, real_echo), type: 'stream-action' }) + "\r\n\r\n"); + } + }, real_echo); + if (stream) { + res.write(JSON.stringify({ ...result, type: 'stream-action' }) + "\r\n\r\n"); + return res.end(); + }; return res.json(result); } catch (error: unknown) { - return res.json(OB11Response.error((error as Error)?.stack?.toString() || (error as Error)?.message || 'Error Handle', 200)); + return res.json(OB11Response.error((error as Error)?.stack?.toString() || (error as Error)?.message || 'Error Handle', 200, real_echo)); } } else { - return res.json(OB11Response.error('不支持的Api ' + actionName, 200)); + return res.json(OB11Response.error('不支持的Api ' + actionName, 200, real_echo)); } } diff --git a/src/onebot/network/websocket-client.ts b/src/onebot/network/websocket-client.ts index 132a868e..26e4f343 100644 --- a/src/onebot/network/websocket-client.ts +++ b/src/onebot/network/websocket-client.ts @@ -151,7 +151,11 @@ export class OB11WebSocketClientAdapter extends IOB11NetworkAdapter(OB11Response.error('不支持的Api ' + receiveData.action, 1404, echo)); return; } - const retdata = await action.websocketHandle(receiveData.params, echo ?? '', this.name, this.config); + const retdata = await action.websocketHandle(receiveData.params, echo ?? '', this.name, this.config, { + send: async (data: object) => { + this.checkStateAndReply({ ...OB11Response.ok(data, echo ?? '') }); + } + }); this.checkStateAndReply({ ...retdata }); } async reload(newConfig: WebsocketClientConfig) { diff --git a/src/onebot/network/websocket-server.ts b/src/onebot/network/websocket-server.ts index e96157bf..6bdf38eb 100644 --- a/src/onebot/network/websocket-server.ts +++ b/src/onebot/network/websocket-server.ts @@ -186,7 +186,11 @@ export class OB11WebSocketServerAdapter extends IOB11NetworkAdapter(OB11Response.error('不支持的API ' + receiveData.action, 1404, echo), wsClient); return; } - const retdata = await action.websocketHandle(receiveData.params, echo ?? '', this.name, this.config); + const retdata = await action.websocketHandle(receiveData.params, echo ?? '', this.name, this.config, { + send: async (data: object) => { + this.checkStateAndReply({ ...OB11Response.ok(data, echo ?? '') }, wsClient); + } + }); this.checkStateAndReply({ ...retdata }, wsClient); }