chore: cleanup vision code

This commit is contained in:
wwqgtxx 2025-09-08 10:54:59 +08:00
parent 02d954bfa8
commit 0336d64e52
4 changed files with 101 additions and 63 deletions

View File

@ -26,11 +26,20 @@ type ReadWaitOptions = network.ReadWaitOptions
var NewReadWaitOptions = network.NewReadWaitOptions var NewReadWaitOptions = network.NewReadWaitOptions
type ReaderWithUpstream = network.ReaderWithUpstream
type WithUpstreamReader = network.WithUpstreamReader
type WriterWithUpstream = network.WriterWithUpstream
type WithUpstreamWriter = network.WithUpstreamWriter
type WithUpstream = common.WithUpstream
var UnwrapReader = network.UnwrapReader
var UnwrapWriter = network.UnwrapWriter
func NewDeadlineConn(conn net.Conn) ExtendedConn { func NewDeadlineConn(conn net.Conn) ExtendedConn {
if deadline.IsPipe(conn) || deadline.IsPipe(network.UnwrapReader(conn)) { if deadline.IsPipe(conn) || deadline.IsPipe(UnwrapReader(conn)) {
return NewExtendedConn(conn) // pipe always have correctly deadline implement return NewExtendedConn(conn) // pipe always have correctly deadline implement
} }
if deadline.IsConn(conn) || deadline.IsConn(network.UnwrapReader(conn)) { if deadline.IsConn(conn) || deadline.IsConn(UnwrapReader(conn)) {
return NewExtendedConn(conn) // was a *deadline.Conn return NewExtendedConn(conn) // was a *deadline.Conn
} }
return deadline.NewConn(conn) return deadline.NewConn(conn)

View File

@ -26,13 +26,11 @@ type Conn struct {
N.ExtendedWriter N.ExtendedWriter
userUUID *uuid.UUID userUUID *uuid.UUID
// tlsConn and it's internal variables // [*tls.Conn] or other tls-like [net.Conn]'s internal variables
tlsConn net.Conn // maybe [*tls.Conn] or other tls-like conn
netConn net.Conn // tlsConn.NetConn() netConn net.Conn // tlsConn.NetConn()
input *bytes.Reader // &tlsConn.input or nil input *bytes.Reader // &tlsConn.input or nil
rawInput *bytes.Buffer // &tlsConn.rawInput or nil rawInput *bytes.Buffer // &tlsConn.rawInput or nil
needHandshake bool
packetsToFilter int packetsToFilter int
isTLS bool isTLS bool
isTLS12orAbove bool isTLS12orAbove bool
@ -46,6 +44,7 @@ type Conn struct {
readLastCommand byte readLastCommand byte
writeFilterApplicationData bool writeFilterApplicationData bool
writeDirect bool writeDirect bool
writeOnceUserUUID []byte
} }
func (vc *Conn) Read(b []byte) (int, error) { func (vc *Conn) Read(b []byte) (int, error) {
@ -169,29 +168,12 @@ func (vc *Conn) Write(p []byte) (int, error) {
} }
func (vc *Conn) WriteBuffer(buffer *buf.Buffer) (err error) { func (vc *Conn) WriteBuffer(buffer *buf.Buffer) (err error) {
if vc.needHandshake {
vc.needHandshake = false
if buffer.IsEmpty() {
ApplyPadding(buffer, commandPaddingContinue, vc.userUUID, true) // we do a long padding to hide vless header
} else {
vc.FilterTLS(buffer.Bytes())
ApplyPadding(buffer, commandPaddingContinue, vc.userUUID, vc.isTLS)
}
err = vc.ExtendedWriter.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return err
}
err = vc.checkTLSVersion()
if err != nil {
buffer.Release()
return err
}
vc.tlsConn = nil
return nil
}
if vc.writeFilterApplicationData { if vc.writeFilterApplicationData {
if buffer.IsEmpty() {
ApplyPadding(buffer, commandPaddingContinue, &vc.writeOnceUserUUID, true) // we do a long padding to hide vless header
return vc.ExtendedWriter.WriteBuffer(buffer)
}
vc.FilterTLS(buffer.Bytes()) vc.FilterTLS(buffer.Bytes())
buffers := vc.ReshapeBuffer(buffer) buffers := vc.ReshapeBuffer(buffer)
applyPadding := true applyPadding := true
@ -211,7 +193,7 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) (err error) {
vc.writeFilterApplicationData = false vc.writeFilterApplicationData = false
applyPadding = false applyPadding = false
} }
ApplyPadding(buffer, command, nil, vc.isTLS) ApplyPadding(buffer, command, &vc.writeOnceUserUUID, vc.isTLS)
} }
err = vc.ExtendedWriter.WriteBuffer(buffer) err = vc.ExtendedWriter.WriteBuffer(buffer)
@ -234,7 +216,7 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) (err error) {
} }
func (vc *Conn) FrontHeadroom() int { func (vc *Conn) FrontHeadroom() int {
if vc.readFilterUUID { if vc.readFilterUUID || vc.writeOnceUserUUID != nil {
return PaddingHeaderLen return PaddingHeaderLen
} }
return PaddingHeaderLen - uuid.Size return PaddingHeaderLen - uuid.Size
@ -245,7 +227,7 @@ func (vc *Conn) RearHeadroom() int {
} }
func (vc *Conn) NeedHandshake() bool { func (vc *Conn) NeedHandshake() bool {
return vc.needHandshake return vc.writeOnceUserUUID != nil
} }
func (vc *Conn) Upstream() any { func (vc *Conn) Upstream() any {

View File

@ -20,7 +20,7 @@ const (
commandPaddingDirect byte = 0x02 commandPaddingDirect byte = 0x02
) )
func ApplyPadding(buffer *buf.Buffer, command byte, userUUID *uuid.UUID, paddingTLS bool) { func ApplyPadding(buffer *buf.Buffer, command byte, userUUID *[]byte, paddingTLS bool) {
contentLen := int32(buffer.Len()) contentLen := int32(buffer.Len())
var paddingLen int32 var paddingLen int32
if contentLen < 900 { if contentLen < 900 {
@ -35,8 +35,9 @@ func ApplyPadding(buffer *buf.Buffer, command byte, userUUID *uuid.UUID, padding
binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(paddingLen)) binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(paddingLen))
binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(contentLen)) binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(contentLen))
buffer.ExtendHeader(1)[0] = command buffer.ExtendHeader(1)[0] = command
if userUUID != nil { if userUUID != nil && *userUUID != nil {
copy(buffer.ExtendHeader(uuid.Size), userUUID.Bytes()) copy(buffer.ExtendHeader(uuid.Size), *userUUID)
*userUUID = nil
} }
buffer.Extend(int(paddingLen)) buffer.Extend(int(paddingLen))

View File

@ -12,11 +12,13 @@ import (
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
tlsC "github.com/metacubex/mihomo/component/tls" tlsC "github.com/metacubex/mihomo/component/tls"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/transport/vless/encryption" "github.com/metacubex/mihomo/transport/vless/encryption"
"github.com/gofrs/uuid/v5" "github.com/gofrs/uuid/v5"
) )
var ErrNotHandshakeComplete = errors.New("tls connection not handshake complete")
var ErrNotTLS13 = errors.New("XTLS Vision based on TLS 1.3 outer connection") var ErrNotTLS13 = errors.New("XTLS Vision based on TLS 1.3 outer connection")
func NewConn(conn net.Conn, tlsConn net.Conn, userUUID *uuid.UUID) (*Conn, error) { func NewConn(conn net.Conn, tlsConn net.Conn, userUUID *uuid.UUID) (*Conn, error) {
@ -25,40 +27,72 @@ func NewConn(conn net.Conn, tlsConn net.Conn, userUUID *uuid.UUID) (*Conn, error
ExtendedWriter: N.NewExtendedWriter(conn), ExtendedWriter: N.NewExtendedWriter(conn),
Conn: conn, Conn: conn,
userUUID: userUUID, userUUID: userUUID,
tlsConn: tlsConn,
packetsToFilter: 6, packetsToFilter: 6,
needHandshake: true,
readProcess: true, readProcess: true,
readFilterUUID: true, readFilterUUID: true,
writeFilterApplicationData: true, writeFilterApplicationData: true,
writeOnceUserUUID: userUUID.Bytes(),
} }
var t reflect.Type var t reflect.Type
var p unsafe.Pointer var p unsafe.Pointer
switch underlying := tlsConn.(type) { var upstream any = tlsConn
case *gotls.Conn: for {
//log.Debugln("type tls") switch underlying := upstream.(type) {
c.netConn = underlying.NetConn() case *gotls.Conn:
t = reflect.TypeOf(underlying).Elem() //log.Debugln("type tls")
p = unsafe.Pointer(underlying) tlsConn = underlying
case *tlsC.Conn: c.netConn = underlying.NetConn()
//log.Debugln("type *tlsC.Conn") t = reflect.TypeOf(underlying).Elem()
c.netConn = underlying.NetConn() p = unsafe.Pointer(underlying)
t = reflect.TypeOf(underlying).Elem() break
p = unsafe.Pointer(underlying) case *tlsC.Conn:
case *tlsC.UConn: //log.Debugln("type *tlsC.Conn")
//log.Debugln("type *tlsC.UConn") tlsConn = underlying
c.netConn = underlying.NetConn() c.netConn = underlying.NetConn()
t = reflect.TypeOf(underlying.Conn).Elem() t = reflect.TypeOf(underlying).Elem()
//log.Debugln("t:%v", t) p = unsafe.Pointer(underlying)
p = unsafe.Pointer(underlying.Conn) break
case *encryption.CommonConn: case *tlsC.UConn:
//log.Debugln("type *encryption.CommonConn") //log.Debugln("type *tlsC.UConn")
c.netConn = underlying.Conn tlsConn = underlying
t = reflect.TypeOf(underlying).Elem() c.netConn = underlying.NetConn()
p = unsafe.Pointer(underlying) t = reflect.TypeOf(underlying.Conn).Elem()
default: //log.Debugln("t:%v", t)
p = unsafe.Pointer(underlying.Conn)
break
case *encryption.CommonConn:
//log.Debugln("type *encryption.CommonConn")
tlsConn = underlying
c.netConn = underlying.Conn
t = reflect.TypeOf(underlying).Elem()
p = unsafe.Pointer(underlying)
break
}
if u, ok := upstream.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { // must replaceable
break
}
if u, ok := upstream.(N.WithUpstreamReader); ok {
upstream = u.UpstreamReader()
continue
}
if u, ok := upstream.(N.WithUpstream); ok {
upstream = u.Upstream()
continue
}
}
if t == nil || p == nil {
log.Warnln("vision: not a valid supported TLS connection: %s", reflect.TypeOf(tlsConn))
return nil, fmt.Errorf(`failed to use vision, maybe "security" is not "tls" or "utls"`) return nil, fmt.Errorf(`failed to use vision, maybe "security" is not "tls" or "utls"`)
} }
if err := checkTLSVersion(tlsConn); err != nil {
if errors.Is(err, ErrNotHandshakeComplete) {
log.Warnln("vision: TLS connection not handshake complete: %s", reflect.TypeOf(tlsConn))
} else {
return nil, err
}
}
if i, ok := t.FieldByName("input"); ok { if i, ok := t.FieldByName("input"); ok {
c.input = (*bytes.Reader)(unsafe.Add(p, i.Offset)) c.input = (*bytes.Reader)(unsafe.Add(p, i.Offset))
} }
@ -68,18 +102,30 @@ func NewConn(conn net.Conn, tlsConn net.Conn, userUUID *uuid.UUID) (*Conn, error
return c, nil return c, nil
} }
func (vc *Conn) checkTLSVersion() error { func checkTLSVersion(tlsConn net.Conn) error {
switch underlying := vc.tlsConn.(type) { switch underlying := tlsConn.(type) {
case *gotls.Conn: case *gotls.Conn:
if underlying.ConnectionState().Version != gotls.VersionTLS13 { state := underlying.ConnectionState()
if !state.HandshakeComplete {
return ErrNotHandshakeComplete
}
if state.Version != gotls.VersionTLS13 {
return ErrNotTLS13 return ErrNotTLS13
} }
case *tlsC.Conn: case *tlsC.Conn:
if underlying.ConnectionState().Version != tlsC.VersionTLS13 { state := underlying.ConnectionState()
if !state.HandshakeComplete {
return ErrNotHandshakeComplete
}
if state.Version != tlsC.VersionTLS13 {
return ErrNotTLS13 return ErrNotTLS13
} }
case *tlsC.UConn: case *tlsC.UConn:
if underlying.ConnectionState().Version != tlsC.VersionTLS13 { state := underlying.ConnectionState()
if !state.HandshakeComplete {
return ErrNotHandshakeComplete
}
if state.Version != tlsC.VersionTLS13 {
return ErrNotTLS13 return ErrNotTLS13
} }
} }