diff --git a/plugin_b14/main.go b/plugin_b14/main.go index 0c6a1003..577cd293 100644 --- a/plugin_b14/main.go +++ b/plugin_b14/main.go @@ -1,3 +1,4 @@ +// Package b14coder base16384 与 tea 加解密 package b14coder import ( diff --git a/plugin_diana/data/migrate/text.go b/plugin_diana/data/migrate/text.go index 81da71bc..87aa3073 100644 --- a/plugin_diana/data/migrate/text.go +++ b/plugin_diana/data/migrate/text.go @@ -2,9 +2,9 @@ package main import ( "crypto/md5" + "encoding/binary" "fmt" "os" - "unsafe" "github.com/RomiChan/protobuf/proto" "github.com/wdvxdr1123/ZeroBot/utils/helper" @@ -37,7 +37,7 @@ func main() { } for _, d := range compo.Array { s := md5.Sum(helper.StringToBytes(d)) - i := *(*int64)(unsafe.Pointer(&s)) + i := int64(binary.LittleEndian.Uint64(s[:8])) fmt.Printf("[Diana]id: %d\n", i) err = db.Insert("text", &Text{ Id: i, diff --git a/plugin_diana/data/text.go b/plugin_diana/data/text.go index 4e1b377c..0a1f2428 100644 --- a/plugin_diana/data/text.go +++ b/plugin_diana/data/text.go @@ -3,8 +3,8 @@ package data import ( "crypto/md5" + "encoding/binary" "os" - "unsafe" log "github.com/sirupsen/logrus" "github.com/wdvxdr1123/ZeroBot/utils/helper" @@ -56,8 +56,8 @@ func LoadText() error { // AddText 添加小作文 func AddText(txt string) error { s := md5.Sum(helper.StringToBytes(txt)) - i := *(*int64)(unsafe.Pointer(&s)) - return db.Insert("text", &Text{Id: i, Data: txt}) + i := binary.LittleEndian.Uint64(s[:8]) + return db.Insert("text", &Text{Id: int64(i), Data: txt}) } // RandText 随机小作文 diff --git a/plugin_manager/manager.go b/plugin_manager/manager.go index 9359a90b..f9775ab4 100644 --- a/plugin_manager/manager.go +++ b/plugin_manager/manager.go @@ -259,9 +259,9 @@ func init() { // 插件主体 engine.OnRegex(`^在(.{1,2})月(.{1,3}日|每?周.?)的(.{1,3})点(.{1,3})分时(用.+)?提醒大家(.*)`, zero.AdminPermission, zero.OnlyGroup).SetBlock(true).SetPriority(40). Handle(func(ctx *zero.Ctx) { dateStrs := ctx.State["regex_matched"].([]string) - ts := timer.GetFilledTimer(dateStrs, ctx.Event.SelfID, false) + ts := timer.GetFilledTimer(dateStrs, ctx.Event.SelfID, ctx.Event.GroupID, false) if ts.En() { - go clock.RegisterTimer(ts, ctx.Event.GroupID, true) + go clock.RegisterTimer(ts, true) ctx.SendChain(message.Text("记住了~")) } else { ctx.SendChain(message.Text("参数非法:" + ts.Alert)) @@ -283,8 +283,8 @@ func init() { // 插件主体 return } logrus.Debugln("[manager] cron:", dateStrs[1]) - ts := timer.GetFilledCronTimer(dateStrs[1], alert, url, ctx.Event.SelfID) - if clock.RegisterTimer(ts, ctx.Event.GroupID, true) { + ts := timer.GetFilledCronTimer(dateStrs[1], alert, url, ctx.Event.SelfID, ctx.Event.GroupID) + if clock.RegisterTimer(ts, true) { ctx.SendChain(message.Text("记住了~")) } else { ctx.SendChain(message.Text("参数非法:" + ts.Alert)) @@ -294,8 +294,8 @@ func init() { // 插件主体 engine.OnRegex(`^取消在(.{1,2})月(.{1,3}日|每?周.?)的(.{1,3})点(.{1,3})分的提醒`, zero.AdminPermission, zero.OnlyGroup).SetBlock(true).SetPriority(40). Handle(func(ctx *zero.Ctx) { dateStrs := ctx.State["regex_matched"].([]string) - ts := timer.GetFilledTimer(dateStrs, ctx.Event.SelfID, true) - ti := ts.GetTimerInfo(ctx.Event.GroupID) + ts := timer.GetFilledTimer(dateStrs, ctx.Event.SelfID, ctx.Event.GroupID, true) + ti := ts.GetTimerID() ok := clock.CancelTimer(ti) if ok { ctx.SendChain(message.Text("取消成功~")) @@ -307,8 +307,8 @@ func init() { // 插件主体 engine.OnRegex(`^取消在"(.*)"的提醒`, zero.AdminPermission, zero.OnlyGroup).SetBlock(true).SetPriority(40). Handle(func(ctx *zero.Ctx) { dateStrs := ctx.State["regex_matched"].([]string) - ts := timer.Timer{Cron: dateStrs[1]} - ti := ts.GetTimerInfo(ctx.Event.GroupID) + ts := timer.Timer{Cron: dateStrs[1], GrpId: ctx.Event.GroupID} + ti := ts.GetTimerID() ok := clock.CancelTimer(ti) if ok { ctx.SendChain(message.Text("取消成功~")) diff --git a/plugin_manager/timer/parse.go b/plugin_manager/timer/parse.go index 83986817..7c24ef2b 100644 --- a/plugin_manager/timer/parse.go +++ b/plugin_manager/timer/parse.go @@ -1,6 +1,8 @@ package timer import ( + "crypto/md5" + "encoding/binary" "fmt" "strconv" "strings" @@ -8,28 +10,37 @@ import ( "unicode" "github.com/sirupsen/logrus" + "github.com/wdvxdr1123/ZeroBot/utils/helper" ) // GetTimerInfo 获得标准化定时字符串 -func (ts *Timer) GetTimerInfo(grp int64) string { +func (ts *Timer) GetTimerInfo() string { if ts.Cron != "" { - return fmt.Sprintf("[%d]%s", grp, ts.Cron) + return fmt.Sprintf("[%d]%s", ts.GrpId, ts.Cron) } - return fmt.Sprintf("[%d]%d月%d日%d周%d:%d", grp, ts.Month(), ts.Day(), ts.Week(), ts.Hour(), ts.Minute()) + return fmt.Sprintf("[%d]%d月%d日%d周%d:%d", ts.GrpId, ts.Month(), ts.Day(), ts.Week(), ts.Hour(), ts.Minute()) +} + +// GetTimerInfo 获得标准化 ID +func (ts *Timer) GetTimerID() uint32 { + key := ts.GetTimerInfo() + m := md5.Sum(helper.StringToBytes(key)) + return binary.LittleEndian.Uint32(m[:4]) } // GetFilledCronTimer 获得以cron填充好的ts -func GetFilledCronTimer(croncmd string, alert string, img string, botqq int64) *Timer { +func GetFilledCronTimer(croncmd string, alert string, img string, botqq, gid int64) *Timer { var ts Timer ts.Alert = alert ts.Cron = croncmd ts.Url = img ts.Selfid = botqq + ts.GrpId = gid return &ts } // GetFilledTimer 获得填充好的ts -func GetFilledTimer(dateStrs []string, botqq int64, matchDateOnly bool) *Timer { +func GetFilledTimer(dateStrs []string, botqq, grp int64, matchDateOnly bool) *Timer { monthStr := []rune(dateStrs[1]) dayWeekStr := []rune(dateStrs[2]) hourStr := []rune(dateStrs[3]) @@ -43,7 +54,8 @@ func GetFilledTimer(dateStrs []string, botqq int64, matchDateOnly bool) *Timer { } ts.SetMonth(mon) lenOfDW := len(dayWeekStr) - if lenOfDW == 4 { // 包括末尾的"日" + switch { + case lenOfDW == 4: // 包括末尾的"日" dayWeekStr = []rune{dayWeekStr[0], dayWeekStr[2]} // 去除中间的十 d := chineseNum2Int(dayWeekStr) if (d != -1 && d <= 0) || d > 31 { // 日期非法 @@ -51,7 +63,7 @@ func GetFilledTimer(dateStrs []string, botqq int64, matchDateOnly bool) *Timer { return &ts } ts.SetDay(d) - } else if dayWeekStr[lenOfDW-1] == rune('日') { // xx日 + case dayWeekStr[lenOfDW-1] == rune('日'): // xx日 dayWeekStr = dayWeekStr[:lenOfDW-1] d := chineseNum2Int(dayWeekStr) if (d != -1 && d <= 0) || d > 31 { // 日期非法 @@ -59,9 +71,9 @@ func GetFilledTimer(dateStrs []string, botqq int64, matchDateOnly bool) *Timer { return &ts } ts.SetDay(d) - } else if dayWeekStr[0] == rune('每') { // 每周 + case dayWeekStr[0] == rune('每'): // 每周 ts.SetWeek(-1) - } else { // 周x + default: // 周x w := chineseNum2Int(dayWeekStr[1:]) if w == 7 { // 周天是0 w = 0 @@ -105,6 +117,7 @@ func GetFilledTimer(dateStrs []string, botqq int64, matchDateOnly bool) *Timer { ts.SetEn(true) } ts.Selfid = botqq + ts.GrpId = grp return &ts } @@ -116,13 +129,14 @@ func chineseNum2Int(rs []rune) int { if unicode.IsDigit(rs[0]) { // 默认可能存在的第二位也为int r, _ = strconv.Atoi(string(rs)) } else { - if rs[0] == mai { + switch { + case rs[0] == mai: if l == 2 { r = -chineseChar2Int(rs[1]) } - } else if l == 1 { + case l == 1: r = chineseChar2Int(rs[0]) - } else { + default: ten := chineseChar2Int(rs[0]) if ten != 10 { ten *= 10 diff --git a/plugin_manager/timer/sleep.go b/plugin_manager/timer/sleep.go index ddcc51a0..73a11ba6 100644 --- a/plugin_manager/timer/sleep.go +++ b/plugin_manager/timer/sleep.go @@ -56,11 +56,12 @@ func (ts *Timer) nextWakeTime() (date time.Time) { } else { stable |= 0x8 } - if d < 0 { + switch { + case d < 0: d = date.Day() - } else if d > 0 { + case d > 0: stable |= 0x4 - } else { + default: d = date.Day() if w >= 0 { stable |= 0x2 @@ -148,14 +149,14 @@ func (ts *Timer) nextWakeTime() (date time.Time) { return date } -func (ts *Timer) judgeHM(grp int64) { +func (ts *Timer) judgeHM() { if ts.Hour() < 0 || ts.Hour() == time.Now().Hour() { if ts.Minute() < 0 || ts.Minute() == time.Now().Minute() { if ts.Selfid != 0 { - ts.sendmsg(grp, zero.GetBot(ts.Selfid)) + ts.sendmsg(ts.GrpId, zero.GetBot(ts.Selfid)) } else { zero.RangeBot(func(id int64, ctx *zero.Ctx) (_ bool) { - ts.sendmsg(grp, ctx) + ts.sendmsg(ts.GrpId, ctx) return }) } diff --git a/plugin_manager/timer/timer.db.go b/plugin_manager/timer/timer.db.go new file mode 100644 index 00000000..ea05e93f --- /dev/null +++ b/plugin_manager/timer/timer.db.go @@ -0,0 +1,26 @@ +package timer + +import ( + "strconv" + + "github.com/FloatTech/ZeroBot-Plugin/utils/sql" +) + +type Timer struct { + Id uint32 `db:"id"` + En1Month4Day5Week3Hour5Min6 int32 `db:"emdwhm"` + Selfid int64 `db:"sid"` + GrpId int64 `db:"gid"` + Alert string `db:"alert"` + Cron string `db:"cron"` + Url string `db:"url"` +} + +func (t *Timer) InsertInto(db *sql.Sqlite) error { + return db.Insert("timer", t) +} + +func getTimerFrom(db *sql.Sqlite, id uint32) (t Timer, err error) { + err = db.Find("timer", &t, "where id = "+strconv.Itoa(int(id))) + return +} diff --git a/plugin_manager/timer/timer.go b/plugin_manager/timer/timer.go index c8f658a3..9d6cf377 100644 --- a/plugin_manager/timer/timer.go +++ b/plugin_manager/timer/timer.go @@ -2,34 +2,28 @@ package timer import ( - "io" - "os" "strconv" "strings" "sync" "time" - "github.com/RomiChan/protobuf/proto" "github.com/fumiama/cron" "github.com/sirupsen/logrus" zero "github.com/wdvxdr1123/ZeroBot" "github.com/wdvxdr1123/ZeroBot/message" "github.com/FloatTech/ZeroBot-Plugin/utils/file" + "github.com/FloatTech/ZeroBot-Plugin/utils/sql" ) type Clock struct { - // 记录每个定时器以便取消 - timersmap TimersMap - // 定时器map - timers *(map[string]*Timer) + db *sql.Sqlite + timers *(map[uint32]*Timer) timersmu sync.RWMutex - // 定时器存储位置 - pbfile *string // cron 定时器 cron *cron.Cron // entries key <-> cron - entries map[string]cron.EntryID + entries map[uint32]cron.EntryID entmu sync.Mutex } @@ -43,26 +37,27 @@ var ( } ) -func NewClock(pbfile string) (c Clock) { - c.loadTimers(pbfile) - c.timers = &c.timersmap.Timers - c.pbfile = &pbfile +func NewClock(dbfile string) (c Clock) { + c.loadTimers(dbfile) c.cron = cron.New() - c.entries = make(map[string]cron.EntryID) + c.entries = make(map[uint32]cron.EntryID) c.cron.Start() return } // RegisterTimer 注册计时器 -func (c *Clock) RegisterTimer(ts *Timer, grp int64, save bool) bool { - key := ts.GetTimerInfo(grp) +func (c *Clock) RegisterTimer(ts *Timer, save bool) bool { + var key uint32 + if save { + key = ts.GetTimerID() + ts.Id = key + } else { + key = ts.Id + } t, ok := c.GetTimer(key) if t != ts && ok { // 避免重复注册定时器 t.SetEn(false) } - c.timersmu.Lock() - (*c.timers)[key] = ts - c.timersmu.Unlock() logrus.Println("[群管]注册计时器", key) if ts.Cron != "" { var ctx *zero.Ctx @@ -75,33 +70,33 @@ func (c *Clock) RegisterTimer(ts *Timer, grp int64, save bool) bool { return false }) } - eid, err := c.cron.AddFunc(ts.Cron, func() { ts.sendmsg(grp, ctx) }) + eid, err := c.cron.AddFunc(ts.Cron, func() { ts.sendmsg(ts.GrpId, ctx) }) if err == nil { c.entmu.Lock() c.entries[key] = eid c.entmu.Unlock() if save { - c.SaveTimers() + err = c.AddTimer(ts) } - return true + return err == nil } ts.Alert = err.Error() } else { if save { - c.SaveTimers() + _ = c.AddTimer(ts) } for ts.En() { nextdate := ts.nextWakeTime() sleepsec := time.Until(nextdate) - logrus.Printf("[群管]计时器%s将睡眠%ds", key, sleepsec/time.Second) + logrus.Printf("[群管]计时器%08x将睡眠%ds", key, sleepsec/time.Second) time.Sleep(sleepsec) if ts.En() { if ts.Month() < 0 || ts.Month() == time.Now().Month() { if ts.Day() < 0 || ts.Day() == time.Now().Day() { - ts.judgeHM(grp) + ts.judgeHM() } else if ts.Day() == 0 { if ts.Week() < 0 || ts.Week() == time.Now().Weekday() { - ts.judgeHM(grp) + ts.judgeHM() } } } @@ -112,8 +107,8 @@ func (c *Clock) RegisterTimer(ts *Timer, grp int64, save bool) bool { } // CancelTimer 取消计时器 -func (c *Clock) CancelTimer(key string) bool { - t, ok := (*c.timers)[key] +func (c *Clock) CancelTimer(key uint32) bool { + t, ok := c.GetTimer(key) if ok { if t.Cron != "" { c.entmu.Lock() @@ -126,41 +121,22 @@ func (c *Clock) CancelTimer(key string) bool { } c.timersmu.Lock() delete(*c.timers, key) // 避免重复取消 + e := c.db.Del("timer", "where id = "+strconv.Itoa(int(key))) c.timersmu.Unlock() - _ = c.SaveTimers() + return e == nil } - return ok -} - -// SaveTimers 保存当前计时器 -func (c *Clock) SaveTimers() error { - c.timersmu.RLock() - data, err := proto.Marshal(&c.timersmap) - c.timersmu.RUnlock() - if err == nil { - c.timersmu.Lock() - defer c.timersmu.Unlock() - f, err1 := os.OpenFile(*c.pbfile, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0644) - if err1 != nil { - return err1 - } else { - _, err2 := f.Write(data) - f.Close() - return err2 - } - } - return err + return false } // ListTimers 列出本群所有计时器 func (c *Clock) ListTimers(grpID int64) []string { // 数组默认长度为map长度,后面append时,不需要重新申请内存和拷贝,效率很高 if c.timers != nil { - g := strconv.FormatInt(grpID, 10) c.timersmu.RLock() keys := make([]string, 0, len(*c.timers)) - for k := range *c.timers { - if strings.Contains(k, g) { + for _, v := range *c.timers { + if v.GrpId == grpID { + k := v.GetTimerInfo() start := strings.Index(k, "]") msg := strings.ReplaceAll(k[start+1:]+"\n", "-1", "每") msg = strings.ReplaceAll(msg, "月0日0周", "月周天") @@ -176,35 +152,32 @@ func (c *Clock) ListTimers(grpID int64) []string { } } -func (c *Clock) GetTimer(key string) (t *Timer, ok bool) { +func (c *Clock) GetTimer(key uint32) (t *Timer, ok bool) { c.timersmu.RLock() t, ok = (*c.timers)[key] c.timersmu.RUnlock() return } -func (c *Clock) loadTimers(pbfile string) { - if file.IsExist(pbfile) { - f, err := os.Open(pbfile) +func (c *Clock) AddTimer(t *Timer) (err error) { + c.timersmu.Lock() + (*c.timers)[t.Id] = t + err = c.db.Insert("timer", t) + c.timersmu.Unlock() + return +} + +func (c *Clock) loadTimers(dbfile string) { + if file.IsExist(dbfile) { + c.db.DBPath = dbfile + err := c.db.Create("timer", &Timer{}) if err == nil { - data, err := io.ReadAll(f) - if err == nil { - if len(data) > 0 { - err = proto.Unmarshal(data, &c.timersmap) - if err == nil { - for str, t := range c.timersmap.Timers { - grp, err := strconv.ParseInt(str[1:strings.Index(str, "]")], 10, 64) - if err == nil { - go c.RegisterTimer(t, grp, false) - } - } - return - } - logrus.Errorln("[群管]读取定时器文件失败,将在下一次保存时覆盖原文件。err:", err) - logrus.Errorln("[群管]如不希望被覆盖,请运行源码plugin_manager/timers/migrate下的程序将timers.pb刷新为新版") - } - } + var t Timer + c.db.FindFor("timer", &t, "", func() error { + tescape := t + go c.RegisterTimer(&tescape, false) + return nil + }) } } - c.timersmap.Timers = make(map[string]*Timer) } diff --git a/plugin_manager/timer/timer.pb.go b/plugin_manager/timer/timer.pb.go deleted file mode 100644 index c2a1c2cc..00000000 --- a/plugin_manager/timer/timer.pb.go +++ /dev/null @@ -1,13 +0,0 @@ -package timer - -type Timer struct { - Alert string `protobuf:"bytes,1,opt"` - Cron string `protobuf:"bytes,2,opt"` - En1Month4Day5Week3Hour5Min6 int32 `protobuf:"varint,4,opt"` - Selfid int64 `protobuf:"varint,8,opt"` - Url string `protobuf:"bytes,16,opt"` -} - -type TimersMap struct { - Timers map[string]*Timer `protobuf:"bytes,1,rep" protobuf_key:"bytes,1,opt" protobuf_val:"bytes,2,opt"` -} diff --git a/plugin_manager/timer/wrap.go b/plugin_manager/timer/wrap.go index 7675b689..e79c9760 100644 --- a/plugin_manager/timer/wrap.go +++ b/plugin_manager/timer/wrap.go @@ -2,10 +2,12 @@ package timer import "time" +// En isEnabled 1bit func (m *Timer) En() (en bool) { return m.En1Month4Day5Week3Hour5Min6&0x800000 != 0 } +// Month 4bits func (m *Timer) Month() (mon time.Month) { mon = time.Month((m.En1Month4Day5Week3Hour5Min6 & 0x780000) >> 19) if mon == 0b1111 { @@ -14,6 +16,7 @@ func (m *Timer) Month() (mon time.Month) { return } +// Day 5bits func (m *Timer) Day() (d int) { d = int((m.En1Month4Day5Week3Hour5Min6 & 0x07c000) >> 14) if d == 0b11111 { @@ -22,6 +25,7 @@ func (m *Timer) Day() (d int) { return } +// Week 3bits func (m *Timer) Week() (w time.Weekday) { w = time.Weekday((m.En1Month4Day5Week3Hour5Min6 & 0x003800) >> 11) if w == 0b111 { @@ -30,6 +34,7 @@ func (m *Timer) Week() (w time.Weekday) { return } +// Hour 5bits func (m *Timer) Hour() (h int) { h = int((m.En1Month4Day5Week3Hour5Min6 & 0x0007c0) >> 6) if h == 0b11111 { @@ -38,6 +43,7 @@ func (m *Timer) Hour() (h int) { return } +// Minute 6bits func (m *Timer) Minute() (min int) { min = int(m.En1Month4Day5Week3Hour5Min6 & 0x00003f) if min == 0b111111 { @@ -46,6 +52,7 @@ func (m *Timer) Minute() (min int) { return } +// SetEn ... func (m *Timer) SetEn(en bool) { if en { m.En1Month4Day5Week3Hour5Min6 |= 0x800000 @@ -54,22 +61,27 @@ func (m *Timer) SetEn(en bool) { } } +// SetMonth ... func (m *Timer) SetMonth(mon time.Month) { m.En1Month4Day5Week3Hour5Min6 = ((int32(mon) << 19) & 0x780000) | (m.En1Month4Day5Week3Hour5Min6 & 0x87ffff) } +// SetDay ... func (m *Timer) SetDay(d int) { m.En1Month4Day5Week3Hour5Min6 = ((int32(d) << 14) & 0x07c000) | (m.En1Month4Day5Week3Hour5Min6 & 0xf83fff) } +// SetWeek ... func (m *Timer) SetWeek(w time.Weekday) { m.En1Month4Day5Week3Hour5Min6 = ((int32(w) << 11) & 0x003800) | (m.En1Month4Day5Week3Hour5Min6 & 0xffc7ff) } +// SetHour ... func (m *Timer) SetHour(h int) { m.En1Month4Day5Week3Hour5Min6 = ((int32(h) << 6) & 0x0007c0) | (m.En1Month4Day5Week3Hour5Min6 & 0xfff83f) } +// SetMinute ... func (m *Timer) SetMinute(min int) { m.En1Month4Day5Week3Hour5Min6 = (int32(min) & 0x00003f) | (m.En1Month4Day5Week3Hour5Min6 & 0xffffc0) } diff --git a/utils/sql/sqlite.go b/utils/sql/sqlite.go index c237d309..99cd9de2 100644 --- a/utils/sql/sqlite.go +++ b/utils/sql/sqlite.go @@ -75,6 +75,7 @@ func (db *Sqlite) Create(table string, objptr interface{}) (err error) { } // Insert 插入数据集 +// 如果 PK 存在会覆盖 // 默认结构体的第一个元素为主键 // 返回错误 func (db *Sqlite) Insert(table string, objptr interface{}) error { @@ -133,7 +134,67 @@ func (db *Sqlite) Insert(table string, objptr interface{}) error { return stmt.Close() } -// Find 查询数据库 +// InsertUnique 插入数据集 +// 如果 PK 存在会报错 +// 默认结构体的第一个元素为主键 +// 返回错误 +func (db *Sqlite) InsertUnique(table string, objptr interface{}) error { + rows, err := db.DB.Query("SELECT * FROM " + table + " limit 1;") + if err != nil { + return err + } + if rows.Err() != nil { + return rows.Err() + } + tags, _ := rows.Columns() + rows.Close() + var ( + vals = values(objptr) + top = len(tags) - 1 + cmd = []string{} + ) + cmd = append(cmd, "INSERT INTO") + cmd = append(cmd, table) + for i := range tags { + switch i { + default: + cmd = append(cmd, tags[i]) + cmd = append(cmd, ",") + case 0: + cmd = append(cmd, "(") + cmd = append(cmd, tags[i]) + cmd = append(cmd, ",") + case top: + cmd = append(cmd, tags[i]) + cmd = append(cmd, ")") + } + } + for i := range tags { + switch i { + default: + cmd = append(cmd, "?") + cmd = append(cmd, ",") + case 0: + cmd = append(cmd, "VALUES (") + cmd = append(cmd, "?") + cmd = append(cmd, ",") + case top: + cmd = append(cmd, "?") + cmd = append(cmd, ")") + } + } + stmt, err := db.DB.Prepare(strings.Join(cmd, " ") + ";") + if err != nil { + return err + } + _, err = stmt.Exec(vals...) + if err != nil { + return err + } + return stmt.Close() +} + +// Find 查询数据库,写入最后一条结果到 objptr // condition 可为"WHERE id = 0" // 默认字段与结构体元素顺序一致 // 返回错误 @@ -164,6 +225,68 @@ func (db *Sqlite) Find(table string, objptr interface{}, condition string) error return err } +// CanFind 查询数据库是否有 condition +// condition 可为"WHERE id = 0" +// 默认字段与结构体元素顺序一致 +// 返回错误 +func (db *Sqlite) CanFind(table string, condition string) bool { + var cmd = []string{} + cmd = append(cmd, "SELECT * FROM") + cmd = append(cmd, table) + cmd = append(cmd, condition) + rows, err := db.DB.Query(strings.Join(cmd, " ") + ";") + if err != nil { + return false + } + if rows.Err() != nil { + return false + } + defer rows.Close() + + if !rows.Next() { + return false + } + _ = rows.Close() + return true +} + +// FindFor 查询数据库,用函数 f 遍历结果 +// condition 可为"WHERE id = 0" +// 默认字段与结构体元素顺序一致 +// 返回错误 +func (db *Sqlite) FindFor(table string, objptr interface{}, condition string, f func() error) error { + var cmd = []string{} + cmd = append(cmd, "SELECT * FROM") + cmd = append(cmd, table) + cmd = append(cmd, condition) + rows, err := db.DB.Query(strings.Join(cmd, " ") + ";") + if err != nil { + return err + } + if rows.Err() != nil { + return rows.Err() + } + defer rows.Close() + + if !rows.Next() { + return errors.New("sql.FindFor: null result") + } + err = rows.Scan(addrs(objptr)...) + if err == nil { + err = f() + } + for rows.Next() { + if err != nil { + return err + } + err = rows.Scan(addrs(objptr)...) + if err == nil { + err = f() + } + } + return err +} + // Pick 从 table 随机一行 func (db *Sqlite) Pick(table string, objptr interface{}) error { return db.Find(table, objptr, "ORDER BY RANDOM() limit 1")