mirror of
https://github.com/NapNeko/NapCatQQ.git
synced 2025-12-19 05:05:44 +08:00
feat: 标准化
This commit is contained in:
parent
32bba007cd
commit
2161ec5fa7
@ -4,9 +4,10 @@ import { NapCatCore } from '@/core';
|
||||
import { NapCatOneBot11Adapter, OB11Return } from '@/onebot';
|
||||
import { NetworkAdapterConfig } from '../config/config';
|
||||
import { TSchema } from '@sinclair/typebox';
|
||||
import { StreamPacket, StreamPacketBasic, StreamStatus } from './stream/StreamBasic';
|
||||
|
||||
export class OB11Response {
|
||||
private static createResponse<T>(data: T, status: string, retcode: number, message: string = '', echo: unknown = null): OB11Return<T> {
|
||||
private static createResponse<T>(data: T, status: string, retcode: number, message: string = '', echo: unknown = null, useStream: boolean = false): OB11Return<T> {
|
||||
return {
|
||||
status,
|
||||
retcode,
|
||||
@ -14,23 +15,24 @@ export class OB11Response {
|
||||
message,
|
||||
wording: message,
|
||||
echo,
|
||||
stream: useStream ? 'stream-action' : 'normal-action'
|
||||
};
|
||||
}
|
||||
|
||||
static res<T>(data: T, status: string, retcode: number, message: string = ''): OB11Return<T> {
|
||||
return this.createResponse(data, status, retcode, message);
|
||||
static res<T>(data: T, status: string, retcode: number, message: string = '', echo: unknown = null, useStream: boolean = false): OB11Return<T> {
|
||||
return this.createResponse(data, status, retcode, message, echo, useStream);
|
||||
}
|
||||
|
||||
static ok<T>(data: T, echo: unknown = null): OB11Return<T> {
|
||||
return this.createResponse(data, 'ok', 0, '', echo);
|
||||
static ok<T>(data: T, echo: unknown = null, useStream: boolean = false): OB11Return<T> {
|
||||
return this.createResponse(data, 'ok', 0, '', echo, useStream);
|
||||
}
|
||||
|
||||
static error(err: string, retcode: number, echo: unknown = null): OB11Return<null> {
|
||||
return this.createResponse(null, 'failed', retcode, err, echo);
|
||||
static error(err: string, retcode: number, echo: unknown = null, useStream: boolean = false): OB11Return<null | StreamPacketBasic> {
|
||||
return this.createResponse(useStream ? { type: StreamStatus.Error, data_type: 'error' } : null, 'failed', retcode, err, echo, useStream);
|
||||
}
|
||||
}
|
||||
export interface OneBotRequestToolkit<T = unknown> {
|
||||
send: (data: T) => Promise<void>;
|
||||
export abstract class OneBotRequestToolkit {
|
||||
abstract send<T>(packet: StreamPacket<T>): Promise<void>;
|
||||
}
|
||||
export abstract class OneBotAction<PayloadType, ReturnDataType> {
|
||||
actionName: typeof ActionName[keyof typeof ActionName] = ActionName.Unknown;
|
||||
@ -38,6 +40,7 @@ export abstract class OneBotAction<PayloadType, ReturnDataType> {
|
||||
private validate?: ValidateFunction<unknown> = undefined;
|
||||
payloadSchema?: TSchema = undefined;
|
||||
obContext: NapCatOneBot11Adapter;
|
||||
useStream: boolean = false;
|
||||
|
||||
constructor(obContext: NapCatOneBot11Adapter, core: NapCatCore) {
|
||||
this.obContext = obContext;
|
||||
@ -59,31 +62,31 @@ export abstract class OneBotAction<PayloadType, ReturnDataType> {
|
||||
return { valid: true };
|
||||
}
|
||||
|
||||
public async handle(payload: PayloadType, adaptername: string, config: NetworkAdapterConfig, req: OneBotRequestToolkit = { send: async () => { } }, echo?: string): Promise<OB11Return<ReturnDataType | null>> {
|
||||
public async handle(payload: PayloadType, adaptername: string, config: NetworkAdapterConfig, req: OneBotRequestToolkit = { send: async () => { } }, echo?: string): Promise<OB11Return<ReturnDataType | StreamPacketBasic | null>> {
|
||||
const result = await this.check(payload);
|
||||
if (!result.valid) {
|
||||
return OB11Response.error(result.message, 400);
|
||||
}
|
||||
try {
|
||||
const resData = await this._handle(payload, adaptername, config, req);
|
||||
return OB11Response.ok(resData, echo);
|
||||
return OB11Response.ok(resData, echo, this.useStream);
|
||||
} catch (e: unknown) {
|
||||
this.core.context.logger.logError('发生错误', e);
|
||||
return OB11Response.error((e as Error).message.toString() || (e as Error)?.stack?.toString() || '未知错误,可能操作超时', 200, echo);
|
||||
return OB11Response.error((e as Error).message.toString() || (e as Error)?.stack?.toString() || '未知错误,可能操作超时', 200, echo, this.useStream);
|
||||
}
|
||||
}
|
||||
|
||||
public async websocketHandle(payload: PayloadType, echo: unknown, adaptername: string, config: NetworkAdapterConfig, req: OneBotRequestToolkit = { send: async () => { } }): Promise<OB11Return<ReturnDataType | null>> {
|
||||
public async websocketHandle(payload: PayloadType, echo: unknown, adaptername: string, config: NetworkAdapterConfig, req: OneBotRequestToolkit = { send: async () => { } }): Promise<OB11Return<ReturnDataType | StreamPacketBasic | null>> {
|
||||
const result = await this.check(payload);
|
||||
if (!result.valid) {
|
||||
return OB11Response.error(result.message, 1400, echo);
|
||||
return OB11Response.error(result.message, 1400, echo, this.useStream);
|
||||
}
|
||||
try {
|
||||
const resData = await this._handle(payload, adaptername, config, req);
|
||||
return OB11Response.ok(resData, echo);
|
||||
return OB11Response.ok(resData, echo, this.useStream);
|
||||
} catch (e: unknown) {
|
||||
this.core.context.logger.logError('发生错误', e);
|
||||
return OB11Response.error(((e as Error).message.toString() || (e as Error).stack?.toString()) ?? 'Error', 1200, echo);
|
||||
return OB11Response.error(((e as Error).message.toString() || (e as Error).stack?.toString()) ?? 'Error', 1200, echo, this.useStream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -132,7 +132,7 @@ import { SetGroupAlbumMediaLike } from './extends/SetGroupAlbumMediaLike';
|
||||
import { DelGroupAlbumMedia } from './extends/DelGroupAlbumMedia';
|
||||
import { CleanStreamTempFile } from './stream/CleanStreamTempFile';
|
||||
import { DownloadFileStream } from './stream/DownloadFileStream';
|
||||
import { TestStreamDownload } from './stream/TestStreamDownload';
|
||||
import { TestDownloadStream } from './stream/TestStreamDownload';
|
||||
import { UploadFileStream } from './stream/UploadFileStream';
|
||||
|
||||
export function createActionMap(obContext: NapCatOneBot11Adapter, core: NapCatCore) {
|
||||
@ -140,7 +140,7 @@ export function createActionMap(obContext: NapCatOneBot11Adapter, core: NapCatCo
|
||||
const actionHandlers = [
|
||||
new CleanStreamTempFile(obContext, core),
|
||||
new DownloadFileStream(obContext, core),
|
||||
new TestStreamDownload(obContext, core),
|
||||
new TestDownloadStream(obContext, core),
|
||||
new UploadFileStream(obContext, core),
|
||||
new DelGroupAlbumMedia(obContext, core),
|
||||
new SetGroupAlbumMediaLike(obContext, core),
|
||||
|
||||
@ -10,10 +10,14 @@ export interface InvalidCheckResult {
|
||||
}
|
||||
|
||||
export const ActionName = {
|
||||
// 所有 Normal Stream Api 表示并未流传输 表示与流传输有关
|
||||
CleanStreamTempFile: 'clean_stream_temp_file',
|
||||
TestStreamDownload: 'test_stream_download',
|
||||
|
||||
// 所有 Upload/Download Stream Api 应当 _stream 结尾
|
||||
TestDownloadStream: 'test_download_stream',
|
||||
UploadFileStream: 'upload_file_stream',
|
||||
DownloadFileStream: 'download_file_stream',
|
||||
|
||||
DelGroupAlbumMedia: 'del_group_album_media',
|
||||
SetGroupAlbumMediaLike: 'set_group_album_media_like',
|
||||
DoGroupAlbumComment: 'do_group_album_comment',
|
||||
|
||||
@ -2,6 +2,7 @@ import { ActionName } from '@/onebot/action/router';
|
||||
import { OneBotAction, OneBotRequestToolkit } from '@/onebot/action/OneBotAction';
|
||||
import { Static, Type } from '@sinclair/typebox';
|
||||
import { NetworkAdapterConfig } from '@/onebot/config/config';
|
||||
import { StreamPacket, StreamStatus } from './StreamBasic';
|
||||
import fs from 'fs';
|
||||
import { FileNapCatOneBotUUID } from '@/common/file-uuid';
|
||||
const SchemaData = Type.Object({
|
||||
@ -12,90 +13,121 @@ const SchemaData = Type.Object({
|
||||
|
||||
type Payload = Static<typeof SchemaData>;
|
||||
|
||||
export class DownloadFileStream extends OneBotAction<Payload, void> {
|
||||
// 下载结果类型
|
||||
interface DownloadResult {
|
||||
// 文件信息
|
||||
file_name?: string;
|
||||
file_size?: number;
|
||||
chunk_size?: number;
|
||||
|
||||
// 分片数据
|
||||
index?: number;
|
||||
data?: string;
|
||||
size?: number;
|
||||
progress?: number;
|
||||
base64_size?: number;
|
||||
|
||||
// 完成信息
|
||||
total_chunks?: number;
|
||||
total_bytes?: number;
|
||||
message?: string;
|
||||
data_type?: 'file_info' | 'file_chunk' | 'file_complete';
|
||||
}
|
||||
|
||||
export class DownloadFileStream extends OneBotAction<Payload, StreamPacket<DownloadResult>> {
|
||||
override actionName = ActionName.DownloadFileStream;
|
||||
override payloadSchema = SchemaData;
|
||||
override useStream = true;
|
||||
|
||||
async _handle(payload: Payload, _adaptername: string, _config: NetworkAdapterConfig, req: OneBotRequestToolkit) {
|
||||
payload.file ||= payload.file_id || '';
|
||||
const chunkSize = payload.chunk_size || 64 * 1024;
|
||||
let downloadPath = '';
|
||||
let fileName = '';
|
||||
let fileSize = 0;
|
||||
async _handle(payload: Payload, _adaptername: string, _config: NetworkAdapterConfig, req: OneBotRequestToolkit): Promise<StreamPacket<DownloadResult>> {
|
||||
try {
|
||||
payload.file ||= payload.file_id || '';
|
||||
const chunkSize = payload.chunk_size || 64 * 1024;
|
||||
let downloadPath = '';
|
||||
let fileName = '';
|
||||
let fileSize = 0;
|
||||
|
||||
//接收消息标记模式
|
||||
const contextMsgFile = FileNapCatOneBotUUID.decode(payload.file);
|
||||
if (contextMsgFile && contextMsgFile.msgId && contextMsgFile.elementId) {
|
||||
const { peer, msgId, elementId } = contextMsgFile;
|
||||
downloadPath = await this.core.apis.FileApi.downloadMedia(msgId, peer.chatType, peer.peerUid, elementId, '', '');
|
||||
const rawMessage = (await this.core.apis.MsgApi.getMsgsByMsgId(peer, [msgId]))?.msgList
|
||||
.find(msg => msg.msgId === msgId);
|
||||
const mixElement = rawMessage?.elements.find(e => e.elementId === elementId);
|
||||
const mixElementInner = mixElement?.videoElement ?? mixElement?.fileElement ?? mixElement?.pttElement ?? mixElement?.picElement;
|
||||
if (!mixElementInner) throw new Error('element not found');
|
||||
fileSize = parseInt(mixElementInner.fileSize?.toString() ?? '0');
|
||||
fileName = mixElementInner.fileName ?? '';
|
||||
}
|
||||
//群文件模式
|
||||
else if (FileNapCatOneBotUUID.decodeModelId(payload.file)) {
|
||||
const contextModelIdFile = FileNapCatOneBotUUID.decodeModelId(payload.file);
|
||||
if (contextModelIdFile && contextModelIdFile.modelId) {
|
||||
const { peer, modelId } = contextModelIdFile;
|
||||
downloadPath = await this.core.apis.FileApi.downloadFileForModelId(peer, modelId, '');
|
||||
//接收消息标记模式
|
||||
const contextMsgFile = FileNapCatOneBotUUID.decode(payload.file);
|
||||
if (contextMsgFile && contextMsgFile.msgId && contextMsgFile.elementId) {
|
||||
const { peer, msgId, elementId } = contextMsgFile;
|
||||
downloadPath = await this.core.apis.FileApi.downloadMedia(msgId, peer.chatType, peer.peerUid, elementId, '', '');
|
||||
const rawMessage = (await this.core.apis.MsgApi.getMsgsByMsgId(peer, [msgId]))?.msgList
|
||||
.find(msg => msg.msgId === msgId);
|
||||
const mixElement = rawMessage?.elements.find(e => e.elementId === elementId);
|
||||
const mixElementInner = mixElement?.videoElement ?? mixElement?.fileElement ?? mixElement?.pttElement ?? mixElement?.picElement;
|
||||
if (!mixElementInner) throw new Error('element not found');
|
||||
fileSize = parseInt(mixElementInner.fileSize?.toString() ?? '0');
|
||||
fileName = mixElementInner.fileName ?? '';
|
||||
}
|
||||
}
|
||||
//搜索名字模式
|
||||
else {
|
||||
const searchResult = (await this.core.apis.FileApi.searchForFile([payload.file]));
|
||||
if (searchResult) {
|
||||
downloadPath = await this.core.apis.FileApi.downloadFileById(searchResult.id, parseInt(searchResult.fileSize));
|
||||
fileSize = parseInt(searchResult.fileSize);
|
||||
fileName = searchResult.fileName;
|
||||
//群文件模式
|
||||
else if (FileNapCatOneBotUUID.decodeModelId(payload.file)) {
|
||||
const contextModelIdFile = FileNapCatOneBotUUID.decodeModelId(payload.file);
|
||||
if (contextModelIdFile && contextModelIdFile.modelId) {
|
||||
const { peer, modelId } = contextModelIdFile;
|
||||
downloadPath = await this.core.apis.FileApi.downloadFileForModelId(peer, modelId, '');
|
||||
}
|
||||
}
|
||||
//搜索名字模式
|
||||
else {
|
||||
const searchResult = (await this.core.apis.FileApi.searchForFile([payload.file]));
|
||||
if (searchResult) {
|
||||
downloadPath = await this.core.apis.FileApi.downloadFileById(searchResult.id, parseInt(searchResult.fileSize));
|
||||
fileSize = parseInt(searchResult.fileSize);
|
||||
fileName = searchResult.fileName;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!downloadPath) {
|
||||
throw new Error('file not found');
|
||||
}
|
||||
if (!downloadPath) {
|
||||
throw new Error('file not found');
|
||||
}
|
||||
|
||||
// 获取文件大小
|
||||
const stats = await fs.promises.stat(downloadPath);
|
||||
const totalSize = fileSize || stats.size;
|
||||
|
||||
// 发送文件信息
|
||||
req.send({
|
||||
type: 'file_info',
|
||||
file_name: fileName,
|
||||
file_size: totalSize,
|
||||
chunk_size: chunkSize
|
||||
});
|
||||
|
||||
// 创建读取流并分块发送
|
||||
const readStream = fs.createReadStream(downloadPath, { highWaterMark: chunkSize });
|
||||
let chunkIndex = 0;
|
||||
let bytesRead = 0;
|
||||
|
||||
for await (const chunk of readStream) {
|
||||
const base64Chunk = chunk.toString('base64');
|
||||
bytesRead += chunk.length;
|
||||
// 获取文件大小
|
||||
const stats = await fs.promises.stat(downloadPath);
|
||||
const totalSize = fileSize || stats.size;
|
||||
|
||||
// 发送文件信息
|
||||
req.send({
|
||||
type: 'chunk',
|
||||
index: chunkIndex,
|
||||
data: base64Chunk,
|
||||
size: chunk.length,
|
||||
progress: Math.round((bytesRead / totalSize) * 100),
|
||||
base64_size: base64Chunk.length
|
||||
type: StreamStatus.Stream,
|
||||
data_type: 'file_info',
|
||||
file_name: fileName,
|
||||
file_size: totalSize,
|
||||
chunk_size: chunkSize
|
||||
});
|
||||
|
||||
chunkIndex++;
|
||||
}
|
||||
// 创建读取流并分块发送
|
||||
const readStream = fs.createReadStream(downloadPath, { highWaterMark: chunkSize });
|
||||
let chunkIndex = 0;
|
||||
let bytesRead = 0;
|
||||
|
||||
// 发送完成信号
|
||||
req.send({
|
||||
type: 'complete',
|
||||
total_chunks: chunkIndex,
|
||||
total_bytes: bytesRead
|
||||
});
|
||||
for await (const chunk of readStream) {
|
||||
const base64Chunk = chunk.toString('base64');
|
||||
bytesRead += chunk.length;
|
||||
|
||||
req.send({
|
||||
type: StreamStatus.Stream,
|
||||
data_type: 'file_chunk',
|
||||
index: chunkIndex,
|
||||
data: base64Chunk,
|
||||
size: chunk.length,
|
||||
progress: Math.round((bytesRead / totalSize) * 100),
|
||||
base64_size: base64Chunk.length
|
||||
});
|
||||
|
||||
chunkIndex++;
|
||||
}
|
||||
|
||||
// 返回完成状态
|
||||
return {
|
||||
type: StreamStatus.Response,
|
||||
data_type: 'file_complete',
|
||||
total_chunks: chunkIndex,
|
||||
total_bytes: bytesRead,
|
||||
message: 'Download completed'
|
||||
};
|
||||
|
||||
} catch (error) {
|
||||
throw new Error(`Download failed: ${(error as Error).message}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,9 +1,3 @@
|
||||
# Stream-Api
|
||||
|
||||
## 流式接口总览
|
||||
clean_stream_temp_file 清理流临时文件
|
||||
test_stream_download 测试流传输
|
||||
download_file_stream 下载文件流传输 get_file替代
|
||||
upload_file_stream 上传文件流
|
||||
|
||||
## 使用
|
||||
## 流式接口
|
||||
16
src/onebot/action/stream/StreamBasic.ts
Normal file
16
src/onebot/action/stream/StreamBasic.ts
Normal file
@ -0,0 +1,16 @@
|
||||
import { OneBotAction, OneBotRequestToolkit } from "../OneBotAction";
|
||||
import { NetworkAdapterConfig } from "@/onebot/config/config";
|
||||
export type StreamPacketBasic = {
|
||||
type: StreamStatus;
|
||||
data_type?: string;
|
||||
};
|
||||
export type StreamPacket<T> = T & StreamPacketBasic;
|
||||
export enum StreamStatus {
|
||||
Stream = 'stream', // 分片流数据包
|
||||
Response = 'response', // 流最终响应
|
||||
Reset = 'reset', // 重置流
|
||||
Error = 'error' // 流错误
|
||||
}
|
||||
export abstract class BasicStream<T, R> extends OneBotAction<T, StreamPacket<R>> {
|
||||
abstract override _handle(_payload: T, _adaptername: string, _config: NetworkAdapterConfig, req: OneBotRequestToolkit): Promise<StreamPacket<R>>;
|
||||
}
|
||||
@ -2,6 +2,7 @@ import { ActionName } from '@/onebot/action/router';
|
||||
import { OneBotAction, OneBotRequestToolkit } from '@/onebot/action/OneBotAction';
|
||||
import { Static, Type } from '@sinclair/typebox';
|
||||
import { NetworkAdapterConfig } from '@/onebot/config/config';
|
||||
import { StreamPacket, StreamStatus } from './StreamBasic';
|
||||
|
||||
const SchemaData = Type.Object({
|
||||
|
||||
@ -9,15 +10,20 @@ const SchemaData = Type.Object({
|
||||
|
||||
type Payload = Static<typeof SchemaData>;
|
||||
|
||||
export class TestStreamDownload extends OneBotAction<Payload, string> {
|
||||
override actionName = ActionName.TestStreamDownload;
|
||||
export class TestDownloadStream extends OneBotAction<Payload, StreamPacket<{ data: string }>> {
|
||||
override actionName = ActionName.TestDownloadStream;
|
||||
override payloadSchema = SchemaData;
|
||||
override useStream = true;
|
||||
|
||||
async _handle(_payload: Payload, _adaptername: string, _config: NetworkAdapterConfig, req: OneBotRequestToolkit) {
|
||||
for (let i = 0; i < 10; i++) {
|
||||
req.send({ index: i });
|
||||
req.send({ type: StreamStatus.Stream, data: `这是第 ${i + 1} 片流数据`, data_type: 'data_chunk' });
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
}
|
||||
return 'done';
|
||||
return {
|
||||
type: StreamStatus.Response,
|
||||
data_type: 'data_complete',
|
||||
data: '流传输完成'
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,368 +0,0 @@
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from typing import Optional, Dict, Set
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class ChunkInfo:
|
||||
index: int
|
||||
data: bytes
|
||||
size: int
|
||||
retry_count: int = 0
|
||||
uploaded: bool = False
|
||||
|
||||
class FileUploadTester:
|
||||
def __init__(self, ws_uri: str, file_path: str, max_concurrent: int = 5):
|
||||
self.ws_uri = ws_uri
|
||||
self.file_path = file_path
|
||||
self.chunk_size = 64 * 1024 # 64KB per chunk
|
||||
self.max_concurrent = max_concurrent # 最大并发数
|
||||
self.max_retries = 3 # 最大重试次数
|
||||
self.stream_id = None
|
||||
self.chunks: Dict[int, ChunkInfo] = {}
|
||||
self.upload_semaphore = asyncio.Semaphore(max_concurrent)
|
||||
self.failed_chunks: Set[int] = set()
|
||||
|
||||
# 消息路由机制
|
||||
self.response_futures: Dict[str, asyncio.Future] = {}
|
||||
self.message_receiver_task = None
|
||||
|
||||
async def connect_and_upload(self):
|
||||
"""连接到WebSocket并上传文件"""
|
||||
async with websockets.connect(self.ws_uri) as ws:
|
||||
print(f"已连接到 {self.ws_uri}")
|
||||
|
||||
# 启动消息接收器
|
||||
self.message_receiver_task = asyncio.create_task(self._message_receiver(ws))
|
||||
|
||||
try:
|
||||
# 准备文件数据
|
||||
file_info = self.prepare_file()
|
||||
if not file_info:
|
||||
return
|
||||
|
||||
print(f"文件信息: {file_info['filename']}, 大小: {file_info['file_size']} bytes, 块数: {file_info['total_chunks']}")
|
||||
print(f"并发设置: 最大 {self.max_concurrent} 个并发上传")
|
||||
|
||||
# 生成stream_id
|
||||
self.stream_id = f"upload_{hash(file_info['filename'] + str(file_info['file_size']))}"
|
||||
print(f"Stream ID: {self.stream_id}")
|
||||
|
||||
# 重置流(如果存在)
|
||||
await self.reset_stream(ws)
|
||||
|
||||
# 准备所有分片
|
||||
self.prepare_chunks(file_info)
|
||||
|
||||
# 并行上传分片
|
||||
await self.upload_chunks_parallel(ws, file_info)
|
||||
|
||||
# 重试失败的分片
|
||||
if self.failed_chunks:
|
||||
await self.retry_failed_chunks(ws, file_info)
|
||||
|
||||
# 完成上传
|
||||
await self.complete_upload(ws)
|
||||
|
||||
# 等待一段时间确保所有响应都收到
|
||||
await asyncio.sleep(2)
|
||||
|
||||
finally:
|
||||
# 取消消息接收器
|
||||
if self.message_receiver_task:
|
||||
self.message_receiver_task.cancel()
|
||||
try:
|
||||
await self.message_receiver_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 清理未完成的Future
|
||||
for future in self.response_futures.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
|
||||
async def _message_receiver(self, ws):
|
||||
"""专门的消息接收协程,负责分发响应到对应的Future"""
|
||||
try:
|
||||
while True:
|
||||
message = await ws.recv()
|
||||
try:
|
||||
data = json.loads(message)
|
||||
echo = data.get('echo', 'unknown')
|
||||
|
||||
# 查找对应的Future
|
||||
if echo in self.response_futures:
|
||||
future = self.response_futures[echo]
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
else:
|
||||
# 处理未预期的响应
|
||||
print(f"📨 未预期响应 [{echo}]: {data}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"⚠️ JSON解析错误: {e}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 消息处理错误: {e}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
print("🔄 消息接收器已停止")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"💥 消息接收器异常: {e}")
|
||||
|
||||
async def _send_and_wait_response(self, ws, request: dict, timeout: float = 10.0) -> Optional[dict]:
|
||||
"""发送请求并等待响应"""
|
||||
echo = request.get('echo', 'unknown')
|
||||
|
||||
# 创建Future用于接收响应
|
||||
future = asyncio.Future()
|
||||
self.response_futures[echo] = future
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
await ws.send(json.dumps(request))
|
||||
|
||||
# 等待响应
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(f"⏰ 请求超时: {echo}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"💥 请求异常: {echo} - {e}")
|
||||
return None
|
||||
finally:
|
||||
# 清理Future
|
||||
if echo in self.response_futures:
|
||||
del self.response_futures[echo]
|
||||
|
||||
def prepare_file(self):
|
||||
"""准备文件信息"""
|
||||
if not os.path.exists(self.file_path):
|
||||
print(f"文件不存在: {self.file_path}")
|
||||
return None
|
||||
|
||||
file_size = os.path.getsize(self.file_path)
|
||||
filename = os.path.basename(self.file_path)
|
||||
total_chunks = (file_size + self.chunk_size - 1) // self.chunk_size
|
||||
|
||||
# 计算SHA256
|
||||
print("计算文件SHA256...")
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(self.file_path, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
sha256_hash.update(chunk)
|
||||
expected_sha256 = sha256_hash.hexdigest()
|
||||
|
||||
return {
|
||||
'filename': filename,
|
||||
'file_size': file_size,
|
||||
'total_chunks': total_chunks,
|
||||
'expected_sha256': expected_sha256
|
||||
}
|
||||
|
||||
def prepare_chunks(self, file_info):
|
||||
"""预读取所有分片数据"""
|
||||
print("预读取分片数据...")
|
||||
with open(self.file_path, 'rb') as f:
|
||||
for chunk_index in range(file_info['total_chunks']):
|
||||
chunk_data = f.read(self.chunk_size)
|
||||
self.chunks[chunk_index] = ChunkInfo(
|
||||
index=chunk_index,
|
||||
data=chunk_data,
|
||||
size=len(chunk_data)
|
||||
)
|
||||
print(f"已准备 {len(self.chunks)} 个分片")
|
||||
|
||||
async def reset_stream(self, ws):
|
||||
"""重置流"""
|
||||
req = {
|
||||
"action": "upload_file_stream",
|
||||
"params": {
|
||||
"stream_id": self.stream_id,
|
||||
"reset": True
|
||||
},
|
||||
"echo": "reset"
|
||||
}
|
||||
|
||||
print("发送重置请求...")
|
||||
response = await self._send_and_wait_response(ws, req, timeout=5.0)
|
||||
|
||||
if response and response.get('echo') == 'reset':
|
||||
print("✅ 流重置完成")
|
||||
else:
|
||||
print(f"⚠️ 重置响应异常: {response}")
|
||||
|
||||
async def upload_chunks_parallel(self, ws, file_info):
|
||||
"""并行上传所有分片"""
|
||||
print(f"\n开始并行上传 {len(self.chunks)} 个分片...")
|
||||
start_time = time.time()
|
||||
|
||||
# 创建上传任务
|
||||
tasks = []
|
||||
for chunk_index in range(file_info['total_chunks']):
|
||||
task = asyncio.create_task(
|
||||
self.upload_single_chunk(ws, chunk_index, file_info)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# 等待所有任务完成
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 统计结果
|
||||
successful = sum(1 for r in results if r is True)
|
||||
failed = sum(1 for r in results if r is not True)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
speed = file_info['file_size'] / elapsed / 1024 / 1024 # MB/s
|
||||
|
||||
print(f"\n📊 并行上传完成:")
|
||||
print(f" 成功: {successful}/{len(self.chunks)}")
|
||||
print(f" 失败: {failed}")
|
||||
print(f" 耗时: {elapsed:.2f}秒")
|
||||
print(f" 速度: {speed:.2f}MB/s")
|
||||
|
||||
if failed > 0:
|
||||
print(f"⚠️ {failed} 个分片上传失败,将进行重试")
|
||||
|
||||
async def upload_single_chunk(self, ws, chunk_index: int, file_info) -> bool:
|
||||
"""上传单个分片"""
|
||||
async with self.upload_semaphore: # 限制并发数
|
||||
chunk = self.chunks[chunk_index]
|
||||
|
||||
try:
|
||||
chunk_base64 = base64.b64encode(chunk.data).decode('utf-8')
|
||||
|
||||
req = {
|
||||
"action": "upload_file_stream",
|
||||
"params": {
|
||||
"stream_id": self.stream_id,
|
||||
"chunk_data": chunk_base64,
|
||||
"chunk_index": chunk_index,
|
||||
"total_chunks": file_info['total_chunks'],
|
||||
"file_size": file_info['file_size'],
|
||||
"filename": file_info['filename'],
|
||||
"expected_sha256": file_info['expected_sha256']
|
||||
},
|
||||
"echo": f"chunk_{chunk_index}"
|
||||
}
|
||||
|
||||
# 使用统一的发送和接收方法
|
||||
response = await self._send_and_wait_response(ws, req, timeout=10.0)
|
||||
|
||||
if response and response.get('echo') == f"chunk_{chunk_index}":
|
||||
if response.get('status') == 'ok':
|
||||
chunk.uploaded = True
|
||||
data = response.get('data', {})
|
||||
progress = data.get('received_chunks', 0)
|
||||
total = data.get('total_chunks', file_info['total_chunks'])
|
||||
print(f"✅ 块 {chunk_index + 1:3d}/{total} ({chunk.size:5d}B) - 进度: {progress}/{total}")
|
||||
return True
|
||||
else:
|
||||
error_msg = response.get('message', 'Unknown error')
|
||||
print(f"❌ 块 {chunk_index + 1} 失败: {error_msg}")
|
||||
self.failed_chunks.add(chunk_index)
|
||||
return False
|
||||
else:
|
||||
print(f"⚠️ 块 {chunk_index + 1} 响应异常或超时")
|
||||
self.failed_chunks.add(chunk_index)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"💥 块 {chunk_index + 1} 异常: {e}")
|
||||
self.failed_chunks.add(chunk_index)
|
||||
return False
|
||||
|
||||
async def retry_failed_chunks(self, ws, file_info):
|
||||
"""重试失败的分片"""
|
||||
print(f"\n🔄 开始重试 {len(self.failed_chunks)} 个失败分片...")
|
||||
|
||||
for retry_round in range(self.max_retries):
|
||||
if not self.failed_chunks:
|
||||
break
|
||||
|
||||
print(f"第 {retry_round + 1} 轮重试,剩余 {len(self.failed_chunks)} 个分片")
|
||||
current_failed = self.failed_chunks.copy()
|
||||
self.failed_chunks.clear()
|
||||
|
||||
# 重试当前失败的分片
|
||||
retry_tasks = []
|
||||
for chunk_index in current_failed:
|
||||
task = asyncio.create_task(
|
||||
self.upload_single_chunk(ws, chunk_index, file_info)
|
||||
)
|
||||
retry_tasks.append(task)
|
||||
|
||||
retry_results = await asyncio.gather(*retry_tasks, return_exceptions=True)
|
||||
successful_retries = sum(1 for r in retry_results if r is True)
|
||||
|
||||
print(f"重试结果: {successful_retries}/{len(current_failed)} 成功")
|
||||
|
||||
if not self.failed_chunks:
|
||||
print("✅ 所有分片重试成功!")
|
||||
break
|
||||
else:
|
||||
await asyncio.sleep(1) # 重试间隔
|
||||
|
||||
if self.failed_chunks:
|
||||
print(f"❌ 仍有 {len(self.failed_chunks)} 个分片失败: {sorted(self.failed_chunks)}")
|
||||
|
||||
async def complete_upload(self, ws):
|
||||
"""完成上传"""
|
||||
req = {
|
||||
"action": "upload_file_stream",
|
||||
"params": {
|
||||
"stream_id": self.stream_id,
|
||||
"is_complete": True
|
||||
},
|
||||
"echo": "complete"
|
||||
}
|
||||
|
||||
print("\n发送完成请求...")
|
||||
response = await self._send_and_wait_response(ws, req, timeout=10.0)
|
||||
|
||||
if response:
|
||||
if response.get('status') == 'ok':
|
||||
data = response.get('data', {})
|
||||
print(f"✅ 上传完成!")
|
||||
print(f" 文件路径: {data.get('file_path')}")
|
||||
print(f" 文件大小: {data.get('file_size')} bytes")
|
||||
print(f" SHA256: {data.get('sha256')}")
|
||||
print(f" 状态: {data.get('status')}")
|
||||
else:
|
||||
print(f"❌ 上传失败: {response.get('message')}")
|
||||
else:
|
||||
print("⚠️ 完成请求超时或失败")
|
||||
|
||||
async def main():
|
||||
# 配置
|
||||
WS_URI = "ws://localhost:3001" # 修改为你的WebSocket地址
|
||||
FILE_PATH = r"C:\Users\nanaeo\Pictures\CatPicture.zip" #!!!!!!!!!!!
|
||||
MAX_CONCURRENT = 8 # 最大并发上传数,可根据服务器性能调整
|
||||
|
||||
# 创建测试文件(如果不存在)
|
||||
if not os.path.exists(FILE_PATH):
|
||||
with open(FILE_PATH, 'w', encoding='utf-8') as f:
|
||||
f.write("这是一个测试文件,用于演示并行文件分片上传功能。\n" * 100)
|
||||
print(f"✅ 创建测试文件: {FILE_PATH}")
|
||||
|
||||
print("=== 并行文件流上传测试 ===")
|
||||
print(f"WebSocket URI: {WS_URI}")
|
||||
print(f"文件路径: {FILE_PATH}")
|
||||
print(f"最大并发数: {MAX_CONCURRENT}")
|
||||
|
||||
try:
|
||||
tester = FileUploadTester(WS_URI, FILE_PATH, MAX_CONCURRENT)
|
||||
await tester.connect_and_upload()
|
||||
print("🎉 测试完成!")
|
||||
except Exception as e:
|
||||
print(f"💥 测试出错: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -1,6 +1,8 @@
|
||||
import { ActionName } from '@/onebot/action/router';
|
||||
import { OneBotAction } from '@/onebot/action/OneBotAction';
|
||||
import { Static, Type } from '@sinclair/typebox';
|
||||
import { NetworkAdapterConfig } from '@/onebot/config/config';
|
||||
import { StreamPacket, StreamStatus } from './StreamBasic';
|
||||
import fs from 'fs';
|
||||
import { join as joinPath } from 'node:path';
|
||||
import { randomUUID } from 'crypto';
|
||||
@ -53,68 +55,51 @@ interface StreamState {
|
||||
|
||||
interface StreamResult {
|
||||
stream_id: string;
|
||||
status: 'receiving' | 'completed' | 'error' | 'ready';
|
||||
status: 'file_created' | 'chunk_received' | 'file_complete';
|
||||
received_chunks: number;
|
||||
total_chunks?: number;
|
||||
missing_chunks?: number[];
|
||||
total_chunks: number;
|
||||
file_path?: string;
|
||||
file_size?: number;
|
||||
sha256?: string;
|
||||
message?: string;
|
||||
}
|
||||
|
||||
export class UploadFileStream extends OneBotAction<Payload, StreamResult> {
|
||||
export class UploadFileStream extends OneBotAction<Payload, StreamPacket<StreamResult>> {
|
||||
override actionName = ActionName.UploadFileStream;
|
||||
override payloadSchema = SchemaData;
|
||||
override useStream = true;
|
||||
|
||||
private static streams = new Map<string, StreamState>();
|
||||
private static memoryUsage = 0;
|
||||
|
||||
async _handle(payload: Payload): Promise<StreamResult> {
|
||||
async _handle(payload: Payload, _adaptername: string, _config: NetworkAdapterConfig): Promise<StreamPacket<StreamResult>> {
|
||||
const { stream_id, reset, verify_only } = payload;
|
||||
|
||||
try {
|
||||
if (reset) return this.resetStream(stream_id);
|
||||
if (verify_only) return this.verifyStream(stream_id);
|
||||
|
||||
const stream = this.getOrCreateStream(payload);
|
||||
|
||||
if (payload.chunk_data && payload.chunk_index !== undefined) {
|
||||
const result = await this.processChunk(stream, payload.chunk_data, payload.chunk_index);
|
||||
if (result.status === 'error') return result;
|
||||
}
|
||||
|
||||
if (payload.is_complete || stream.receivedChunks === stream.totalChunks) {
|
||||
return await this.completeStream(stream);
|
||||
}
|
||||
if (reset) {
|
||||
this.cleanupStream(stream_id);
|
||||
throw new Error('Stream reset completed');
|
||||
}
|
||||
|
||||
if (verify_only) {
|
||||
const stream = UploadFileStream.streams.get(stream_id);
|
||||
if (!stream) throw new Error('Stream not found');
|
||||
return this.getStreamStatus(stream);
|
||||
|
||||
} catch (error) {
|
||||
// 确保在任何错误情况下都清理资源
|
||||
this.cleanupStream(stream_id, true);
|
||||
return this.errorResult(stream_id, error);
|
||||
}
|
||||
}
|
||||
|
||||
private resetStream(streamId: string): StreamResult {
|
||||
this.cleanupStream(streamId);
|
||||
return {
|
||||
stream_id: streamId,
|
||||
status: 'ready',
|
||||
received_chunks: 0,
|
||||
message: 'Stream reset'
|
||||
};
|
||||
}
|
||||
const stream = this.getOrCreateStream(payload);
|
||||
|
||||
private verifyStream(streamId: string): StreamResult {
|
||||
const stream = UploadFileStream.streams.get(streamId);
|
||||
if (!stream) {
|
||||
return this.errorResult(streamId, new Error('Stream not found'));
|
||||
if (payload.chunk_data && payload.chunk_index !== undefined) {
|
||||
return await this.processChunk(stream, payload.chunk_data, payload.chunk_index);
|
||||
}
|
||||
|
||||
if (payload.is_complete || stream.receivedChunks === stream.totalChunks) {
|
||||
return await this.completeStream(stream);
|
||||
}
|
||||
|
||||
return this.getStreamStatus(stream);
|
||||
}
|
||||
|
||||
|
||||
|
||||
private getOrCreateStream(payload: Payload): StreamState {
|
||||
let stream = UploadFileStream.streams.get(payload.stream_id);
|
||||
|
||||
@ -194,7 +179,7 @@ export class UploadFileStream extends OneBotAction<Payload, StreamResult> {
|
||||
}, CONFIG.TIMEOUT);
|
||||
}
|
||||
|
||||
private async processChunk(stream: StreamState, chunkData: string, chunkIndex: number): Promise<StreamResult> {
|
||||
private async processChunk(stream: StreamState, chunkData: string, chunkIndex: number): Promise<StreamPacket<StreamResult>> {
|
||||
// 验证索引
|
||||
if (chunkIndex < 0 || chunkIndex >= stream.totalChunks) {
|
||||
throw new Error(`Invalid chunk index: ${chunkIndex}`);
|
||||
@ -202,30 +187,31 @@ export class UploadFileStream extends OneBotAction<Payload, StreamResult> {
|
||||
|
||||
// 检查重复
|
||||
if (!stream.missingChunks.has(chunkIndex)) {
|
||||
return this.getStreamStatus(stream, `Chunk ${chunkIndex} already received`);
|
||||
}
|
||||
|
||||
try {
|
||||
const buffer = Buffer.from(chunkData, 'base64');
|
||||
|
||||
// 存储分片
|
||||
if (stream.useMemory) {
|
||||
stream.memoryChunks!.set(chunkIndex, buffer);
|
||||
} else {
|
||||
const chunkPath = joinPath(stream.tempDir!, `${chunkIndex}.chunk`);
|
||||
await fs.promises.writeFile(chunkPath, buffer);
|
||||
}
|
||||
|
||||
// 更新状态
|
||||
stream.missingChunks.delete(chunkIndex);
|
||||
stream.receivedChunks++;
|
||||
this.refreshTimeout(stream);
|
||||
|
||||
return this.getStreamStatus(stream);
|
||||
|
||||
} catch (error) {
|
||||
throw new Error(`Chunk processing failed: ${error instanceof Error ? error.message : 'Unknown error'}`);
|
||||
}
|
||||
|
||||
const buffer = Buffer.from(chunkData, 'base64');
|
||||
|
||||
// 存储分片
|
||||
if (stream.useMemory) {
|
||||
stream.memoryChunks!.set(chunkIndex, buffer);
|
||||
} else {
|
||||
const chunkPath = joinPath(stream.tempDir!, `${chunkIndex}.chunk`);
|
||||
await fs.promises.writeFile(chunkPath, buffer);
|
||||
}
|
||||
|
||||
// 更新状态
|
||||
stream.missingChunks.delete(chunkIndex);
|
||||
stream.receivedChunks++;
|
||||
this.refreshTimeout(stream);
|
||||
|
||||
return {
|
||||
type: StreamStatus.Stream,
|
||||
stream_id: stream.id,
|
||||
status: 'chunk_received',
|
||||
received_chunks: stream.receivedChunks,
|
||||
total_chunks: stream.totalChunks
|
||||
};
|
||||
}
|
||||
|
||||
private refreshTimeout(stream: StreamState): void {
|
||||
@ -233,51 +219,42 @@ export class UploadFileStream extends OneBotAction<Payload, StreamResult> {
|
||||
stream.timeoutId = this.setupTimeout(stream.id);
|
||||
}
|
||||
|
||||
private getStreamStatus(stream: StreamState, message?: string): StreamResult {
|
||||
const missingChunks = Array.from(stream.missingChunks).sort();
|
||||
|
||||
private getStreamStatus(stream: StreamState): StreamPacket<StreamResult> {
|
||||
return {
|
||||
type: StreamStatus.Stream,
|
||||
stream_id: stream.id,
|
||||
status: 'receiving',
|
||||
status: 'file_created',
|
||||
received_chunks: stream.receivedChunks,
|
||||
total_chunks: stream.totalChunks,
|
||||
missing_chunks: missingChunks.length > 0 ? missingChunks : undefined,
|
||||
file_size: stream.fileSize,
|
||||
message
|
||||
total_chunks: stream.totalChunks
|
||||
};
|
||||
}
|
||||
|
||||
private async completeStream(stream: StreamState): Promise<StreamResult> {
|
||||
try {
|
||||
// 合并分片
|
||||
const finalBuffer = stream.useMemory ?
|
||||
await this.mergeMemoryChunks(stream) :
|
||||
await this.mergeDiskChunks(stream);
|
||||
private async completeStream(stream: StreamState): Promise<StreamPacket<StreamResult>> {
|
||||
// 合并分片
|
||||
const finalBuffer = stream.useMemory ?
|
||||
await this.mergeMemoryChunks(stream) :
|
||||
await this.mergeDiskChunks(stream);
|
||||
|
||||
// 验证SHA256
|
||||
const sha256 = this.validateSha256(stream, finalBuffer);
|
||||
// 验证SHA256
|
||||
const sha256 = this.validateSha256(stream, finalBuffer);
|
||||
|
||||
// 保存文件
|
||||
const finalPath = stream.finalPath || joinPath(this.core.NapCatTempPath, stream.filename);
|
||||
await fs.promises.writeFile(finalPath, finalBuffer);
|
||||
// 保存文件
|
||||
const finalPath = stream.finalPath || joinPath(this.core.NapCatTempPath, stream.filename);
|
||||
await fs.promises.writeFile(finalPath, finalBuffer);
|
||||
|
||||
// 清理资源但保留文件
|
||||
this.cleanupStream(stream.id, false);
|
||||
// 清理资源但保留文件
|
||||
this.cleanupStream(stream.id, false);
|
||||
|
||||
return {
|
||||
stream_id: stream.id,
|
||||
status: 'completed',
|
||||
received_chunks: stream.receivedChunks,
|
||||
total_chunks: stream.totalChunks,
|
||||
file_path: finalPath,
|
||||
file_size: finalBuffer.length,
|
||||
sha256,
|
||||
message: 'Upload completed'
|
||||
};
|
||||
|
||||
} catch (error) {
|
||||
throw new Error(`Stream completion failed: ${error instanceof Error ? error.message : 'Unknown error'}`);
|
||||
}
|
||||
return {
|
||||
type: StreamStatus.Response,
|
||||
stream_id: stream.id,
|
||||
status: 'file_complete',
|
||||
received_chunks: stream.receivedChunks,
|
||||
total_chunks: stream.totalChunks,
|
||||
file_path: finalPath,
|
||||
file_size: finalBuffer.length,
|
||||
sha256
|
||||
};
|
||||
}
|
||||
|
||||
private async mergeMemoryChunks(stream: StreamState): Promise<Buffer> {
|
||||
@ -357,28 +334,4 @@ export class UploadFileStream extends OneBotAction<Payload, StreamResult> {
|
||||
console.log(`Stream ${streamId} cleaned up`);
|
||||
}
|
||||
}
|
||||
|
||||
private errorResult(streamId: string, error: any): StreamResult {
|
||||
return {
|
||||
stream_id: streamId,
|
||||
status: 'error',
|
||||
received_chunks: 0,
|
||||
message: error instanceof Error ? error.message : 'Unknown error'
|
||||
};
|
||||
}
|
||||
|
||||
// 全局状态查询
|
||||
static getGlobalStatus() {
|
||||
return {
|
||||
activeStreams: this.streams.size,
|
||||
memoryUsageMB: Math.round(this.memoryUsage / 1024 / 1024 * 100) / 100,
|
||||
streams: Array.from(this.streams.values()).map(stream => ({
|
||||
streamId: stream.id,
|
||||
filename: stream.filename,
|
||||
progress: `${stream.receivedChunks}/${stream.totalChunks}`,
|
||||
useMemory: stream.useMemory,
|
||||
createdAt: new Date(stream.createdAt).toISOString()
|
||||
}))
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,201 +0,0 @@
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
class FileUploadTester:
|
||||
def __init__(self, ws_uri: str, file_path: str):
|
||||
self.ws_uri = ws_uri
|
||||
self.file_path = file_path
|
||||
self.chunk_size = 64 * 1024 # 64KB per chunk
|
||||
self.stream_id = None
|
||||
|
||||
async def connect_and_upload(self):
|
||||
"""连接到WebSocket并上传文件"""
|
||||
async with websockets.connect(self.ws_uri) as ws:
|
||||
print(f"已连接到 {self.ws_uri}")
|
||||
|
||||
# 准备文件数据
|
||||
file_info = self.prepare_file()
|
||||
if not file_info:
|
||||
return
|
||||
|
||||
print(f"文件信息: {file_info['filename']}, 大小: {file_info['file_size']} bytes, 块数: {file_info['total_chunks']}")
|
||||
|
||||
# 生成stream_id
|
||||
self.stream_id = f"upload_{hash(file_info['filename'] + str(file_info['file_size']))}"
|
||||
print(f"Stream ID: {self.stream_id}")
|
||||
|
||||
# 重置流(如果存在)
|
||||
await self.reset_stream(ws)
|
||||
|
||||
# 开始分块上传
|
||||
await self.upload_chunks(ws, file_info)
|
||||
|
||||
# 完成上传
|
||||
await self.complete_upload(ws)
|
||||
|
||||
# 等待一些响应
|
||||
await self.listen_for_responses(ws)
|
||||
|
||||
def prepare_file(self):
|
||||
"""准备文件信息"""
|
||||
if not os.path.exists(self.file_path):
|
||||
print(f"文件不存在: {self.file_path}")
|
||||
return None
|
||||
|
||||
file_size = os.path.getsize(self.file_path)
|
||||
filename = os.path.basename(self.file_path)
|
||||
total_chunks = (file_size + self.chunk_size - 1) // self.chunk_size
|
||||
|
||||
# 计算SHA256
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(self.file_path, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
sha256_hash.update(chunk)
|
||||
expected_sha256 = sha256_hash.hexdigest()
|
||||
|
||||
return {
|
||||
'filename': filename,
|
||||
'file_size': file_size,
|
||||
'total_chunks': total_chunks,
|
||||
'expected_sha256': expected_sha256
|
||||
}
|
||||
|
||||
async def reset_stream(self, ws):
|
||||
"""重置流"""
|
||||
req = {
|
||||
"action": "upload_file_stream",
|
||||
"params": {
|
||||
"stream_id": self.stream_id,
|
||||
"reset": True
|
||||
},
|
||||
"echo": "reset"
|
||||
}
|
||||
await ws.send(json.dumps(req))
|
||||
print("发送重置请求...")
|
||||
|
||||
async def upload_chunks(self, ws, file_info):
|
||||
"""上传文件块"""
|
||||
with open(self.file_path, 'rb') as f:
|
||||
for chunk_index in range(file_info['total_chunks']):
|
||||
# 读取块数据
|
||||
chunk_data = f.read(self.chunk_size)
|
||||
chunk_base64 = base64.b64encode(chunk_data).decode('utf-8')
|
||||
|
||||
# 准备请求
|
||||
req = {
|
||||
"action": "upload_file_stream",
|
||||
"params": {
|
||||
"stream_id": self.stream_id,
|
||||
"chunk_data": chunk_base64,
|
||||
"chunk_index": chunk_index,
|
||||
"total_chunks": file_info['total_chunks'],
|
||||
"file_size": file_info['file_size'],
|
||||
"filename": file_info['filename'],
|
||||
#"expected_sha256": file_info['expected_sha256']
|
||||
},
|
||||
"echo": f"chunk_{chunk_index}"
|
||||
}
|
||||
|
||||
await ws.send(json.dumps(req))
|
||||
print(f"发送块 {chunk_index + 1}/{file_info['total_chunks']} ({len(chunk_data)} bytes)")
|
||||
|
||||
# 等待响应
|
||||
try:
|
||||
response = await asyncio.wait_for(ws.recv(), timeout=5.0)
|
||||
resp_data = json.loads(response)
|
||||
if resp_data.get('echo') == f"chunk_{chunk_index}":
|
||||
if resp_data.get('status') == 'ok':
|
||||
data = resp_data.get('data', {})
|
||||
print(f" -> 状态: {data.get('status')}, 已接收: {data.get('received_chunks')}")
|
||||
else:
|
||||
print(f" -> 错误: {resp_data.get('message')}")
|
||||
except asyncio.TimeoutError:
|
||||
print(f" -> 块 {chunk_index} 响应超时")
|
||||
|
||||
# 小延迟避免过快发送
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def complete_upload(self, ws):
|
||||
"""完成上传"""
|
||||
req = {
|
||||
"action": "upload_file_stream",
|
||||
"params": {
|
||||
"stream_id": self.stream_id,
|
||||
"is_complete": True
|
||||
},
|
||||
"echo": "complete"
|
||||
}
|
||||
await ws.send(json.dumps(req))
|
||||
print("发送完成请求...")
|
||||
|
||||
async def verify_stream(self, ws):
|
||||
"""验证流状态"""
|
||||
req = {
|
||||
"action": "upload_file_stream",
|
||||
"params": {
|
||||
"stream_id": self.stream_id,
|
||||
"verify_only": True
|
||||
},
|
||||
"echo": "verify"
|
||||
}
|
||||
await ws.send(json.dumps(req))
|
||||
print("发送验证请求...")
|
||||
|
||||
async def listen_for_responses(self, ws, duration=10):
|
||||
"""监听响应"""
|
||||
print(f"监听响应 {duration} 秒...")
|
||||
try:
|
||||
end_time = asyncio.get_event_loop().time() + duration
|
||||
while asyncio.get_event_loop().time() < end_time:
|
||||
try:
|
||||
msg = await asyncio.wait_for(ws.recv(), timeout=1.0)
|
||||
resp_data = json.loads(msg)
|
||||
echo = resp_data.get('echo', 'unknown')
|
||||
|
||||
if echo == "complete":
|
||||
if resp_data.get('status') == 'ok':
|
||||
data = resp_data.get('data', {})
|
||||
print(f"✅ 上传完成!")
|
||||
print(f" 文件路径: {data.get('file_path')}")
|
||||
print(f" 文件大小: {data.get('file_size')} bytes")
|
||||
print(f" SHA256: {data.get('sha256')}")
|
||||
print(f" 状态: {data.get('status')}")
|
||||
else:
|
||||
print(f"❌ 上传失败: {resp_data.get('message')}")
|
||||
elif echo == "verify":
|
||||
if resp_data.get('status') == 'ok':
|
||||
data = resp_data.get('data', {})
|
||||
print(f"🔍 验证结果: {data}")
|
||||
elif echo == "reset":
|
||||
print(f"🔄 重置完成: {resp_data}")
|
||||
else:
|
||||
print(f"📨 收到响应 [{echo}]: {resp_data}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"监听出错: {e}")
|
||||
|
||||
async def main():
|
||||
# 配置
|
||||
WS_URI = "ws://localhost:3001" # 修改为你的WebSocket地址
|
||||
FILE_PATH = "C:\\Users\\nanaeo\\Pictures\\CatPicture.zip"
|
||||
|
||||
print("=== 文件流上传测试 ===")
|
||||
print(f"WebSocket URI: {WS_URI}")
|
||||
print(f"文件路径: {FILE_PATH}")
|
||||
|
||||
try:
|
||||
tester = FileUploadTester(WS_URI, FILE_PATH)
|
||||
await tester.connect_and_upload()
|
||||
except Exception as e:
|
||||
print(f"测试出错: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
238
src/onebot/action/stream/test_upload_stream.py
Normal file
238
src/onebot/action/stream/test_upload_stream.py
Normal file
@ -0,0 +1,238 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
NapCat OneBot WebSocket 文件流上传测试脚本
|
||||
用于测试 UploadFileStream 接口的一次性分片上传功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
import websockets
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
class OneBotUploadTester:
|
||||
def __init__(self, ws_url: str = "ws://localhost:3001", access_token: Optional[str] = None):
|
||||
self.ws_url = ws_url
|
||||
self.access_token = access_token
|
||||
self.websocket = None
|
||||
|
||||
async def connect(self):
|
||||
"""连接到 OneBot WebSocket"""
|
||||
headers = {}
|
||||
if self.access_token:
|
||||
headers["Authorization"] = f"Bearer {self.access_token}"
|
||||
|
||||
print(f"连接到 {self.ws_url}")
|
||||
self.websocket = await websockets.connect(self.ws_url, extra_headers=headers)
|
||||
print("WebSocket 连接成功")
|
||||
|
||||
async def disconnect(self):
|
||||
"""断开 WebSocket 连接"""
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
print("WebSocket 连接已断开")
|
||||
|
||||
def calculate_file_chunks(self, file_path: str, chunk_size: int = 64 * 1024) -> tuple[List[bytes], str, int]:
|
||||
"""
|
||||
计算文件分片和 SHA256
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
chunk_size: 分片大小(默认64KB)
|
||||
|
||||
Returns:
|
||||
(chunks, sha256_hash, total_size)
|
||||
"""
|
||||
chunks = []
|
||||
hasher = hashlib.sha256()
|
||||
total_size = 0
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
while True:
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
chunks.append(chunk)
|
||||
hasher.update(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
sha256_hash = hasher.hexdigest()
|
||||
print(f"文件分析完成:")
|
||||
print(f" - 文件大小: {total_size} 字节")
|
||||
print(f" - 分片数量: {len(chunks)}")
|
||||
print(f" - SHA256: {sha256_hash}")
|
||||
|
||||
return chunks, sha256_hash, total_size
|
||||
|
||||
async def send_action(self, action: str, params: dict, echo: str = None) -> dict:
|
||||
"""发送 OneBot 动作请求"""
|
||||
if not echo:
|
||||
echo = str(uuid.uuid4())
|
||||
|
||||
message = {
|
||||
"action": action,
|
||||
"params": params,
|
||||
"echo": echo
|
||||
}
|
||||
|
||||
print(f"发送请求: {action}")
|
||||
await self.websocket.send(json.dumps(message))
|
||||
|
||||
# 等待响应
|
||||
while True:
|
||||
response = await self.websocket.recv()
|
||||
data = json.loads(response)
|
||||
|
||||
# 检查是否是我们的响应
|
||||
if data.get("echo") == echo:
|
||||
return data
|
||||
else:
|
||||
# 可能是其他消息,继续等待
|
||||
print(f"收到其他消息: {data}")
|
||||
continue
|
||||
|
||||
async def upload_file_stream_batch(self, file_path: str, chunk_size: int = 64 * 1024) -> str:
|
||||
"""
|
||||
一次性批量上传文件流
|
||||
|
||||
Args:
|
||||
file_path: 要上传的文件路径
|
||||
chunk_size: 分片大小
|
||||
|
||||
Returns:
|
||||
上传完成后的文件路径
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
# 分析文件
|
||||
chunks, sha256_hash, total_size = self.calculate_file_chunks(str(file_path), chunk_size)
|
||||
stream_id = str(uuid.uuid4())
|
||||
|
||||
print(f"\n开始上传文件: {file_path.name}")
|
||||
print(f"流ID: {stream_id}")
|
||||
|
||||
# 一次性发送所有分片
|
||||
total_chunks = len(chunks)
|
||||
|
||||
for chunk_index, chunk_data in enumerate(chunks):
|
||||
# 将分片数据编码为 base64
|
||||
chunk_base64 = base64.b64encode(chunk_data).decode('utf-8')
|
||||
|
||||
# 构建参数
|
||||
params = {
|
||||
"stream_id": stream_id,
|
||||
"chunk_data": chunk_base64,
|
||||
"chunk_index": chunk_index,
|
||||
"total_chunks": total_chunks,
|
||||
"file_size": total_size,
|
||||
"expected_sha256": sha256_hash,
|
||||
"filename": file_path.name
|
||||
}
|
||||
|
||||
# 发送分片
|
||||
response = await self.send_action("upload_file_stream", params)
|
||||
|
||||
if response.get("status") != "ok":
|
||||
raise Exception(f"上传分片 {chunk_index} 失败: {response}")
|
||||
|
||||
# 解析流响应
|
||||
stream_data = response.get("data", {})
|
||||
print(f"分片 {chunk_index + 1}/{total_chunks} 上传成功 "
|
||||
f"(接收: {stream_data.get('received_chunks', 0)}/{stream_data.get('total_chunks', 0)})")
|
||||
|
||||
# 发送完成信号
|
||||
print(f"\n所有分片发送完成,请求文件合并...")
|
||||
complete_params = {
|
||||
"stream_id": stream_id,
|
||||
"is_complete": True
|
||||
}
|
||||
|
||||
response = await self.send_action("upload_file_stream", complete_params)
|
||||
|
||||
if response.get("status") != "ok":
|
||||
raise Exception(f"文件合并失败: {response}")
|
||||
|
||||
result = response.get("data", {})
|
||||
|
||||
if result.get("status") == "file_complete":
|
||||
print(f"✅ 文件上传成功!")
|
||||
print(f" - 文件路径: {result.get('file_path')}")
|
||||
print(f" - 文件大小: {result.get('file_size')} 字节")
|
||||
print(f" - SHA256: {result.get('sha256')}")
|
||||
return result.get('file_path')
|
||||
else:
|
||||
raise Exception(f"文件状态异常: {result}")
|
||||
|
||||
async def test_upload(self, file_path: str, chunk_size: int = 64 * 1024):
|
||||
"""测试文件上传"""
|
||||
try:
|
||||
await self.connect()
|
||||
|
||||
# 执行上传
|
||||
uploaded_path = await self.upload_file_stream_batch(file_path, chunk_size)
|
||||
|
||||
print(f"\n🎉 测试完成! 上传后的文件路径: {uploaded_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
await self.disconnect()
|
||||
|
||||
def create_test_file(file_path: str, size_mb: float = 1):
|
||||
"""创建测试文件"""
|
||||
size_bytes = int(size_mb * 1024 * 1024)
|
||||
|
||||
with open(file_path, 'wb') as f:
|
||||
# 写入一些有意义的测试数据
|
||||
test_data = b"NapCat Upload Test Data - " * 100
|
||||
written = 0
|
||||
while written < size_bytes:
|
||||
write_size = min(len(test_data), size_bytes - written)
|
||||
f.write(test_data[:write_size])
|
||||
written += write_size
|
||||
|
||||
print(f"创建测试文件: {file_path} ({size_mb}MB)")
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description="NapCat OneBot 文件流上传测试")
|
||||
parser.add_argument("--url", default="ws://localhost:3001", help="WebSocket URL")
|
||||
parser.add_argument("--token", help="访问令牌")
|
||||
parser.add_argument("--file", help="要上传的文件路径")
|
||||
parser.add_argument("--chunk-size", type=int, default=64*1024, help="分片大小(字节)")
|
||||
parser.add_argument("--create-test", type=float, help="创建测试文件(MB)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建测试文件
|
||||
if args.create_test:
|
||||
test_file = "test_upload_file.bin"
|
||||
create_test_file(test_file, args.create_test)
|
||||
if not args.file:
|
||||
args.file = test_file
|
||||
|
||||
if not args.file:
|
||||
print("请指定要上传的文件路径,或使用 --create-test 创建测试文件")
|
||||
return
|
||||
|
||||
# 创建测试器并运行
|
||||
tester = OneBotUploadTester(args.url, args.token)
|
||||
await tester.test_upload(args.file, args.chunk_size)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 安装依赖提示
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("请先安装依赖: pip install websockets")
|
||||
exit(1)
|
||||
|
||||
asyncio.run(main())
|
||||
@ -121,19 +121,18 @@ export class OB11HttpServerAdapter extends IOB11NetworkAdapter<HttpServerConfig>
|
||||
const real_echo = payload_echo ?? Math.random().toString(36).substring(2, 15);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const action = this.actions.get(actionName as any);
|
||||
const useStream = action.useStream;
|
||||
if (action) {
|
||||
try {
|
||||
let stream = false;
|
||||
const result = await action.handle(payload, this.name, this.config, {
|
||||
send: request_sse ? async (data: object) => {
|
||||
this.onEvent({ ...OB11Response.ok(data, real_echo), type: 'sse-action' } as unknown as OB11EmitEventContent);
|
||||
this.onEvent({ ...OB11Response.ok(data, real_echo, true) } as unknown as OB11EmitEventContent);
|
||||
} : async (data: object) => {
|
||||
stream = true;
|
||||
res.write(JSON.stringify({ ...OB11Response.ok(data, real_echo), type: 'stream-action' }) + "\r\n\r\n");
|
||||
res.write(JSON.stringify({ ...OB11Response.ok(data, real_echo, true) }) + "\r\n\r\n");
|
||||
}
|
||||
}, real_echo);
|
||||
if (stream) {
|
||||
res.write(JSON.stringify({ ...result, type: 'stream-action' }) + "\r\n\r\n");
|
||||
if (useStream) {
|
||||
res.write(JSON.stringify({ ...result }) + "\r\n\r\n");
|
||||
return res.end();
|
||||
};
|
||||
return res.json(result);
|
||||
|
||||
@ -153,7 +153,7 @@ export class OB11WebSocketClientAdapter extends IOB11NetworkAdapter<WebsocketCli
|
||||
}
|
||||
const retdata = await action.websocketHandle(receiveData.params, echo ?? '', this.name, this.config, {
|
||||
send: async (data: object) => {
|
||||
this.checkStateAndReply<unknown>({ ...OB11Response.ok(data, echo ?? '') });
|
||||
this.checkStateAndReply<unknown>({ ...OB11Response.ok(data, echo ?? '', true) });
|
||||
}
|
||||
});
|
||||
this.checkStateAndReply<unknown>({ ...retdata });
|
||||
|
||||
@ -188,7 +188,7 @@ export class OB11WebSocketServerAdapter extends IOB11NetworkAdapter<WebsocketSer
|
||||
}
|
||||
const retdata = await action.websocketHandle(receiveData.params, echo ?? '', this.name, this.config, {
|
||||
send: async (data: object) => {
|
||||
this.checkStateAndReply<unknown>({ ...OB11Response.ok(data, echo ?? '') }, wsClient);
|
||||
this.checkStateAndReply<unknown>({ ...OB11Response.ok(data, echo ?? '', true) }, wsClient);
|
||||
}
|
||||
});
|
||||
this.checkStateAndReply<unknown>({ ...retdata }, wsClient);
|
||||
|
||||
@ -46,6 +46,7 @@ export interface OB11Return<DataType> {
|
||||
message: string;
|
||||
echo?: unknown; // ws调用api才有此字段
|
||||
wording?: string; // go-cqhttp字段,错误信息
|
||||
stream?: 'stream-action' | 'normal-action' ; // 流式返回标记
|
||||
}
|
||||
|
||||
// 消息数据类型枚举
|
||||
|
||||
Loading…
Reference in New Issue
Block a user