diff --git a/transport/vless/conn.go b/transport/vless/conn.go index 180a9969..94ae71ee 100644 --- a/transport/vless/conn.go +++ b/transport/vless/conn.go @@ -30,26 +30,22 @@ type Conn struct { } func (vc *Conn) Read(b []byte) (int, error) { - if vc.received { - return vc.ExtendedReader.Read(b) + if !vc.received { + if err := vc.recvResponse(); err != nil { + return 0, err + } + vc.received = true } - - if err := vc.recvResponse(); err != nil { - return 0, err - } - vc.received = true return vc.ExtendedReader.Read(b) } func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error { - if vc.received { - return vc.ExtendedReader.ReadBuffer(buffer) + if !vc.received { + if err := vc.recvResponse(); err != nil { + return err + } + vc.received = true } - - if err := vc.recvResponse(); err != nil { - return err - } - vc.received = true return vc.ExtendedReader.ReadBuffer(buffer) } @@ -190,7 +186,7 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (net.Conn, error) { if client.Addons != nil { switch client.Addons.Flow { case XRV: - visionConn, err := vision.NewConn(c, c.id) + visionConn, err := vision.NewConn(c, conn, c.id) if err != nil { return nil, err } diff --git a/transport/vless/vision/conn.go b/transport/vless/vision/conn.go index a0e83f71..b8367fb9 100644 --- a/transport/vless/vision/conn.go +++ b/transport/vless/vision/conn.go @@ -21,15 +21,16 @@ var ( ) type Conn struct { - net.Conn + net.Conn // should be *vless.Conn N.ExtendedReader N.ExtendedWriter - upstream net.Conn userUUID *uuid.UUID - tlsConn net.Conn - input *bytes.Reader - rawInput *bytes.Buffer + // tlsConn and it's internal variables + tlsConn net.Conn // maybe [*tls.Conn] or other tls-like conn + netConn net.Conn // tlsConn.NetConn() + input *bytes.Reader // &tlsConn.input or nil + rawInput *bytes.Buffer // &tlsConn.rawInput or nil needHandshake bool packetsToFilter int @@ -143,7 +144,7 @@ func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error { } if vc.input == nil && vc.rawInput == nil { vc.readProcess = false - vc.ExtendedReader = N.NewExtendedReader(vc.Conn) + vc.ExtendedReader = N.NewExtendedReader(vc.netConn) log.Debugln("XTLS Vision direct read start") } if needReturn { @@ -214,7 +215,7 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) (err error) { return err } if vc.writeDirect { - vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn) + vc.ExtendedWriter = N.NewExtendedWriter(vc.netConn) log.Debugln("XTLS Vision direct write start") //time.Sleep(5 * time.Millisecond) } @@ -235,7 +236,7 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) (err error) { ApplyPadding(buffer2, command, nil, vc.isTLS) err = vc.ExtendedWriter.WriteBuffer(buffer2) if vc.writeDirect { - vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn) + vc.ExtendedWriter = N.NewExtendedWriter(vc.netConn) log.Debugln("XTLS Vision direct write start") //time.Sleep(10 * time.Millisecond) } @@ -266,9 +267,9 @@ func (vc *Conn) NeedHandshake() bool { func (vc *Conn) Upstream() any { if vc.writeDirect || vc.readLastCommand == commandPaddingDirect { - return vc.Conn + return vc.netConn } - return vc.upstream + return vc.Conn } func (vc *Conn) ReaderPossiblyReplaceable() bool { @@ -293,3 +294,10 @@ func (vc *Conn) WriterReplaceable() bool { } return false } + +func (vc *Conn) Close() error { + if vc.ReaderReplaceable() || vc.WriterReplaceable() { // ignore send closeNotify alert in tls.Conn + return vc.netConn.Close() + } + return vc.Conn.Close() +} diff --git a/transport/vless/vision/vision.go b/transport/vless/vision/vision.go index 32634f0c..108e0177 100644 --- a/transport/vless/vision/vision.go +++ b/transport/vless/vision/vision.go @@ -15,22 +15,17 @@ import ( "github.com/metacubex/mihomo/transport/vless/encryption" "github.com/gofrs/uuid/v5" - "github.com/metacubex/sing/common" ) var ErrNotTLS13 = errors.New("XTLS Vision based on TLS 1.3 outer connection") -type connWithUpstream interface { - net.Conn - common.WithUpstream -} - -func NewConn(conn connWithUpstream, userUUID *uuid.UUID) (*Conn, error) { +func NewConn(conn net.Conn, tlsConn net.Conn, userUUID *uuid.UUID) (*Conn, error) { c := &Conn{ ExtendedReader: N.NewExtendedReader(conn), ExtendedWriter: N.NewExtendedWriter(conn), - upstream: conn, + Conn: conn, userUUID: userUUID, + tlsConn: tlsConn, packetsToFilter: 6, needHandshake: true, readProcess: true, @@ -39,36 +34,31 @@ func NewConn(conn connWithUpstream, userUUID *uuid.UUID) (*Conn, error) { } var t reflect.Type var p unsafe.Pointer - switch underlying := conn.Upstream().(type) { + switch underlying := tlsConn.(type) { case *gotls.Conn: //log.Debugln("type tls") - c.Conn = underlying.NetConn() - c.tlsConn = underlying + c.netConn = underlying.NetConn() t = reflect.TypeOf(underlying).Elem() p = unsafe.Pointer(underlying) case *tlsC.Conn: //log.Debugln("type *tlsC.Conn") - c.Conn = underlying.NetConn() - c.tlsConn = underlying + c.netConn = underlying.NetConn() t = reflect.TypeOf(underlying).Elem() p = unsafe.Pointer(underlying) case *tlsC.UConn: //log.Debugln("type *tlsC.UConn") - c.Conn = underlying.NetConn() - c.tlsConn = underlying + c.netConn = underlying.NetConn() t = reflect.TypeOf(underlying.Conn).Elem() //log.Debugln("t:%v", t) p = unsafe.Pointer(underlying.Conn) case *encryption.ClientConn: //log.Debugln("type *encryption.ClientConn") - c.Conn = underlying.Conn - c.tlsConn = underlying + c.netConn = underlying.Conn t = reflect.TypeOf(underlying).Elem() p = unsafe.Pointer(underlying) case *encryption.ServerConn: //log.Debugln("type *encryption.ServerConn") - c.Conn = underlying.Conn - c.tlsConn = underlying + c.netConn = underlying.Conn t = reflect.TypeOf(underlying).Elem() p = unsafe.Pointer(underlying) default: