diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index c33e0b32..64539dc7 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -5,23 +5,26 @@ import ( "crypto/md5" "encoding/hex" "errors" + "fmt" "io" "net/url" "os" stdpath "path" "strconv" + "strings" + "sync" "time" - "golang.org/x/sync/semaphore" - "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/pkg/errgroup" + "github.com/alist-org/alist/v3/pkg/singleflight" "github.com/alist-org/alist/v3/pkg/utils" "github.com/avast/retry-go" + "github.com/go-resty/resty/v2" log "github.com/sirupsen/logrus" ) @@ -31,8 +34,16 @@ type BaiduNetdisk struct { uploadThread int vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M) + + upClient *resty.Client // 上传文件使用的http客户端 + uploadUrlG singleflight.Group[string] + uploadUrlMu sync.RWMutex + uploadUrl string // 上传域名 + uploadUrlUpdateTime time.Time // 上传域名上次更新时间 } +var ErrUploadIDExpired = errors.New("uploadid expired") + func (d *BaiduNetdisk) Config() driver.Config { return config } @@ -42,19 +53,26 @@ func (d *BaiduNetdisk) GetAddition() driver.Additional { } func (d *BaiduNetdisk) Init(ctx context.Context) error { + d.upClient = base.NewRestyClient(). + SetTimeout(UPLOAD_TIMEOUT). + SetRetryCount(UPLOAD_RETRY_COUNT). + SetRetryWaitTime(UPLOAD_RETRY_WAIT_TIME). + SetRetryMaxWaitTime(UPLOAD_RETRY_MAX_WAIT_TIME) d.uploadThread, _ = strconv.Atoi(d.UploadThread) - if d.uploadThread < 1 || d.uploadThread > 32 { - d.uploadThread, d.UploadThread = 3, "3" + if d.uploadThread < 1 { + d.uploadThread, d.UploadThread = 1, "1" + } else if d.uploadThread > 32 { + d.uploadThread, d.UploadThread = 32, "32" } if _, err := url.Parse(d.UploadAPI); d.UploadAPI == "" || err != nil { - d.UploadAPI = "https://d.pcs.baidu.com" + d.UploadAPI = UPLOAD_FALLBACK_API } res, err := d.get("/xpan/nas", map[string]string{ "method": "uinfo", }, nil) - log.Debugf("[baidu] get uinfo: %s", string(res)) + log.Debugf("[baidu_netdisk] get uinfo: %s", string(res)) if err != nil { return err } @@ -181,6 +199,11 @@ func (d *BaiduNetdisk) PutRapid(ctx context.Context, dstDir model.Obj, stream mo // **注意**: 截至 2024/04/20 百度云盘 api 接口返回的时间永远是当前时间,而不是文件时间。 // 而实际上云盘存储的时间是文件时间,所以此处需要覆盖时间,保证缓存与云盘的数据一致 func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + // 百度网盘不允许上传空文件 + if stream.GetSize() < 1 { + return nil, ErrBaiduEmptyFilesNotAllowed + } + // rapid upload if newObj, err := d.PutRapid(ctx, dstDir, stream); err == nil { return newObj, nil @@ -245,7 +268,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F } if tmpF != nil { if written != streamSize { - return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize) + return nil, errs.NewErr(err, "CreateTempFile failed, size mismatch: %d != %d ", written, streamSize) } _, err = tmpF.Seek(0, io.SeekStart) if err != nil { @@ -259,82 +282,97 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F mtime := stream.ModTime().Unix() ctime := stream.CreateTime().Unix() - // step.1 预上传 - // 尝试获取之前的进度 + // step.1 尝试读取已保存进度 precreateResp, ok := base.GetUploadProgress[*PrecreateResp](d, d.AccessToken, contentMd5) if !ok { - params := map[string]string{ - "method": "precreate", - } - form := map[string]string{ - "path": path, - "size": strconv.FormatInt(streamSize, 10), - "isdir": "0", - "autoinit": "1", - "rtype": "3", - "block_list": blockListStr, - "content-md5": contentMd5, - "slice-md5": sliceMd5, - } - joinTime(form, ctime, mtime) - - log.Debugf("[baidu_netdisk] precreate data: %s", form) - _, err = d.postForm("/xpan/file", params, form, &precreateResp) + // 没有进度,走预上传 + precreateResp, err = d.precreate(ctx, path, streamSize, blockListStr, contentMd5, sliceMd5, ctime, mtime) if err != nil { return nil, err } - log.Debugf("%+v", precreateResp) if precreateResp.ReturnType == 2 { //rapid upload, since got md5 match from baidu server // 修复时间,具体原因见 Put 方法注释的 **注意** - precreateResp.File.Ctime = ctime - precreateResp.File.Mtime = mtime return fileToObj(precreateResp.File), nil } } + // step.2 上传分片 - threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, - retry.Attempts(1), - retry.Delay(time.Second), - retry.DelayType(retry.BackOffDelay)) - sem := semaphore.NewWeighted(3) - for i, partseq := range precreateResp.BlockList { - if utils.IsCanceled(upCtx) { - break +uploadLoop: + for attempt := 0; attempt < 2; attempt++ { + // 获取上传域名 + uploadUrl := d.getUploadUrl(path, precreateResp.Uploadid) + // 并发上传 + threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, + retry.Attempts(1), + retry.Delay(time.Second), + retry.DelayType(retry.BackOffDelay)) + + cacheReaderAt, okReaderAt := cache.(io.ReaderAt) + if !okReaderAt { + return nil, fmt.Errorf("cache object must implement io.ReaderAt interface for upload operations") } - i, partseq, offset, byteSize := i, partseq, int64(partseq)*sliceSize, sliceSize - if partseq+1 == count { - byteSize = lastBlockSize + totalParts := len(precreateResp.BlockList) + for i, partseq := range precreateResp.BlockList { + if utils.IsCanceled(upCtx) || partseq < 0 { + continue + } + + i, partseq := i, partseq + offset, size := int64(partseq)*sliceSize, sliceSize + if partseq+1 == count { + size = lastBlockSize + } + threadG.Go(func(ctx context.Context) error { + params := map[string]string{ + "method": "upload", + "access_token": d.AccessToken, + "type": "tmpfile", + "path": path, + "uploadid": precreateResp.Uploadid, + "partseq": strconv.Itoa(partseq), + } + section := io.NewSectionReader(cacheReaderAt, offset, size) + err := d.uploadSlice(ctx, uploadUrl, params, stream.GetName(), driver.NewLimitedUploadStream(ctx, section)) + if err != nil { + return err + } + precreateResp.BlockList[i] = -1 + // 当前goroutine还没退出,+1才是真正成功的数量 + success := threadG.Success() + 1 + progress := float64(success) * 100 / float64(totalParts) + up(progress) + return nil + }) } - threadG.Go(func(ctx context.Context) error { - if err = sem.Acquire(ctx, 1); err != nil { - return err - } - defer sem.Release(1) - params := map[string]string{ - "method": "upload", - "access_token": d.AccessToken, - "type": "tmpfile", - "path": path, - "uploadid": precreateResp.Uploadid, - "partseq": strconv.Itoa(partseq), - } - err := d.uploadSlice(ctx, params, stream.GetName(), - driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize))) - if err != nil { - return err - } - up(float64(threadG.Success()) * 100 / float64(len(precreateResp.BlockList))) - precreateResp.BlockList[i] = -1 - return nil - }) - } - if err = threadG.Wait(); err != nil { - // 如果属于用户主动取消,则保存上传进度 + + err = threadG.Wait() + if err == nil { + break uploadLoop + } + + // 保存进度(所有错误都会保存) + precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 }) + base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5) + if errors.Is(err, context.Canceled) { - precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 }) + return nil, err + } + if errors.Is(err, ErrUploadIDExpired) { + log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch") + // 重新 precreate(所有分片都要重传) + newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime) + if err2 != nil { + return nil, err2 + } + if newPre.ReturnType == 2 { + return fileToObj(newPre.File), nil + } + precreateResp = newPre + // 覆盖掉旧的进度 base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5) + continue uploadLoop } return nil, err } @@ -348,23 +386,67 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F // 修复时间,具体原因见 Put 方法注释的 **注意** newFile.Ctime = ctime newFile.Mtime = mtime + // 上传成功清理进度 + base.SaveUploadProgress(d, nil, d.AccessToken, contentMd5) return fileToObj(newFile), nil } -func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params map[string]string, fileName string, file io.Reader) error { - res, err := base.RestyClient.R(). +// precreate 执行预上传操作,支持首次上传和 uploadid 过期重试 +func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize int64, blockListStr, contentMd5, sliceMd5 string, ctime, mtime int64) (*PrecreateResp, error) { + params := map[string]string{"method": "precreate"} + form := map[string]string{ + "path": path, + "size": strconv.FormatInt(streamSize, 10), + "isdir": "0", + "autoinit": "1", + "rtype": "3", + "block_list": blockListStr, + } + + // 只有在首次上传时才包含 content-md5 和 slice-md5 + if contentMd5 != "" && sliceMd5 != "" { + form["content-md5"] = contentMd5 + form["slice-md5"] = sliceMd5 + } + + joinTime(form, ctime, mtime) + + var precreateResp PrecreateResp + _, err := d.postForm("/xpan/file", params, form, &precreateResp) + if err != nil { + return nil, err + } + + // 修复时间,具体原因见 Put 方法注释的 **注意** + if precreateResp.ReturnType == 2 { + precreateResp.File.Ctime = ctime + precreateResp.File.Mtime = mtime + } + + return &precreateResp, nil +} + +func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params map[string]string, fileName string, file io.Reader) error { + res, err := d.upClient.R(). SetContext(ctx). SetQueryParams(params). SetFileReader("file", fileName, file). - Post(d.UploadAPI + "/rest/2.0/pcs/superfile2") + Post(uploadUrl + "/rest/2.0/pcs/superfile2") if err != nil { return err } log.Debugln(res.RawResponse.Status + res.String()) errCode := utils.Json.Get(res.Body(), "error_code").ToInt() errNo := utils.Json.Get(res.Body(), "errno").ToInt() + respStr := res.String() + lower := strings.ToLower(respStr) + if strings.Contains(lower, "uploadid") && + (strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) { + return ErrUploadIDExpired + } + if errCode != 0 || errNo != 0 { - return errs.NewErr(errs.StreamIncomplete, "error in uploading to baidu, will retry. response=%s", res.String()) + return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", res.String()) } return nil } diff --git a/drivers/baidu_netdisk/meta.go b/drivers/baidu_netdisk/meta.go index 7577c747..b75650ef 100644 --- a/drivers/baidu_netdisk/meta.go +++ b/drivers/baidu_netdisk/meta.go @@ -1,6 +1,8 @@ package baidu_netdisk import ( + "time" + "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/op" ) @@ -17,11 +19,21 @@ type Addition struct { AccessToken string UploadThread string `json:"upload_thread" default:"3" help:"1<=thread<=32"` UploadAPI string `json:"upload_api" default:"https://d.pcs.baidu.com"` + UseDynamicUploadAPI bool `json:"use_dynamic_upload_api" default:"true" help:"dynamically get upload api domain, when enabled, the 'Upload API' setting will be used as a fallback if failed to get"` CustomUploadPartSize int64 `json:"custom_upload_part_size" type:"number" default:"0" help:"0 for auto"` LowBandwithUploadMode bool `json:"low_bandwith_upload_mode" default:"false"` OnlyListVideoFile bool `json:"only_list_video_file" default:"false"` } +const ( + UPLOAD_FALLBACK_API = "https://d.pcs.baidu.com" // 备用上传地址 + UPLOAD_URL_EXPIRE_TIME = time.Minute * 60 // 上传地址有效期(分钟) + UPLOAD_TIMEOUT = time.Minute * 30 // 上传请求超时时间 + UPLOAD_RETRY_COUNT = 3 + UPLOAD_RETRY_WAIT_TIME = time.Second * 1 + UPLOAD_RETRY_MAX_WAIT_TIME = time.Second * 5 +) + var config = driver.Config{ Name: "BaiduNetdisk", DefaultRoot: "/", diff --git a/drivers/baidu_netdisk/types.go b/drivers/baidu_netdisk/types.go index ed9b09df..a158956d 100644 --- a/drivers/baidu_netdisk/types.go +++ b/drivers/baidu_netdisk/types.go @@ -1,6 +1,7 @@ package baidu_netdisk import ( + "errors" "path" "strconv" "time" @@ -9,6 +10,10 @@ import ( "github.com/alist-org/alist/v3/pkg/utils" ) +var ( + ErrBaiduEmptyFilesNotAllowed = errors.New("empty files are not allowed by baidu netdisk") +) + type TokenErrResp struct { ErrorDescription string `json:"error_description"` Error string `json:"error"` @@ -189,3 +194,27 @@ type PrecreateResp struct { // return_type=2 File File `json:"info"` } + +type UploadServerResp struct { + BakServer []any `json:"bak_server"` + BakServers []struct { + Server string `json:"server"` + } `json:"bak_servers"` + ClientIP string `json:"client_ip"` + ErrorCode int `json:"error_code"` + ErrorMsg string `json:"error_msg"` + Expire int `json:"expire"` + Host string `json:"host"` + Newno string `json:"newno"` + QuicServer []any `json:"quic_server"` + QuicServers []struct { + Server string `json:"server"` + } `json:"quic_servers"` + RequestID int64 `json:"request_id"` + Server []any `json:"server"` + ServerTime int `json:"server_time"` + Servers []struct { + Server string `json:"server"` + } `json:"servers"` + Sl int `json:"sl"` +} diff --git a/drivers/baidu_netdisk/util.go b/drivers/baidu_netdisk/util.go index 1249b3f4..c5a73343 100644 --- a/drivers/baidu_netdisk/util.go +++ b/drivers/baidu_netdisk/util.go @@ -73,7 +73,7 @@ func (d *BaiduNetdisk) request(furl string, method string, callback base.ReqCall errno := utils.Json.Get(res.Body(), "errno").ToInt() if errno != 0 { if utils.SliceContains([]int{111, -6}, errno) { - log.Info("refreshing baidu_netdisk token.") + log.Info("[baidu_netdisk] refreshing baidu_netdisk token.") err2 := d.refreshToken() if err2 != nil { return retry.Unrecoverable(err2) @@ -284,10 +284,10 @@ func (d *BaiduNetdisk) getSliceSize(filesize int64) int64 { // 非会员固定为 4MB if d.vipType == 0 { if d.CustomUploadPartSize != 0 { - log.Warnf("CustomUploadPartSize is not supported for non-vip user, use DefaultSliceSize") + log.Warnf("[baidu_netdisk] CustomUploadPartSize is not supported for non-vip user, use DefaultSliceSize") } if filesize > MaxSliceNum*DefaultSliceSize { - log.Warnf("File size(%d) is too large, may cause upload failure", filesize) + log.Warnf("[baidu_netdisk] File size(%d) is too large, may cause upload failure", filesize) } return DefaultSliceSize @@ -295,17 +295,17 @@ func (d *BaiduNetdisk) getSliceSize(filesize int64) int64 { if d.CustomUploadPartSize != 0 { if d.CustomUploadPartSize < DefaultSliceSize { - log.Warnf("CustomUploadPartSize(%d) is less than DefaultSliceSize(%d), use DefaultSliceSize", d.CustomUploadPartSize, DefaultSliceSize) + log.Warnf("[baidu_netdisk] CustomUploadPartSize(%d) is less than DefaultSliceSize(%d), use DefaultSliceSize", d.CustomUploadPartSize, DefaultSliceSize) return DefaultSliceSize } if d.vipType == 1 && d.CustomUploadPartSize > VipSliceSize { - log.Warnf("CustomUploadPartSize(%d) is greater than VipSliceSize(%d), use VipSliceSize", d.CustomUploadPartSize, VipSliceSize) + log.Warnf("[baidu_netdisk] CustomUploadPartSize(%d) is greater than VipSliceSize(%d), use VipSliceSize", d.CustomUploadPartSize, VipSliceSize) return VipSliceSize } if d.vipType == 2 && d.CustomUploadPartSize > SVipSliceSize { - log.Warnf("CustomUploadPartSize(%d) is greater than SVipSliceSize(%d), use SVipSliceSize", d.CustomUploadPartSize, SVipSliceSize) + log.Warnf("[baidu_netdisk] CustomUploadPartSize(%d) is greater than SVipSliceSize(%d), use SVipSliceSize", d.CustomUploadPartSize, SVipSliceSize) return SVipSliceSize } @@ -335,12 +335,89 @@ func (d *BaiduNetdisk) getSliceSize(filesize int64) int64 { } if filesize > MaxSliceNum*maxSliceSize { - log.Warnf("File size(%d) is too large, may cause upload failure", filesize) + log.Warnf("[baidu_netdisk] File size(%d) is too large, may cause upload failure", filesize) } return maxSliceSize } +// getUploadUrl 从开放平台获取上传域名/地址,并发请求会被合并,结果会被缓存1h。 +// 如果获取失败,则返回 Upload API设置项。 +func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string { + if !d.UseDynamicUploadAPI { + return d.UploadAPI + } + getCachedUrlFunc := func() string { + d.uploadUrlMu.RLock() + defer d.uploadUrlMu.RUnlock() + if d.uploadUrl != "" && time.Since(d.uploadUrlUpdateTime) < UPLOAD_URL_EXPIRE_TIME { + return d.uploadUrl + } + return "" + } + // 检查地址缓存 + if uploadUrl := getCachedUrlFunc(); uploadUrl != "" { + return uploadUrl + } + + uploadUrlGetFunc := func() (string, error) { + // 双重检查缓存 + if uploadUrl := getCachedUrlFunc(); uploadUrl != "" { + return uploadUrl, nil + } + + uploadUrl, err := d.requestForUploadUrl(path, uploadId) + if err != nil { + return "", err + } + + d.uploadUrlMu.Lock() + defer d.uploadUrlMu.Unlock() + d.uploadUrl = uploadUrl + d.uploadUrlUpdateTime = time.Now() + return uploadUrl, nil + } + + uploadUrl, err, _ := d.uploadUrlG.Do("", uploadUrlGetFunc) + if err != nil { + fallback := d.UploadAPI + log.Warnf("[baidu_netdisk] get upload URL failed (%v), will use fallback URL: %s", err, fallback) + return fallback + } + return uploadUrl +} + +// requestForUploadUrl 请求获取上传地址。 +// 实测此接口不需要认证,传method和upload_version就行,不过还是按文档规范调用。 +// https://pan.baidu.com/union/doc/Mlvw5hfnr +func (d *BaiduNetdisk) requestForUploadUrl(path, uploadId string) (string, error) { + params := map[string]string{ + "method": "locateupload", + "appid": "250528", + "path": path, + "uploadid": uploadId, + "upload_version": "2.0", + } + apiUrl := "https://d.pcs.baidu.com/rest/2.0/pcs/file" + var resp UploadServerResp + _, err := d.request(apiUrl, http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(params) + }, &resp) + if err != nil { + return "", err + } + var uploadUrl string + if len(resp.Servers) > 0 { + uploadUrl = resp.Servers[0].Server + } else if len(resp.BakServers) > 0 { + uploadUrl = resp.BakServers[0].Server + } + if uploadUrl == "" { + return "", errors.New("upload URL is empty") + } + return uploadUrl, nil +} + // func encodeURIComponent(str string) string { // r := url.QueryEscape(str) // r = strings.ReplaceAll(r, "+", "%20")