From 22e23efbdcd12bffac50e3503d379a9d447fb001 Mon Sep 17 00:00:00 2001 From: fumiama Date: Tue, 26 Oct 2021 00:48:13 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=8F=EF=B8=8F=20=20=E4=BF=AE=E5=A4=8D=20sq?= =?UTF-8?q?l=20=E6=9F=A5=E8=AF=A2=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- control/rule.go | 35 ++++++++++++++++++----------------- utils/sql/sqlite.go | 28 +++++++++++++--------------- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/control/rule.go b/control/rule.go index 397ca2b8..591a2bc4 100644 --- a/control/rule.go +++ b/control/rule.go @@ -89,11 +89,13 @@ func (m *Control) Reset(groupID int64) { func (m *Control) IsEnabledIn(gid int64) bool { var c grpcfg var err error + logrus.Debugln("[control] IsEnabledIn recv gid =", gid) if gid != 0 { m.RLock() err = db.Find(m.service, &c, "WHERE gid = "+strconv.FormatInt(gid, 10)) m.RUnlock() - if err == nil { + logrus.Debugln("[control] db find gid =", c.GroupID) + if err == nil && gid == c.GroupID { logrus.Debugf("[control] plugin %s of grp %d : %d", m.service, c.GroupID, c.Disable) return c.Disable == 0 } @@ -102,7 +104,7 @@ func (m *Control) IsEnabledIn(gid int64) bool { err = db.Find(m.service, &c, "WHERE gid = 0") m.RUnlock() if err == nil { - logrus.Debugf("[control] plugin %s of all : %d", m.service, c.GroupID, c.Disable) + logrus.Debugf("[control] plugin %s of all : %d", m.service, c.Disable) return c.Disable == 0 } return !m.options.DisableOnDefault @@ -117,6 +119,7 @@ func (m *Control) Handler() zero.Rule { // 个人用户 grp = -ctx.Event.UserID } + logrus.Debugln("[control] handler get gid =", grp) return m.IsEnabledIn(grp) } } @@ -150,6 +153,13 @@ func copyMap(m map[string]*Control) map[string]*Control { return ret } +func userOrGrpAdmin(ctx *zero.Ctx) bool { + if zero.OnlyGroup(ctx) { + return zero.AdminPermission(ctx) + } + return zero.OnlyToMe(ctx) +} + func init() { if !hasinit { mu.Lock() @@ -162,12 +172,7 @@ func init() { zero.OnCommandGroup([]string{ "启用", "enable", "禁用", "disable", "全局启用", "enableall", "全局禁用", "disableall", - }, func(ctx *zero.Ctx) bool { - if zero.OnlyGroup(ctx) { - return zero.AdminPermission(ctx) - } - return zero.OnlyToMe(ctx) - }).Handle(func(ctx *zero.Ctx) { + }, userOrGrpAdmin).Handle(func(ctx *zero.Ctx) { model := extension.CommandModel{} _ = ctx.Parse(&model) service, ok := Lookup(model.Args) @@ -191,12 +196,7 @@ func init() { } }) - zero.OnCommandGroup([]string{"还原", "reset"}, func(ctx *zero.Ctx) bool { - if zero.OnlyGroup(ctx) { - return zero.AdminPermission(ctx) - } - return zero.OnlyToMe(ctx) - }).Handle(func(ctx *zero.Ctx) { + zero.OnCommandGroup([]string{"还原", "reset"}, userOrGrpAdmin).Handle(func(ctx *zero.Ctx) { model := extension.CommandModel{} _ = ctx.Parse(&model) service, ok := Lookup(model.Args) @@ -212,7 +212,7 @@ func init() { ctx.SendChain(message.Text("已还原服务的默认启用状态: " + model.Args)) }) - zero.OnCommandGroup([]string{"用法", "usage"}, zero.AdminPermission, zero.OnlyGroup). + zero.OnCommandGroup([]string{"用法", "usage"}, userOrGrpAdmin). Handle(func(ctx *zero.Ctx) { model := extension.CommandModel{} _ = ctx.Parse(&model) @@ -227,14 +227,15 @@ func init() { } }) - zero.OnCommandGroup([]string{"服务列表", "service_list"}, zero.AdminPermission, zero.OnlyGroup). + zero.OnCommandGroup([]string{"服务列表", "service_list"}, userOrGrpAdmin). Handle(func(ctx *zero.Ctx) { msg := `---服务列表---` i := 0 + gid := ctx.Event.GroupID ForEach(func(key string, manager *Control) bool { i++ msg += "\n" + strconv.Itoa(i) + `: ` - if manager.IsEnabledIn(ctx.Event.GroupID) { + if manager.IsEnabledIn(gid) { msg += "●" + key } else { msg += "○" + key diff --git a/utils/sql/sqlite.go b/utils/sql/sqlite.go index a8fcd2b3..9efd4d6a 100644 --- a/utils/sql/sqlite.go +++ b/utils/sql/sqlite.go @@ -48,17 +48,15 @@ func (db *Sqlite) Create(table string, objptr interface{}) (err error) { cmd = append(cmd, "NULL);") } } - if _, err := db.DB.Exec(strings.Join(cmd, " ")); err != nil { - return err - } - return nil + _, err = db.DB.Exec(strings.Join(cmd, " ") + ";") + return } // Insert 插入数据集 // 默认结构体的第一个元素为主键 // 返回错误 -func (db *Sqlite) Insert(table string, objptr interface{}) (err error) { - rows, err := db.DB.Query("SELECT * FROM " + table) +func (db *Sqlite) Insert(table string, objptr interface{}) error { + rows, err := db.DB.Query("SELECT * FROM " + table + ";") if err != nil { return err } @@ -102,7 +100,7 @@ func (db *Sqlite) Insert(table string, objptr interface{}) (err error) { cmd = append(cmd, ")") } } - stmt, err := db.DB.Prepare(strings.Join(cmd, " ")) + stmt, err := db.DB.Prepare(strings.Join(cmd, " ") + ";") if err != nil { return err } @@ -110,19 +108,19 @@ func (db *Sqlite) Insert(table string, objptr interface{}) (err error) { if err != nil { return err } - return nil + return stmt.Close() } // Find 查询数据库 // condition 可为"WHERE id = 0" // 默认字段与结构体元素顺序一致 // 返回错误 -func (db *Sqlite) Find(table string, objptr interface{}, condition string) (err error) { +func (db *Sqlite) Find(table string, objptr interface{}, condition string) error { var cmd = []string{} cmd = append(cmd, "SELECT * FROM ") cmd = append(cmd, table) cmd = append(cmd, condition) - rows, err := db.DB.Query(strings.Join(cmd, " ")) + rows, err := db.DB.Query(strings.Join(cmd, " ") + ";") if err != nil { return err } @@ -171,12 +169,12 @@ func (db *Sqlite) ListTables() (s []string, err error) { // Del 删除数据库 // condition 可为"WHERE id = 0" // 返回错误 -func (db *Sqlite) Del(table string, condition string) (err error) { +func (db *Sqlite) Del(table string, condition string) error { var cmd = []string{} cmd = append(cmd, "DELETE FROM") cmd = append(cmd, table) cmd = append(cmd, condition) - stmt, err := db.DB.Prepare(strings.Join(cmd, " ")) + stmt, err := db.DB.Prepare(strings.Join(cmd, " ") + ";") if err != nil { return err } @@ -184,7 +182,7 @@ func (db *Sqlite) Del(table string, condition string) (err error) { if err != nil { return err } - return nil + return stmt.Close() } // Count 查询数据库行数 @@ -193,17 +191,17 @@ func (db *Sqlite) Count(table string) (num int, err error) { var cmd = []string{} cmd = append(cmd, "SELECT * FROM") cmd = append(cmd, table) - rows, err := db.DB.Query(strings.Join(cmd, " ")) + rows, err := db.DB.Query(strings.Join(cmd, " ") + ";") if err != nil { return num, err } if rows.Err() != nil { return num, rows.Err() } - defer rows.Close() for rows.Next() { num++ } + rows.Close() return num, nil }