From bc9db11cb41d3dc1b78ec1457a2b9e5109d4b5b2 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Sun, 14 Dec 2025 19:51:18 +0800 Subject: [PATCH] chore: hub/route module handle websocket itself --- common/net/websocket.go | 12 +++ hub/route/common.go | 157 +++++++++++++++++++++++++++++++++++ hub/route/connections.go | 6 +- hub/route/server.go | 14 ++-- transport/vmess/websocket.go | 15 +--- 5 files changed, 179 insertions(+), 25 deletions(-) diff --git a/common/net/websocket.go b/common/net/websocket.go index b002310a..c49e60fa 100644 --- a/common/net/websocket.go +++ b/common/net/websocket.go @@ -1,6 +1,8 @@ package net import ( + "crypto/sha1" + "encoding/base64" "encoding/binary" "math/bits" ) @@ -129,3 +131,13 @@ func MaskWebSocket(key uint32, b []byte) uint32 { return key } + +func GetWebSocketSecAccept(secKey string) string { + const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize) + p := make([]byte, nonceSize+len(magic)) + copy(p[:nonceSize], secKey) + copy(p[nonceSize:], magic) + sum := sha1.Sum(p) + return base64.StdEncoding.EncodeToString(sum[:]) +} diff --git a/hub/route/common.go b/hub/route/common.go index d0053e67..6dd0b40c 100644 --- a/hub/route/common.go +++ b/hub/route/common.go @@ -1,10 +1,19 @@ package route import ( + "bufio" + "encoding/binary" + "errors" + "io" + "net" "net/http" "net/url" + "strconv" + "strings" + "time" "github.com/go-chi/chi/v5" + N "github.com/metacubex/mihomo/common/net" ) // When name is composed of a partial escape string, Golang does not unescape it @@ -15,3 +24,151 @@ func getEscapeParam(r *http.Request, paramName string) string { } return param } + +// wsUpgrade upgrades http connection to the websocket connection. +// +// It hijacks net.Conn from w and returns received net.Conn and +// bufio.ReadWriter. +func wsUpgrade(r *http.Request, w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, err error) { + // See https://tools.ietf.org/html/rfc6455#section-4.1 + // The method of the request MUST be GET, and the HTTP version MUST be at least 1.1. + var nonce string + if r.Method != http.MethodGet { + err = errors.New("handshake error: bad HTTP request method") + body := err.Error() + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(http.StatusMethodNotAllowed) + w.Write([]byte(body)) + return nil, nil, err + } else if r.ProtoMajor < 1 || (r.ProtoMajor == 1 && r.ProtoMinor < 1) { + err = errors.New("handshake error: bad HTTP protocol version") + body := err.Error() + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(http.StatusHTTPVersionNotSupported) + w.Write([]byte(body)) + return nil, nil, err + } else if r.Host == "" { + err = errors.New("handshake error: bad Host header") + body := err.Error() + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(body)) + return nil, nil, err + } else if u := r.Header.Get("Upgrade"); u != "websocket" && !strings.EqualFold(u, "websocket") { + err = errors.New("handshake error: bad Upgrade header") + body := err.Error() + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(body)) + return nil, nil, err + } else if c := r.Header.Get("Connection"); c != "Upgrade" && !strings.Contains(strings.ToLower(c), "upgrade") { + err = errors.New("handshake error: bad Connection header") + body := err.Error() + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(body)) + return nil, nil, err + } else if nonce = r.Header.Get("Sec-WebSocket-Key"); len(nonce) != 24 { + err = errors.New("handshake error: bad Sec-WebSocket-Key header") + body := err.Error() + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(body)) + return nil, nil, err + } else if v := r.Header.Get("Sec-WebSocket-Version"); v != "13" { + err = errors.New("handshake error: bad Sec-WebSocket-Version header") + body := err.Error() + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + if v != "" { + // According to RFC6455: + // If this version does not match a version understood by the server, the + // server MUST abort the WebSocket handshake described in this section and + // instead send an appropriate HTTP error code (such as 426 Upgrade Required) + // and a |Sec-WebSocket-Version| header field indicating the version(s) the + // server is capable of understanding. + w.Header().Set("Sec-WebSocket-Version", "13") + w.WriteHeader(http.StatusUpgradeRequired) + } else { + w.WriteHeader(http.StatusBadRequest) + } + w.Write([]byte(body)) + return nil, nil, err + } + + conn, rw, err = http.NewResponseController(w).Hijack() + if err != nil { + body := err.Error() + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(body)) + return nil, nil, err + } + + // Clear deadlines set by server. + conn.SetDeadline(time.Time{}) + + rw.Writer.WriteString("HTTP/1.1 101 Switching Protocols\r\n") + header := http.Header{} + header.Set("Upgrade", "websocket") + header.Set("Connection", "Upgrade") + header.Set("Sec-WebSocket-Accept", N.GetWebSocketSecAccept(nonce)) + header.Write(rw.Writer) + rw.Writer.WriteString("\r\n") + err = rw.Writer.Flush() + + return conn, rw, err +} + +// wsWriteServerMessage writes message to w, considering that caller represents server side. +func wsWriteServerMessage(w io.Writer, op byte, p []byte) error { + dataLen := len(p) + + // Make slice of bytes with capacity 14 that could hold any header. + bts := make([]byte, 14) + + bts[0] |= 0x80 //FIN + bts[0] |= 0 << 4 //RSV + bts[0] |= op //OPCODE + + var n int + switch { + case dataLen < 126: + bts[1] = byte(dataLen) + n = 2 + case dataLen < 65536: + bts[1] = 126 + binary.BigEndian.PutUint16(bts[2:4], uint16(dataLen)) + n = 4 + default: + bts[1] = 127 + binary.BigEndian.PutUint64(bts[2:10], uint64(dataLen)) + n = 10 + } + + _, err := w.Write(bts[:n]) + if err != nil { + return err + } + _, err = w.Write(p) + return err +} + +// wsWriteServerText is the same as wsWriteServerMessage with ws.OpText. +func wsWriteServerText(w io.Writer, p []byte) error { + const opText = 0x1 + return wsWriteServerMessage(w, opText, p) +} + +// wsWriteServerBinary is the same as wsWriteServerMessage with ws.OpBinary. +func wsWriteServerBinary(w io.Writer, p []byte) error { + const opBinary = 0x2 + return wsWriteServerMessage(w, opBinary, p) +} diff --git a/hub/route/connections.go b/hub/route/connections.go index e0ff2426..2ae1e885 100644 --- a/hub/route/connections.go +++ b/hub/route/connections.go @@ -11,8 +11,6 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/render" - "github.com/gobwas/ws" - "github.com/gobwas/ws/wsutil" ) func connectionRouter() http.Handler { @@ -30,7 +28,7 @@ func getConnections(w http.ResponseWriter, r *http.Request) { return } - conn, _, _, err := ws.UpgradeHTTP(r, w) + conn, _, err := wsUpgrade(r, w) if err != nil { return } @@ -56,7 +54,7 @@ func getConnections(w http.ResponseWriter, r *http.Request) { return err } - return wsutil.WriteMessage(conn, ws.StateServerSide, ws.OpText, buf.Bytes()) + return wsWriteServerText(conn, buf.Bytes()) } if err := sendSnapshot(); err != nil { diff --git a/hub/route/server.go b/hub/route/server.go index fe827530..27c13f75 100644 --- a/hub/route/server.go +++ b/hub/route/server.go @@ -26,8 +26,6 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/render" - "github.com/gobwas/ws" - "github.com/gobwas/ws/wsutil" "github.com/sagernet/cors" ) @@ -363,7 +361,7 @@ func traffic(w http.ResponseWriter, r *http.Request) { var wsConn net.Conn if r.Header.Get("Upgrade") == "websocket" { var err error - wsConn, _, _, err = ws.UpgradeHTTP(r, w) + wsConn, _, err = wsUpgrade(r, w) if err != nil { return } @@ -396,7 +394,7 @@ func traffic(w http.ResponseWriter, r *http.Request) { _, err = w.Write(buf.Bytes()) w.(http.Flusher).Flush() } else { - err = wsutil.WriteMessage(wsConn, ws.StateServerSide, ws.OpText, buf.Bytes()) + err = wsWriteServerText(wsConn, buf.Bytes()) } if err != nil { @@ -409,7 +407,7 @@ func memory(w http.ResponseWriter, r *http.Request) { var wsConn net.Conn if r.Header.Get("Upgrade") == "websocket" { var err error - wsConn, _, _, err = ws.UpgradeHTTP(r, w) + wsConn, _, err = wsUpgrade(r, w) if err != nil { return } @@ -446,7 +444,7 @@ func memory(w http.ResponseWriter, r *http.Request) { _, err = w.Write(buf.Bytes()) w.(http.Flusher).Flush() } else { - err = wsutil.WriteMessage(wsConn, ws.StateServerSide, ws.OpText, buf.Bytes()) + err = wsWriteServerText(wsConn, buf.Bytes()) } if err != nil { @@ -492,7 +490,7 @@ func getLogs(w http.ResponseWriter, r *http.Request) { var wsConn net.Conn if r.Header.Get("Upgrade") == "websocket" { var err error - wsConn, _, _, err = ws.UpgradeHTTP(r, w) + wsConn, _, err = wsUpgrade(r, w) if err != nil { return } @@ -551,7 +549,7 @@ func getLogs(w http.ResponseWriter, r *http.Request) { _, err = w.Write(buf.Bytes()) w.(http.Flusher).Flush() } else { - err = wsutil.WriteMessage(wsConn, ws.StateServerSide, ws.OpText, buf.Bytes()) + err = wsWriteServerText(wsConn, buf.Bytes()) } if err != nil { diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index 07fb9d6a..73b743da 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "crypto/rand" - "crypto/sha1" "crypto/tls" "encoding/base64" "encoding/binary" @@ -478,7 +477,7 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, if lenSecAccept := len(secAccept); lenSecAccept != acceptSize { return nil, fmt.Errorf("unexpected Sec-Websocket-Accept length: %d", lenSecAccept) } - if getSecAccept(secKey) != secAccept { + if N.GetWebSocketSecAccept(secKey) != secAccept { return nil, errors.New("unexpected Sec-Websocket-Accept") } } @@ -489,16 +488,6 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, return N.NewDeadlineConn(conn), nil } -func getSecAccept(secKey string) string { - const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize) - p := make([]byte, nonceSize+len(magic)) - copy(p[:nonceSize], secKey) - copy(p[nonceSize:], magic) - sum := sha1.Sum(p) - return base64.StdEncoding.EncodeToString(sum[:]) -} - func StreamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig) (net.Conn, error) { if u, err := url.Parse(c.Path); err == nil { if q := u.Query(); q.Get("ed") != "" { @@ -568,7 +557,7 @@ func StreamUpgradedWebsocketConn(w http.ResponseWriter, r *http.Request) (net.Co w.Header().Set("Connection", "upgrade") w.Header().Set("Upgrade", "websocket") if !isRaw { - w.Header().Set("Sec-Websocket-Accept", getSecAccept(r.Header.Get("Sec-WebSocket-Key"))) + w.Header().Set("Sec-Websocket-Accept", N.GetWebSocketSecAccept(r.Header.Get("Sec-WebSocket-Key"))) } w.WriteHeader(http.StatusSwitchingProtocols) if flusher, isFlusher := w.(interface{ FlushError() error }); isFlusher && writeHeaderShouldFlush {