From f73bf5a2709623ae2f15e63e8dfaa03a37f1df30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 15 Feb 2025 02:02:53 +0900 Subject: [PATCH] feat(aichat): add temp setting --- README.md | 1 + plugin/aichat/list.go | 4 ++-- plugin/aichat/main.go | 53 ++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 53 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 0777ad20..edef9d0a 100644 --- a/README.md +++ b/README.md @@ -1543,6 +1543,7 @@ print("run[CQ:image,file="+j["img"]+"]") `import _ "github.com/FloatTech/ZeroBot-Plugin/plugin/aichat"` - [x] 设置AI聊天触发概率10 + - [x] 设置AI聊天温度80 - [x] 设置AI聊天密钥xxx - [x] 设置AI聊天模型名xxx - [x] 设置AI聊天系统提示词xxx diff --git a/plugin/aichat/list.go b/plugin/aichat/list.go index 967d4929..e8dc3693 100644 --- a/plugin/aichat/list.go +++ b/plugin/aichat/list.go @@ -40,8 +40,8 @@ func (l *list) add(grp int64, txt string) { l.m[grp] = msgs } -func (l *list) body(mn, sysp string, grp int64) deepinfra.Model { - m := model.NewCustom(mn, "", 0.7, 0.9, 1024).System(sysp) +func (l *list) body(mn, sysp string, temp float32, grp int64) deepinfra.Model { + m := model.NewCustom(mn, "", temp, 0.9, 1024).System(sysp) l.mu.RLock() defer l.mu.RUnlock() for _, msg := range l.m[grp] { diff --git a/plugin/aichat/main.go b/plugin/aichat/main.go index a393b7f7..9bdd2603 100644 --- a/plugin/aichat/main.go +++ b/plugin/aichat/main.go @@ -27,7 +27,7 @@ var ( DisableOnDefault: false, Extra: control.ExtraFromString("aichat"), Brief: "OpenAI聊天", - Help: "- 设置AI聊天触发概率10\n- 设置AI聊天密钥xxx\n- 设置AI聊天模型名xxx\n- 设置AI聊天系统提示词xxx", + Help: "- 设置AI聊天触发概率10\n- 设置AI聊天温度80\n- 设置AI聊天密钥xxx\n- 设置AI聊天模型名xxx\n- 设置AI聊天系统提示词xxx", PrivateDataFolder: "aichat", }) lst = newlist() @@ -73,6 +73,8 @@ func init() { return } rate := c.GetData(gid) + temp := (rate >> 8) & 0xff + rate &= 0xff if !ctx.Event.IsToMe && rand.Intn(100) >= int(rate) { return } @@ -94,7 +96,13 @@ func init() { } else { y = api } - data, err := y.Request(lst.body(modelname, systemprompt, gid)) + if temp <= 0 { + temp = 80 // default setting + } + if temp > 100 { + temp = 100 + } + data, err := y.Request(lst.body(modelname, systemprompt, float32(temp)/100, gid)) if err != nil { logrus.Warnln("[niniqun] post err:", err) return @@ -138,11 +146,50 @@ func init() { ctx.SendChain(message.Text("ERROR: parse rate err: ", err)) return } + if r > 100 { + r = 100 + } else if r < 0 { + r = 0 + } gid := ctx.Event.GroupID if gid == 0 { gid = -ctx.Event.UserID } - err = c.SetData(gid, int64(r&0xff)) + val := c.GetData(gid) & (^0xff) + err = c.SetData(gid, val|int64(r&0xff)) + if err != nil { + ctx.SendChain(message.Text("ERROR: set data err: ", err)) + return + } + ctx.SendChain(message.Text("成功")) + }) + en.OnPrefix("设置AI聊天温度", zero.AdminPermission).SetBlock(true).Handle(func(ctx *zero.Ctx) { + args := strings.TrimSpace(ctx.State["args"].(string)) + if args == "" { + ctx.SendChain(message.Text("ERROR: empty args")) + return + } + c, ok := ctx.State["manager"].(*ctrl.Control[*zero.Ctx]) + if !ok { + ctx.SendChain(message.Text("ERROR: no such plugin")) + return + } + r, err := strconv.Atoi(args) + if err != nil { + ctx.SendChain(message.Text("ERROR: parse rate err: ", err)) + return + } + if r > 100 { + r = 100 + } else if r < 0 { + r = 0 + } + gid := ctx.Event.GroupID + if gid == 0 { + gid = -ctx.Event.UserID + } + val := c.GetData(gid) & (^0xff00) + err = c.SetData(gid, val|(int64(r&0xff)<<8)) if err != nil { ctx.SendChain(message.Text("ERROR: set data err: ", err)) return