From 0336d64e5208ef7dceb35478120e61eb5db7b58e Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Mon, 8 Sep 2025 10:54:59 +0800 Subject: [PATCH] chore: cleanup vision code --- common/net/sing.go | 13 +++- transport/vless/vision/conn.go | 38 +++-------- transport/vless/vision/padding.go | 7 +- transport/vless/vision/vision.go | 106 +++++++++++++++++++++--------- 4 files changed, 101 insertions(+), 63 deletions(-) diff --git a/common/net/sing.go b/common/net/sing.go index 5cf97594..3545b6a4 100644 --- a/common/net/sing.go +++ b/common/net/sing.go @@ -26,11 +26,20 @@ type ReadWaitOptions = network.ReadWaitOptions var NewReadWaitOptions = network.NewReadWaitOptions +type ReaderWithUpstream = network.ReaderWithUpstream +type WithUpstreamReader = network.WithUpstreamReader +type WriterWithUpstream = network.WriterWithUpstream +type WithUpstreamWriter = network.WithUpstreamWriter +type WithUpstream = common.WithUpstream + +var UnwrapReader = network.UnwrapReader +var UnwrapWriter = network.UnwrapWriter + func NewDeadlineConn(conn net.Conn) ExtendedConn { - if deadline.IsPipe(conn) || deadline.IsPipe(network.UnwrapReader(conn)) { + if deadline.IsPipe(conn) || deadline.IsPipe(UnwrapReader(conn)) { return NewExtendedConn(conn) // pipe always have correctly deadline implement } - if deadline.IsConn(conn) || deadline.IsConn(network.UnwrapReader(conn)) { + if deadline.IsConn(conn) || deadline.IsConn(UnwrapReader(conn)) { return NewExtendedConn(conn) // was a *deadline.Conn } return deadline.NewConn(conn) diff --git a/transport/vless/vision/conn.go b/transport/vless/vision/conn.go index 7e778cf8..0e3667b8 100644 --- a/transport/vless/vision/conn.go +++ b/transport/vless/vision/conn.go @@ -26,13 +26,11 @@ type Conn struct { N.ExtendedWriter userUUID *uuid.UUID - // tlsConn and it's internal variables - tlsConn net.Conn // maybe [*tls.Conn] or other tls-like conn + // [*tls.Conn] or other tls-like [net.Conn]'s internal variables netConn net.Conn // tlsConn.NetConn() input *bytes.Reader // &tlsConn.input or nil rawInput *bytes.Buffer // &tlsConn.rawInput or nil - needHandshake bool packetsToFilter int isTLS bool isTLS12orAbove bool @@ -46,6 +44,7 @@ type Conn struct { readLastCommand byte writeFilterApplicationData bool writeDirect bool + writeOnceUserUUID []byte } func (vc *Conn) Read(b []byte) (int, error) { @@ -169,29 +168,12 @@ func (vc *Conn) Write(p []byte) (int, error) { } func (vc *Conn) WriteBuffer(buffer *buf.Buffer) (err error) { - if vc.needHandshake { - vc.needHandshake = false - if buffer.IsEmpty() { - ApplyPadding(buffer, commandPaddingContinue, vc.userUUID, true) // we do a long padding to hide vless header - } else { - vc.FilterTLS(buffer.Bytes()) - ApplyPadding(buffer, commandPaddingContinue, vc.userUUID, vc.isTLS) - } - err = vc.ExtendedWriter.WriteBuffer(buffer) - if err != nil { - buffer.Release() - return err - } - err = vc.checkTLSVersion() - if err != nil { - buffer.Release() - return err - } - vc.tlsConn = nil - return nil - } - if vc.writeFilterApplicationData { + if buffer.IsEmpty() { + ApplyPadding(buffer, commandPaddingContinue, &vc.writeOnceUserUUID, true) // we do a long padding to hide vless header + return vc.ExtendedWriter.WriteBuffer(buffer) + } + vc.FilterTLS(buffer.Bytes()) buffers := vc.ReshapeBuffer(buffer) applyPadding := true @@ -211,7 +193,7 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) (err error) { vc.writeFilterApplicationData = false applyPadding = false } - ApplyPadding(buffer, command, nil, vc.isTLS) + ApplyPadding(buffer, command, &vc.writeOnceUserUUID, vc.isTLS) } err = vc.ExtendedWriter.WriteBuffer(buffer) @@ -234,7 +216,7 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) (err error) { } func (vc *Conn) FrontHeadroom() int { - if vc.readFilterUUID { + if vc.readFilterUUID || vc.writeOnceUserUUID != nil { return PaddingHeaderLen } return PaddingHeaderLen - uuid.Size @@ -245,7 +227,7 @@ func (vc *Conn) RearHeadroom() int { } func (vc *Conn) NeedHandshake() bool { - return vc.needHandshake + return vc.writeOnceUserUUID != nil } func (vc *Conn) Upstream() any { diff --git a/transport/vless/vision/padding.go b/transport/vless/vision/padding.go index 710f64c2..3152139a 100644 --- a/transport/vless/vision/padding.go +++ b/transport/vless/vision/padding.go @@ -20,7 +20,7 @@ const ( commandPaddingDirect byte = 0x02 ) -func ApplyPadding(buffer *buf.Buffer, command byte, userUUID *uuid.UUID, paddingTLS bool) { +func ApplyPadding(buffer *buf.Buffer, command byte, userUUID *[]byte, paddingTLS bool) { contentLen := int32(buffer.Len()) var paddingLen int32 if contentLen < 900 { @@ -35,8 +35,9 @@ func ApplyPadding(buffer *buf.Buffer, command byte, userUUID *uuid.UUID, padding binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(paddingLen)) binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(contentLen)) buffer.ExtendHeader(1)[0] = command - if userUUID != nil { - copy(buffer.ExtendHeader(uuid.Size), userUUID.Bytes()) + if userUUID != nil && *userUUID != nil { + copy(buffer.ExtendHeader(uuid.Size), *userUUID) + *userUUID = nil } buffer.Extend(int(paddingLen)) diff --git a/transport/vless/vision/vision.go b/transport/vless/vision/vision.go index f9158ca4..c49253ec 100644 --- a/transport/vless/vision/vision.go +++ b/transport/vless/vision/vision.go @@ -12,11 +12,13 @@ import ( N "github.com/metacubex/mihomo/common/net" tlsC "github.com/metacubex/mihomo/component/tls" + "github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/transport/vless/encryption" "github.com/gofrs/uuid/v5" ) +var ErrNotHandshakeComplete = errors.New("tls connection not handshake complete") var ErrNotTLS13 = errors.New("XTLS Vision based on TLS 1.3 outer connection") func NewConn(conn net.Conn, tlsConn net.Conn, userUUID *uuid.UUID) (*Conn, error) { @@ -25,40 +27,72 @@ func NewConn(conn net.Conn, tlsConn net.Conn, userUUID *uuid.UUID) (*Conn, error ExtendedWriter: N.NewExtendedWriter(conn), Conn: conn, userUUID: userUUID, - tlsConn: tlsConn, packetsToFilter: 6, - needHandshake: true, readProcess: true, readFilterUUID: true, writeFilterApplicationData: true, + writeOnceUserUUID: userUUID.Bytes(), } var t reflect.Type var p unsafe.Pointer - switch underlying := tlsConn.(type) { - case *gotls.Conn: - //log.Debugln("type tls") - c.netConn = underlying.NetConn() - t = reflect.TypeOf(underlying).Elem() - p = unsafe.Pointer(underlying) - case *tlsC.Conn: - //log.Debugln("type *tlsC.Conn") - c.netConn = underlying.NetConn() - t = reflect.TypeOf(underlying).Elem() - p = unsafe.Pointer(underlying) - case *tlsC.UConn: - //log.Debugln("type *tlsC.UConn") - c.netConn = underlying.NetConn() - t = reflect.TypeOf(underlying.Conn).Elem() - //log.Debugln("t:%v", t) - p = unsafe.Pointer(underlying.Conn) - case *encryption.CommonConn: - //log.Debugln("type *encryption.CommonConn") - c.netConn = underlying.Conn - t = reflect.TypeOf(underlying).Elem() - p = unsafe.Pointer(underlying) - default: + var upstream any = tlsConn + for { + switch underlying := upstream.(type) { + case *gotls.Conn: + //log.Debugln("type tls") + tlsConn = underlying + c.netConn = underlying.NetConn() + t = reflect.TypeOf(underlying).Elem() + p = unsafe.Pointer(underlying) + break + case *tlsC.Conn: + //log.Debugln("type *tlsC.Conn") + tlsConn = underlying + c.netConn = underlying.NetConn() + t = reflect.TypeOf(underlying).Elem() + p = unsafe.Pointer(underlying) + break + case *tlsC.UConn: + //log.Debugln("type *tlsC.UConn") + tlsConn = underlying + c.netConn = underlying.NetConn() + t = reflect.TypeOf(underlying.Conn).Elem() + //log.Debugln("t:%v", t) + p = unsafe.Pointer(underlying.Conn) + break + case *encryption.CommonConn: + //log.Debugln("type *encryption.CommonConn") + tlsConn = underlying + c.netConn = underlying.Conn + t = reflect.TypeOf(underlying).Elem() + p = unsafe.Pointer(underlying) + break + } + if u, ok := upstream.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { // must replaceable + break + } + if u, ok := upstream.(N.WithUpstreamReader); ok { + upstream = u.UpstreamReader() + continue + } + if u, ok := upstream.(N.WithUpstream); ok { + upstream = u.Upstream() + continue + } + } + if t == nil || p == nil { + log.Warnln("vision: not a valid supported TLS connection: %s", reflect.TypeOf(tlsConn)) return nil, fmt.Errorf(`failed to use vision, maybe "security" is not "tls" or "utls"`) } + + if err := checkTLSVersion(tlsConn); err != nil { + if errors.Is(err, ErrNotHandshakeComplete) { + log.Warnln("vision: TLS connection not handshake complete: %s", reflect.TypeOf(tlsConn)) + } else { + return nil, err + } + } + if i, ok := t.FieldByName("input"); ok { c.input = (*bytes.Reader)(unsafe.Add(p, i.Offset)) } @@ -68,18 +102,30 @@ func NewConn(conn net.Conn, tlsConn net.Conn, userUUID *uuid.UUID) (*Conn, error return c, nil } -func (vc *Conn) checkTLSVersion() error { - switch underlying := vc.tlsConn.(type) { +func checkTLSVersion(tlsConn net.Conn) error { + switch underlying := tlsConn.(type) { case *gotls.Conn: - if underlying.ConnectionState().Version != gotls.VersionTLS13 { + state := underlying.ConnectionState() + if !state.HandshakeComplete { + return ErrNotHandshakeComplete + } + if state.Version != gotls.VersionTLS13 { return ErrNotTLS13 } case *tlsC.Conn: - if underlying.ConnectionState().Version != tlsC.VersionTLS13 { + state := underlying.ConnectionState() + if !state.HandshakeComplete { + return ErrNotHandshakeComplete + } + if state.Version != tlsC.VersionTLS13 { return ErrNotTLS13 } case *tlsC.UConn: - if underlying.ConnectionState().Version != tlsC.VersionTLS13 { + state := underlying.ConnectionState() + if !state.HandshakeComplete { + return ErrNotHandshakeComplete + } + if state.Version != tlsC.VersionTLS13 { return ErrNotTLS13 } }