diff --git a/README.md b/README.md index d529dc15..73e33c52 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ zerobot -h -t token -u url [-d|w] [-g 监听地址:端口] qq1 qq2 qq3 ... - [x] 删除[涩图/二次元/风景/车万][P站图片ID] - [x] > setu status - **本地涩图** `import _ "github.com/FloatTech/ZeroBot-Plugin/plugin_nativesetu"` - - [x] 来份本地[xxx] + - [x] 本地[xxx] - [x] 刷新本地[xxx] - [x] 设置本地setu绝对路径[xxx] - [x] 刷新所有本地setu diff --git a/go.mod b/go.mod index 0e0bce6d..665d3e44 100644 --- a/go.mod +++ b/go.mod @@ -23,4 +23,5 @@ require ( github.com/t-tomalak/logrus-easy-formatter v0.0.0-20190827215021-c074f06c5816 github.com/tidwall/gjson v1.11.0 github.com/wdvxdr1123/ZeroBot v1.4.1 + golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d ) diff --git a/plugin_nativesetu/data.go b/plugin_nativesetu/data.go index 5ae43e5a..5d6cd932 100644 --- a/plugin_nativesetu/data.go +++ b/plugin_nativesetu/data.go @@ -1,7 +1,9 @@ package nativesetu import ( + "bytes" "image" + "io" "io/fs" "os" "sync" @@ -9,6 +11,7 @@ import ( "github.com/corona10/goimagehash" "github.com/sirupsen/logrus" "github.com/wdvxdr1123/ZeroBot/utils/helper" + "golang.org/x/image/webp" "github.com/FloatTech/ZeroBot-Plugin/utils/file" "github.com/FloatTech/ZeroBot-Plugin/utils/process" @@ -17,8 +20,9 @@ import ( // setuclass holds setus in a folder, which is the class name. type setuclass struct { - ImgID uint64 `db:"imgid"` // ImgID 图片唯一 id (dhash) + ImgID int64 `db:"imgid"` // ImgID 图片唯一 id (dhash) Name string `db:"name"` // Name 图片名 + Path string `db:"path"` // Path 图片路径 } var ( @@ -41,45 +45,71 @@ func init() { logrus.Println("[nsetu] set setu dir to", setupath) } } + if file.IsExist(dbfile) { + err := db.Open() + if err == nil { + setuclasses, err = db.ListTables() + } + if err != nil { + logrus.Errorln("[nsetu]", err) + } + } }() } func scanall(path string) error { - setuclasses = setuclasses[:0] + setuclasses = nil model := &setuclass{} root := os.DirFS(path) - return fs.WalkDir(root, "./", func(path string, d fs.DirEntry, err error) error { + _ = db.Close() + _ = os.Remove(dbfile) + return fs.WalkDir(root, ".", func(path string, d fs.DirEntry, err error) error { if err != nil { return err } if d.IsDir() { clsn := d.Name() - mu.Lock() - err = db.Create(clsn, model) - setuclasses = append(setuclasses, clsn) - mu.Unlock() - if err == nil { - err = scanclass(root, clsn) - if err != nil { - return err + if clsn != "." { + mu.Lock() + err = db.Create(clsn, model) + setuclasses = append(setuclasses, clsn) + mu.Unlock() + if err == nil { + err = scanclass(root, path, clsn) + if err != nil { + logrus.Errorln("[nsetu]", err) + return err + } } } } - return err + return nil }) } -func scanclass(root fs.FS, clsn string) error { - return fs.WalkDir(root, clsn, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } +func scanclass(root fs.FS, path, clsn string) error { + ds, err := fs.ReadDir(root, path) + if err != nil { + return err + } + mu.Lock() + _ = db.Truncate(clsn) + mu.Unlock() + for _, d := range ds { if !d.IsDir() { - f, e := os.Open(path) + relpath := path + "/" + d.Name() + fullpath := setupath + "/" + relpath + logrus.Debugln("[nsetu] read", fullpath) + f, e := os.ReadFile(fullpath) if e != nil { return e } - img, _, e := image.Decode(f) + b := bytes.NewReader(f) + img, _, e := image.Decode(b) + if e != nil { + b.Seek(0, io.SeekStart) + img, e = webp.Decode(b) + } if e != nil { return e } @@ -87,10 +117,15 @@ func scanclass(root fs.FS, clsn string) error { if e != nil { return e } + dhi := int64(dh.GetHash()) + logrus.Debugln("[nsetu] insert", d.Name(), "with id", dhi, "into", clsn) mu.Lock() - err = db.Insert(clsn, &setuclass{ImgID: dh.GetHash(), Name: d.Name()}) + err = db.Insert(clsn, &setuclass{ImgID: dhi, Name: d.Name(), Path: relpath}) mu.Unlock() + if err != nil { + return err + } } - return err - }) + } + return nil } diff --git a/plugin_nativesetu/main.go b/plugin_nativesetu/main.go index 04a10b8b..63a95f01 100644 --- a/plugin_nativesetu/main.go +++ b/plugin_nativesetu/main.go @@ -6,6 +6,7 @@ import ( "github.com/FloatTech/ZeroBot-Plugin/control" "github.com/FloatTech/ZeroBot-Plugin/utils/rule" + "github.com/sirupsen/logrus" zero "github.com/wdvxdr1123/ZeroBot" "github.com/wdvxdr1123/ZeroBot/message" "github.com/wdvxdr1123/ZeroBot/utils/helper" @@ -25,13 +26,13 @@ func init() { engine := control.Register("nativesetu", &control.Options{ DisableOnDefault: false, Help: "本地涩图\n" + - "- 来份本地[xxx]\n" + + "- 本地[xxx]\n" + "- 刷新本地[xxx]\n" + "- 设置本地setu绝对路径[xxx]\n" + "- 刷新所有本地setu\n" + "- 所有本地setu分类", }) - engine.OnRegex(`^来份本地(.*)$`, rule.FirstValueInList(setuclasses)).SetBlock(true).SetPriority(20). + engine.OnRegex(`^本地(.*)$`, func(ctx *zero.Ctx) bool { return rule.FirstValueInList(setuclasses)(ctx) }).SetBlock(true).SetPriority(36). Handle(func(ctx *zero.Ctx) { imgtype := ctx.State["regex_matched"].([]string)[1] sc := new(setuclass) @@ -41,36 +42,40 @@ func init() { if err != nil { ctx.SendChain(message.Text("ERROR: ", err)) } else { - p := "file:///" + setupath + "/" + imgtype + "/" + sc.Name + p := "file:///" + setupath + "/" + sc.Path ctx.SendChain(message.Text(imgtype, ": ", sc.Name, "\n"), message.Image(p)) } }) - engine.OnRegex(`^刷新本地(.*)$`, rule.FirstValueInList(setuclasses), zero.SuperUserPermission).SetBlock(true).SetPriority(20). + engine.OnRegex(`^刷新本地(.*)$`, func(ctx *zero.Ctx) bool { return rule.FirstValueInList(setuclasses)(ctx) }, zero.SuperUserPermission).SetBlock(true).SetPriority(36). Handle(func(ctx *zero.Ctx) { imgtype := ctx.State["regex_matched"].([]string)[1] - err := scanclass(os.DirFS(setupath), imgtype) + err := scanclass(os.DirFS(setupath), imgtype, imgtype) if err == nil { ctx.SendChain(message.Text("成功!")) } else { ctx.SendChain(message.Text("ERROR: ", err)) } }) - engine.OnRegex(`^设置本地setu绝对路径(.*)$`, zero.SuperUserPermission).SetBlock(true).SetPriority(20). + engine.OnRegex(`^设置本地setu绝对路径(.*)$`, zero.SuperUserPermission).SetBlock(true).SetPriority(36). Handle(func(ctx *zero.Ctx) { setupath = ctx.State["regex_matched"].([]string)[1] err := os.WriteFile(cfgfile, helper.StringToBytes(setupath), 0644) - if err != nil { + if err == nil { + ctx.SendChain(message.Text("成功!")) + } else { ctx.SendChain(message.Text("ERROR: ", err)) } }) - engine.OnFullMatch("刷新所有本地setu", zero.SuperUserPermission).SetBlock(true).SetPriority(20). + engine.OnFullMatch("刷新所有本地setu", zero.SuperUserPermission).SetBlock(true).SetPriority(36). Handle(func(ctx *zero.Ctx) { err := scanall(setupath) - if err != nil { + if err == nil { + ctx.SendChain(message.Text("成功!")) + } else { ctx.SendChain(message.Text("ERROR: ", err)) } }) - engine.OnFullMatch("所有本地setu分类").SetBlock(true).SetPriority(20). + engine.OnFullMatch("所有本地setu分类").SetBlock(true).SetPriority(36). Handle(func(ctx *zero.Ctx) { msg := "所有本地setu分类" mu.RLock() @@ -80,8 +85,10 @@ func init() { msg += fmt.Sprintf("\n%02d. %s(%d)", i, c, n) } else { msg += fmt.Sprintf("\n%02d. %s(error)", i, c) + logrus.Errorln("[nsetu]", err) } } mu.RUnlock() + ctx.SendChain(message.Text(msg)) }) } diff --git a/plugin_qingyunke/qingyunke.go b/plugin_qingyunke/qingyunke.go index a438bcef..7d96d3d9 100644 --- a/plugin_qingyunke/qingyunke.go +++ b/plugin_qingyunke/qingyunke.go @@ -21,7 +21,7 @@ import ( ) var ( - prio = 100 + prio = 256 bucket = rate.NewManager(time.Minute, 20) // 青云客接口回复 engine *zero.Engine ) diff --git a/utils/rule/extension.go b/utils/rule/extension.go index 5580d649..20958dd2 100644 --- a/utils/rule/extension.go +++ b/utils/rule/extension.go @@ -6,8 +6,8 @@ import zero "github.com/wdvxdr1123/ZeroBot" func FirstValueInList(list []string) zero.Rule { return func(ctx *zero.Ctx) bool { first := ctx.State["regex_matched"].([]string)[1] - for i := range list { - if first == list[i] { + for _, v := range list { + if first == v { return true } } diff --git a/utils/sql/sqlite.go b/utils/sql/sqlite.go index ab44e059..c237d309 100644 --- a/utils/sql/sqlite.go +++ b/utils/sql/sqlite.go @@ -16,6 +16,27 @@ type Sqlite struct { DBPath string } +// Open 打开数据库 +func (db *Sqlite) Open() (err error) { + if db.DB == nil { + database, err := sql.Open("sqlite3", db.DBPath) + if err != nil { + return err + } + db.DB = database + } + return +} + +// Close 关闭数据库 +func (db *Sqlite) Close() (err error) { + if db.DB != nil { + err = db.DB.Close() + db.DB = nil + } + return +} + // Create 生成数据库 // 默认结构体的第一个元素为主键 // 返回错误 @@ -57,7 +78,7 @@ func (db *Sqlite) Create(table string, objptr interface{}) (err error) { // 默认结构体的第一个元素为主键 // 返回错误 func (db *Sqlite) Insert(table string, objptr interface{}) error { - rows, err := db.DB.Query("SELECT * FROM " + table + ";") + rows, err := db.DB.Query("SELECT * FROM " + table + " limit 1;") if err != nil { return err } @@ -67,9 +88,9 @@ func (db *Sqlite) Insert(table string, objptr interface{}) error { tags, _ := rows.Columns() rows.Close() var ( - values = values(objptr) - top = len(tags) - 1 - cmd = []string{} + vals = values(objptr) + top = len(tags) - 1 + cmd = []string{} ) cmd = append(cmd, "REPLACE INTO") cmd = append(cmd, table) @@ -105,7 +126,7 @@ func (db *Sqlite) Insert(table string, objptr interface{}) error { if err != nil { return err } - _, err = stmt.Exec(values...) + _, err = stmt.Exec(vals...) if err != nil { return err } @@ -173,7 +194,7 @@ func (db *Sqlite) ListTables() (s []string, err error) { return } -// Del 删除数据库 +// Del 删除数据库表项 // condition 可为"WHERE id = 0" // 返回错误 func (db *Sqlite) Del(table string, condition string) error { @@ -192,11 +213,27 @@ func (db *Sqlite) Del(table string, condition string) error { return stmt.Close() } +// Truncate 清空数据库表 +func (db *Sqlite) Truncate(table string) error { + var cmd = []string{} + cmd = append(cmd, "TRUNCATE TABLE") + cmd = append(cmd, table) + stmt, err := db.DB.Prepare(strings.Join(cmd, " ") + ";") + if err != nil { + return err + } + _, err = stmt.Exec() + if err != nil { + return err + } + return stmt.Close() +} + // Count 查询数据库行数 // 返回行数以及错误 func (db *Sqlite) Count(table string) (num int, err error) { var cmd = []string{} - cmd = append(cmd, "SELECT * FROM") + cmd = append(cmd, "SELECT COUNT(1) FROM") cmd = append(cmd, table) rows, err := db.DB.Query(strings.Join(cmd, " ") + ";") if err != nil { @@ -205,8 +242,8 @@ func (db *Sqlite) Count(table string) (num int, err error) { if rows.Err() != nil { return num, rows.Err() } - for rows.Next() { - num++ + if rows.Next() { + rows.Scan(&num) } rows.Close() return num, nil @@ -236,8 +273,22 @@ func kinds(objptr interface{}) []string { } for i, flen := 0, elem.Type().NumField(); i < flen; i++ { switch elem.Field(i).Type().String() { - case "int64": + case "int8": + kinds = append(kinds, "TINYINT") + case "uint8", "byte": + kinds = append(kinds, "UNSIGNED TINYINT") + case "int16": + kinds = append(kinds, "SMALLINT") + case "uint16": + kinds = append(kinds, "UNSIGNED SMALLINT") + case "int32": kinds = append(kinds, "INT") + case "uint32": + kinds = append(kinds, "UNSIGNED INT") + case "int64": + kinds = append(kinds, "BIGINT") + case "uint64": + kinds = append(kinds, "UNSIGNED BIGINT") case "string": kinds = append(kinds, "TEXT") default: @@ -256,14 +307,7 @@ func values(objptr interface{}) []interface{} { elem = elem.Field(0) } for i, flen := 0, elem.Type().NumField(); i < flen; i++ { - switch elem.Field(i).Type().String() { - case "int64": - values = append(values, elem.Field(i).Int()) - case "string": - values = append(values, elem.Field(i).String()) - default: - values = append(values, elem.Field(i).String()) - } + values = append(values, elem.Field(i).Interface()) } return values }