chore: cleanup vision code

This commit is contained in:
wwqgtxx 2025-08-20 09:46:36 +08:00
parent 2790481709
commit 12c30acdda
3 changed files with 38 additions and 44 deletions

View File

@ -30,26 +30,22 @@ type Conn struct {
} }
func (vc *Conn) Read(b []byte) (int, error) { func (vc *Conn) Read(b []byte) (int, error) {
if vc.received { if !vc.received {
return vc.ExtendedReader.Read(b) if err := vc.recvResponse(); err != nil {
return 0, err
}
vc.received = true
} }
if err := vc.recvResponse(); err != nil {
return 0, err
}
vc.received = true
return vc.ExtendedReader.Read(b) return vc.ExtendedReader.Read(b)
} }
func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error { func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
if vc.received { if !vc.received {
return vc.ExtendedReader.ReadBuffer(buffer) if err := vc.recvResponse(); err != nil {
return err
}
vc.received = true
} }
if err := vc.recvResponse(); err != nil {
return err
}
vc.received = true
return vc.ExtendedReader.ReadBuffer(buffer) return vc.ExtendedReader.ReadBuffer(buffer)
} }
@ -190,7 +186,7 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (net.Conn, error) {
if client.Addons != nil { if client.Addons != nil {
switch client.Addons.Flow { switch client.Addons.Flow {
case XRV: case XRV:
visionConn, err := vision.NewConn(c, c.id) visionConn, err := vision.NewConn(c, conn, c.id)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -21,15 +21,16 @@ var (
) )
type Conn struct { type Conn struct {
net.Conn net.Conn // should be *vless.Conn
N.ExtendedReader N.ExtendedReader
N.ExtendedWriter N.ExtendedWriter
upstream net.Conn
userUUID *uuid.UUID userUUID *uuid.UUID
tlsConn net.Conn // tlsConn and it's internal variables
input *bytes.Reader tlsConn net.Conn // maybe [*tls.Conn] or other tls-like conn
rawInput *bytes.Buffer netConn net.Conn // tlsConn.NetConn()
input *bytes.Reader // &tlsConn.input or nil
rawInput *bytes.Buffer // &tlsConn.rawInput or nil
needHandshake bool needHandshake bool
packetsToFilter int packetsToFilter int
@ -143,7 +144,7 @@ func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
} }
if vc.input == nil && vc.rawInput == nil { if vc.input == nil && vc.rawInput == nil {
vc.readProcess = false vc.readProcess = false
vc.ExtendedReader = N.NewExtendedReader(vc.Conn) vc.ExtendedReader = N.NewExtendedReader(vc.netConn)
log.Debugln("XTLS Vision direct read start") log.Debugln("XTLS Vision direct read start")
} }
if needReturn { if needReturn {
@ -214,7 +215,7 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) (err error) {
return err return err
} }
if vc.writeDirect { if vc.writeDirect {
vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn) vc.ExtendedWriter = N.NewExtendedWriter(vc.netConn)
log.Debugln("XTLS Vision direct write start") log.Debugln("XTLS Vision direct write start")
//time.Sleep(5 * time.Millisecond) //time.Sleep(5 * time.Millisecond)
} }
@ -235,7 +236,7 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) (err error) {
ApplyPadding(buffer2, command, nil, vc.isTLS) ApplyPadding(buffer2, command, nil, vc.isTLS)
err = vc.ExtendedWriter.WriteBuffer(buffer2) err = vc.ExtendedWriter.WriteBuffer(buffer2)
if vc.writeDirect { if vc.writeDirect {
vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn) vc.ExtendedWriter = N.NewExtendedWriter(vc.netConn)
log.Debugln("XTLS Vision direct write start") log.Debugln("XTLS Vision direct write start")
//time.Sleep(10 * time.Millisecond) //time.Sleep(10 * time.Millisecond)
} }
@ -266,9 +267,9 @@ func (vc *Conn) NeedHandshake() bool {
func (vc *Conn) Upstream() any { func (vc *Conn) Upstream() any {
if vc.writeDirect || if vc.writeDirect ||
vc.readLastCommand == commandPaddingDirect { vc.readLastCommand == commandPaddingDirect {
return vc.Conn return vc.netConn
} }
return vc.upstream return vc.Conn
} }
func (vc *Conn) ReaderPossiblyReplaceable() bool { func (vc *Conn) ReaderPossiblyReplaceable() bool {
@ -293,3 +294,10 @@ func (vc *Conn) WriterReplaceable() bool {
} }
return false return false
} }
func (vc *Conn) Close() error {
if vc.ReaderReplaceable() || vc.WriterReplaceable() { // ignore send closeNotify alert in tls.Conn
return vc.netConn.Close()
}
return vc.Conn.Close()
}

View File

@ -15,22 +15,17 @@ import (
"github.com/metacubex/mihomo/transport/vless/encryption" "github.com/metacubex/mihomo/transport/vless/encryption"
"github.com/gofrs/uuid/v5" "github.com/gofrs/uuid/v5"
"github.com/metacubex/sing/common"
) )
var ErrNotTLS13 = errors.New("XTLS Vision based on TLS 1.3 outer connection") var ErrNotTLS13 = errors.New("XTLS Vision based on TLS 1.3 outer connection")
type connWithUpstream interface { func NewConn(conn net.Conn, tlsConn net.Conn, userUUID *uuid.UUID) (*Conn, error) {
net.Conn
common.WithUpstream
}
func NewConn(conn connWithUpstream, userUUID *uuid.UUID) (*Conn, error) {
c := &Conn{ c := &Conn{
ExtendedReader: N.NewExtendedReader(conn), ExtendedReader: N.NewExtendedReader(conn),
ExtendedWriter: N.NewExtendedWriter(conn), ExtendedWriter: N.NewExtendedWriter(conn),
upstream: conn, Conn: conn,
userUUID: userUUID, userUUID: userUUID,
tlsConn: tlsConn,
packetsToFilter: 6, packetsToFilter: 6,
needHandshake: true, needHandshake: true,
readProcess: true, readProcess: true,
@ -39,36 +34,31 @@ func NewConn(conn connWithUpstream, userUUID *uuid.UUID) (*Conn, error) {
} }
var t reflect.Type var t reflect.Type
var p unsafe.Pointer var p unsafe.Pointer
switch underlying := conn.Upstream().(type) { switch underlying := tlsConn.(type) {
case *gotls.Conn: case *gotls.Conn:
//log.Debugln("type tls") //log.Debugln("type tls")
c.Conn = underlying.NetConn() c.netConn = underlying.NetConn()
c.tlsConn = underlying
t = reflect.TypeOf(underlying).Elem() t = reflect.TypeOf(underlying).Elem()
p = unsafe.Pointer(underlying) p = unsafe.Pointer(underlying)
case *tlsC.Conn: case *tlsC.Conn:
//log.Debugln("type *tlsC.Conn") //log.Debugln("type *tlsC.Conn")
c.Conn = underlying.NetConn() c.netConn = underlying.NetConn()
c.tlsConn = underlying
t = reflect.TypeOf(underlying).Elem() t = reflect.TypeOf(underlying).Elem()
p = unsafe.Pointer(underlying) p = unsafe.Pointer(underlying)
case *tlsC.UConn: case *tlsC.UConn:
//log.Debugln("type *tlsC.UConn") //log.Debugln("type *tlsC.UConn")
c.Conn = underlying.NetConn() c.netConn = underlying.NetConn()
c.tlsConn = underlying
t = reflect.TypeOf(underlying.Conn).Elem() t = reflect.TypeOf(underlying.Conn).Elem()
//log.Debugln("t:%v", t) //log.Debugln("t:%v", t)
p = unsafe.Pointer(underlying.Conn) p = unsafe.Pointer(underlying.Conn)
case *encryption.ClientConn: case *encryption.ClientConn:
//log.Debugln("type *encryption.ClientConn") //log.Debugln("type *encryption.ClientConn")
c.Conn = underlying.Conn c.netConn = underlying.Conn
c.tlsConn = underlying
t = reflect.TypeOf(underlying).Elem() t = reflect.TypeOf(underlying).Elem()
p = unsafe.Pointer(underlying) p = unsafe.Pointer(underlying)
case *encryption.ServerConn: case *encryption.ServerConn:
//log.Debugln("type *encryption.ServerConn") //log.Debugln("type *encryption.ServerConn")
c.Conn = underlying.Conn c.netConn = underlying.Conn
c.tlsConn = underlying
t = reflect.TypeOf(underlying).Elem() t = reflect.TypeOf(underlying).Elem()
p = unsafe.Pointer(underlying) p = unsafe.Pointer(underlying)
default: default: