From 50e1afd96375a936d8830c0df02ba35c67792746 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Mon, 8 Sep 2025 15:58:54 +0800 Subject: [PATCH] chore: cleanup vless code --- transport/vless/conn.go | 106 +++++++++++++------------------ transport/vless/vision/conn.go | 2 +- transport/vless/vision/vision.go | 2 +- transport/vless/vless.go | 4 +- 4 files changed, 47 insertions(+), 67 deletions(-) diff --git a/transport/vless/conn.go b/transport/vless/conn.go index 94ae71ee..f43d77e1 100644 --- a/transport/vless/conn.go +++ b/transport/vless/conn.go @@ -5,7 +5,6 @@ import ( "errors" "io" "net" - "sync" "github.com/metacubex/mihomo/common/buf" N "github.com/metacubex/mihomo/common/net" @@ -16,17 +15,12 @@ import ( ) type Conn struct { - N.ExtendedWriter - N.ExtendedReader - net.Conn + N.ExtendedConn dst *DstAddr - id *uuid.UUID + id uuid.UUID addons *Addons received bool - - handshakeMutex sync.Mutex - needHandshake bool - err error + sent bool } func (vc *Conn) Read(b []byte) (int, error) { @@ -36,7 +30,7 @@ func (vc *Conn) Read(b []byte) (int, error) { } vc.received = true } - return vc.ExtendedReader.Read(b) + return vc.ExtendedConn.Read(b) } func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error { @@ -46,58 +40,39 @@ func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error { } vc.received = true } - return vc.ExtendedReader.ReadBuffer(buffer) + return vc.ExtendedConn.ReadBuffer(buffer) } func (vc *Conn) Write(p []byte) (int, error) { - if vc.needHandshake { - vc.handshakeMutex.Lock() - if vc.needHandshake { - vc.needHandshake = false - if vc.sendRequest(p) { - vc.handshakeMutex.Unlock() - if vc.err != nil { - return 0, vc.err - } - return len(p), vc.err - } - if vc.err != nil { - vc.handshakeMutex.Unlock() - return 0, vc.err - } + if !vc.sent { + if err := vc.sendRequest(p); err != nil { + return 0, err } - vc.handshakeMutex.Unlock() + vc.sent = true + return len(p), nil } - return vc.ExtendedWriter.Write(p) + return vc.ExtendedConn.Write(p) } func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error { - if vc.needHandshake { - vc.handshakeMutex.Lock() - if vc.needHandshake { - vc.needHandshake = false - if vc.sendRequest(buffer.Bytes()) { - vc.handshakeMutex.Unlock() - return vc.err - } - if vc.err != nil { - vc.handshakeMutex.Unlock() - return vc.err - } + if !vc.sent { + if err := vc.sendRequest(buffer.Bytes()); err != nil { + return err } - vc.handshakeMutex.Unlock() + vc.sent = true + return nil } - return vc.ExtendedWriter.WriteBuffer(buffer) + return vc.ExtendedConn.WriteBuffer(buffer) } -func (vc *Conn) sendRequest(p []byte) bool { +func (vc *Conn) sendRequest(p []byte) (err error) { var addonsBytes []byte if vc.addons != nil { - addonsBytes, vc.err = proto.Marshal(vc.addons) - if vc.err != nil { - return true + addonsBytes, err = proto.Marshal(vc.addons) + if err != nil { + return } } @@ -141,15 +116,15 @@ func (vc *Conn) sendRequest(p []byte) bool { buf.Must(buf.Error(buffer.Write(p))) - _, vc.err = vc.ExtendedWriter.Write(buffer.Bytes()) - return true + _, err = vc.ExtendedConn.Write(buffer.Bytes()) + return } -func (vc *Conn) recvResponse() error { +func (vc *Conn) recvResponse() (err error) { var buffer [2]byte - _, vc.err = io.ReadFull(vc.ExtendedReader, buffer[:]) - if vc.err != nil { - return vc.err + _, err = io.ReadFull(vc.ExtendedConn, buffer[:]) + if err != nil { + return err } if buffer[0] != Version { @@ -158,29 +133,35 @@ func (vc *Conn) recvResponse() error { length := int64(buffer[1]) if length != 0 { // addon data length > 0 - io.CopyN(io.Discard, vc.ExtendedReader, length) // just discard + io.CopyN(io.Discard, vc.ExtendedConn, length) // just discard } - return nil + return } func (vc *Conn) Upstream() any { - return vc.Conn + return vc.ExtendedConn +} + +func (vc *Conn) ReaderReplaceable() bool { + return vc.received +} + +func (vc *Conn) WriterReplaceable() bool { + return vc.sent } func (vc *Conn) NeedHandshake() bool { - return vc.needHandshake + return !vc.sent } // newConn return a Conn instance func newConn(conn net.Conn, client *Client, dst *DstAddr) (net.Conn, error) { c := &Conn{ - ExtendedReader: N.NewExtendedReader(conn), - ExtendedWriter: N.NewExtendedWriter(conn), - Conn: conn, - id: client.uuid, - dst: dst, - needHandshake: true, + ExtendedConn: N.NewExtendedConn(conn), + id: client.uuid, + addons: client.Addons, + dst: dst, } if client.Addons != nil { @@ -190,7 +171,6 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (net.Conn, error) { if err != nil { return nil, err } - c.addons = client.Addons return visionConn, nil } } diff --git a/transport/vless/vision/conn.go b/transport/vless/vision/conn.go index 0e3667b8..2801091c 100644 --- a/transport/vless/vision/conn.go +++ b/transport/vless/vision/conn.go @@ -24,7 +24,7 @@ type Conn struct { net.Conn // should be *vless.Conn N.ExtendedReader N.ExtendedWriter - userUUID *uuid.UUID + userUUID uuid.UUID // [*tls.Conn] or other tls-like [net.Conn]'s internal variables netConn net.Conn // tlsConn.NetConn() diff --git a/transport/vless/vision/vision.go b/transport/vless/vision/vision.go index c49253ec..e785c6ad 100644 --- a/transport/vless/vision/vision.go +++ b/transport/vless/vision/vision.go @@ -21,7 +21,7 @@ import ( var ErrNotHandshakeComplete = errors.New("tls connection not handshake complete") var ErrNotTLS13 = errors.New("XTLS Vision based on TLS 1.3 outer connection") -func NewConn(conn net.Conn, tlsConn net.Conn, userUUID *uuid.UUID) (*Conn, error) { +func NewConn(conn net.Conn, tlsConn net.Conn, userUUID uuid.UUID) (*Conn, error) { c := &Conn{ ExtendedReader: N.NewExtendedReader(conn), ExtendedWriter: N.NewExtendedWriter(conn), diff --git a/transport/vless/vless.go b/transport/vless/vless.go index 9fb54f92..4e99b9ba 100644 --- a/transport/vless/vless.go +++ b/transport/vless/vless.go @@ -42,7 +42,7 @@ type DstAddr struct { // Client is vless connection generator type Client struct { - uuid *uuid.UUID + uuid uuid.UUID Addons *Addons } @@ -63,7 +63,7 @@ func NewClient(uuidStr string, addons *Addons) (*Client, error) { } return &Client{ - uuid: &uid, + uuid: uid, Addons: addons, }, nil }