From 9f7a2a36c1f8990bd2ae8d5bd4b354f6e5888cc8 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Tue, 20 May 2025 01:57:57 +0800 Subject: [PATCH] chore: unpack externalUI in a separate temporary directory to avoid malicious compressed packages from polluting workdir --- component/updater/update_ui.go | 144 ++++++++++++++------------------- 1 file changed, 62 insertions(+), 82 deletions(-) diff --git a/component/updater/update_ui.go b/component/updater/update_ui.go index 4d1e98ed..3cb60381 100644 --- a/component/updater/update_ui.go +++ b/component/updater/update_ui.go @@ -3,6 +3,7 @@ package updater import ( "archive/tar" "archive/zip" + "bytes" "compress/gzip" "fmt" "io" @@ -32,6 +33,17 @@ const ( typeTarGzip ) +func (t compressionType) String() string { + switch t { + case typeZip: + return "zip" + case typeTarGzip: + return "tar.gz" + default: + return "unknown" + } +} + var DefaultUiUpdater = &UIUpdater{} func NewUiUpdater(externalUI, externalUIURL, externalUIName string) *UIUpdater { @@ -99,48 +111,35 @@ func detectFileType(data []byte) compressionType { } func (u *UIUpdater) downloadUI() error { - err := u.prepareUIPath() - if err != nil { - return fmt.Errorf("prepare UI path failed: %w", err) - } - data, err := downloadForBytes(u.externalUIURL) if err != nil { return fmt.Errorf("can't download file: %w", err) } - fileType := detectFileType(data) - if fileType == typeUnknown { - return fmt.Errorf("unknown or unsupported file type") + tmpDir := C.Path.Resolve("downloadUI.tmp") + defer os.RemoveAll(tmpDir) + extractedFolder, err := extract(data, tmpDir) + if err != nil { + return fmt.Errorf("can't extract compressed file: %w", err) } - ext := ".zip" - if fileType == typeTarGzip { - ext = ".tgz" - } - - saved := path.Join(C.Path.HomeDir(), "download"+ext) - log.Debugln("compression Type: %s", ext) - if err = saveFile(data, saved); err != nil { - return fmt.Errorf("can't save compressed file: %w", err) - } - defer os.Remove(saved) - - err = cleanup(u.externalUIPath) + log.Debugln("cleanupFolder: %s", u.externalUIPath) + err = cleanup(u.externalUIPath) // cleanup files in dir don't remove dir itself if err != nil { if !os.IsNotExist(err) { return fmt.Errorf("cleanup exist file error: %w", err) } } - extractedFolder, err := extract(saved, C.Path.HomeDir()) + err = u.prepareUIPath() if err != nil { - return fmt.Errorf("can't extract compressed file: %w", err) + return fmt.Errorf("prepare UI path failed: %w", err) } - err = os.Rename(extractedFolder, u.externalUIPath) + log.Debugln("moveFolder from %s to %s", extractedFolder, u.externalUIPath) + err = moveDir(extractedFolder, u.externalUIPath) // move files from tmp to target if err != nil { - return fmt.Errorf("rename UI folder failed: %w", err) + return fmt.Errorf("move UI folder failed: %w", err) } return nil } @@ -155,12 +154,11 @@ func (u *UIUpdater) prepareUIPath() error { return nil } -func unzip(src, dest string) (string, error) { - r, err := zip.OpenReader(src) +func unzip(data []byte, dest string) (string, error) { + r, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) if err != nil { return "", err } - defer r.Close() // check whether or not only exists singleRoot dir rootDir := "" @@ -199,17 +197,7 @@ func unzip(src, dest string) (string, error) { log.Debugln("extractedFolder: %s", extractedFolder) } else { log.Debugln("Match the multiRoot") - // or put the files/dirs into new dir - baseName := filepath.Base(src) - baseName = strings.TrimSuffix(baseName, filepath.Ext(baseName)) - extractedFolder = filepath.Join(dest, baseName) - - for i := 1; ; i++ { - if _, err := os.Stat(extractedFolder); os.IsNotExist(err) { - break - } - extractedFolder = filepath.Join(dest, fmt.Sprintf("%s_%d", baseName, i)) - } + extractedFolder = dest log.Debugln("extractedFolder: %s", extractedFolder) } @@ -253,14 +241,8 @@ func unzip(src, dest string) (string, error) { return extractedFolder, nil } -func untgz(src, dest string) (string, error) { - file, err := os.Open(src) - if err != nil { - return "", err - } - defer file.Close() - - gzr, err := gzip.NewReader(file) +func untgz(data []byte, dest string) (string, error) { + gzr, err := gzip.NewReader(bytes.NewReader(data)) if err != nil { return "", err } @@ -303,8 +285,7 @@ func untgz(src, dest string) (string, error) { isSingleRoot = false } - file.Seek(0, 0) - gzr, _ = gzip.NewReader(file) + _ = gzr.Reset(bytes.NewReader(data)) tr = tar.NewReader(gzr) var extractedFolder string @@ -314,17 +295,7 @@ func untgz(src, dest string) (string, error) { log.Debugln("extractedFolder: %s", extractedFolder) } else { log.Debugln("Match the multiRoot") - baseName := filepath.Base(src) - baseName = strings.TrimSuffix(baseName, filepath.Ext(baseName)) - baseName = strings.TrimSuffix(baseName, ".tar") - extractedFolder = filepath.Join(dest, baseName) - - for i := 1; ; i++ { - if _, err := os.Stat(extractedFolder); os.IsNotExist(err) { - break - } - extractedFolder = filepath.Join(dest, fmt.Sprintf("%s_%d", baseName, i)) - } + extractedFolder = dest log.Debugln("extractedFolder: %s", extractedFolder) } @@ -371,16 +342,16 @@ func untgz(src, dest string) (string, error) { return extractedFolder, nil } -func extract(src, dest string) (string, error) { - srcLower := strings.ToLower(src) - switch { - case strings.HasSuffix(srcLower, ".tar.gz") || - strings.HasSuffix(srcLower, ".tgz"): - return untgz(src, dest) - case strings.HasSuffix(srcLower, ".zip"): - return unzip(src, dest) +func extract(data []byte, dest string) (string, error) { + fileType := detectFileType(data) + log.Debugln("compression Type: %s", fileType) + switch fileType { + case typeZip: + return unzip(data, dest) + case typeTarGzip: + return untgz(data, dest) default: - return "", fmt.Errorf("unsupported file format: %s", src) + return "", fmt.Errorf("unknown or unsupported file type") } } @@ -402,24 +373,33 @@ func cleanTarPath(path string) string { } func cleanup(root string) error { - if _, err := os.Stat(root); os.IsNotExist(err) { - return nil + dirEntryList, err := os.ReadDir(root) + if err != nil { + return err } - return filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + + for _, dirEntry := range dirEntryList { + err = os.RemoveAll(filepath.Join(root, dirEntry.Name())) if err != nil { return err } - if info.IsDir() { - if err := os.RemoveAll(path); err != nil { - return err - } - } else { - if err := os.Remove(path); err != nil { - return err - } + } + return nil +} + +func moveDir(src string, dst string) error { + dirEntryList, err := os.ReadDir(src) + if err != nil { + return err + } + + for _, dirEntry := range dirEntryList { + err = os.Rename(filepath.Join(src, dirEntry.Name()), filepath.Join(dst, dirEntry.Name())) + if err != nil { + return err } - return nil - }) + } + return nil } func inDest(fpath, dest string) bool {