diff --git a/transport/vless/encryption/client.go b/transport/vless/encryption/client.go index 4f3b61fc..c5a3df98 100644 --- a/transport/vless/encryption/client.go +++ b/transport/vless/encryption/client.go @@ -10,6 +10,7 @@ import ( "io" "net" "runtime" + "strings" "sync" "time" @@ -37,13 +38,12 @@ func init() { type ClientInstance struct { sync.RWMutex - nfsEKey *mlkem.EncapsulationKey768 - nfsEKeySha256 [32]byte - xor uint32 - minutes time.Duration - expire time.Time - baseKey []byte - ticket []byte + nfsEKey *mlkem.EncapsulationKey768 + xorKey []byte + minutes time.Duration + expire time.Time + baseKey []byte + ticket []byte } type ClientConn struct { @@ -60,10 +60,17 @@ type ClientConn struct { } func (i *ClientInstance) Init(nfsEKeyBytes []byte, xor uint32, minutes time.Duration) (err error) { + if i.nfsEKey != nil { + err = errors.New("already initialized") + return + } i.nfsEKey, err = mlkem.NewEncapsulationKey768(nfsEKeyBytes) + if err != nil { + return + } if xor > 0 { - i.nfsEKeySha256 = sha256.Sum256(nfsEKeyBytes) - i.xor = xor + xorKey := sha256.Sum256(nfsEKeyBytes) + i.xorKey = xorKey[:] } i.minutes = minutes return @@ -73,8 +80,8 @@ func (i *ClientInstance) Handshake(conn net.Conn) (net.Conn, error) { if i.nfsEKey == nil { return nil, errors.New("uninitialized") } - if i.xor > 0 { - conn = NewXorConn(conn, i.nfsEKeySha256[:]) + if i.xorKey != nil { + conn = NewXorConn(conn, i.xorKey) } c := &ClientConn{Conn: conn} @@ -110,14 +117,14 @@ func (i *ClientInstance) Handshake(conn net.Conn) (net.Conn, error) { } // client can send more padding / NFS AEAD messages if needed - _, t, l, err := ReadAndDecodeHeader(c.Conn) + _, t, l, err := ReadAndDiscardPaddings(c.Conn) if err != nil { return nil, err } + if t != 1 { return nil, fmt.Errorf("unexpected type %v, expect random hello", t) } - peerRandomHello := make([]byte, 1088+21) if l != len(peerRandomHello) { return nil, fmt.Errorf("unexpected length %v for random hello", l) @@ -194,27 +201,9 @@ func (c *ClientConn) Read(b []byte) (int, error) { return 0, nil } if c.peerAead == nil { - var t byte - var l int - var err error - if c.instance == nil { // from 1-RTT - for { - if _, t, l, err = ReadAndDecodeHeader(c.Conn); err != nil { - return 0, err - } - if t != 23 { - break - } - if _, err := io.ReadFull(c.Conn, make([]byte, l)); err != nil { - return 0, err - } - } - } else { - h := make([]byte, 5) - if _, err := io.ReadFull(c.Conn, h); err != nil { - return 0, err - } - if t, l, err = DecodeHeader(h); err != nil { + _, t, l, err := ReadAndDiscardPaddings(c.Conn) + if err != nil { + if c.instance != nil && strings.HasPrefix(err.Error(), "invalid header: ") { // from 0-RTT c.instance.Lock() if bytes.Equal(c.ticket, c.instance.ticket) { c.instance.expire = time.Now() // expired @@ -222,6 +211,7 @@ func (c *ClientConn) Read(b []byte) (int, error) { c.instance.Unlock() return 0, errors.New("new handshake needed") } + return 0, err } if t != 0 { return 0, fmt.Errorf("unexpected type %v, expect server random", t) diff --git a/transport/vless/encryption/common.go b/transport/vless/encryption/common.go index cd4a93c5..2a1a5a00 100644 --- a/transport/vless/encryption/common.go +++ b/transport/vless/encryption/common.go @@ -45,10 +45,10 @@ func DecodeHeader(h []byte) (t byte, l int, err error) { } else if h[0] == 1 && h[1] == 1 && h[2] == 1 { t = 1 } else { - h = nil + l = 0 } - if h == nil || l < 17 || l > 17000 { // TODO: TLSv1.3 max length - err = fmt.Errorf("invalid header: %v", h[:5]) + if l < 17 || l > 17000 { // TODO: TLSv1.3 max length + err = fmt.Errorf("invalid header: %v", h[:5]) // relied by client's Read() } return } @@ -62,6 +62,17 @@ func ReadAndDecodeHeader(conn net.Conn) (h []byte, t byte, l int, err error) { return } +func ReadAndDiscardPaddings(conn net.Conn) (h []byte, t byte, l int, err error) { + for { + if h, t, l, err = ReadAndDecodeHeader(conn); err != nil || t != 23 { + return + } + if _, err = io.ReadFull(conn, make([]byte, l)); err != nil { + return + } + } +} + func NewAead(c byte, secret, salt, info []byte) (aead cipher.AEAD) { key := make([]byte, 32) hkdf.New(sha256.New, secret, salt, info).Read(key) diff --git a/transport/vless/encryption/doc.go b/transport/vless/encryption/doc.go index fa36a8cd..43fe2c44 100644 --- a/transport/vless/encryption/doc.go +++ b/transport/vless/encryption/doc.go @@ -9,5 +9,5 @@ // https://github.com/XTLS/Xray-core/commit/1720be168fa069332c418503d30341fc6e01df7f // https://github.com/XTLS/Xray-core/commit/0fd7691d6b28e05922d7a5a9313d97745a51ea63 // https://github.com/XTLS/Xray-core/commit/09cc92c61d9067e0d65c1cae9124664ecfc78f43 -// https://github.com/XTLS/Xray-core/commit/7f778a4e2f123dc03fe57fbf24da59dcaf270f8a +// https://github.com/XTLS/Xray-core/commit/2807ee432a1fbeb301815647189eacd650b12a8b package encryption diff --git a/transport/vless/encryption/server.go b/transport/vless/encryption/server.go index 4619a10c..0e2f4277 100644 --- a/transport/vless/encryption/server.go +++ b/transport/vless/encryption/server.go @@ -24,12 +24,11 @@ type ServerSession struct { type ServerInstance struct { sync.RWMutex - nfsDKey *mlkem.DecapsulationKey768 - nfsEKeySha256 [32]byte - xor uint32 - minutes time.Duration - sessions map[[21]byte]*ServerSession - closed bool + nfsDKey *mlkem.DecapsulationKey768 + xorKey []byte + minutes time.Duration + sessions map[[21]byte]*ServerSession + closed bool } type ServerConn struct { @@ -46,10 +45,17 @@ type ServerConn struct { } func (i *ServerInstance) Init(nfsDKeySeed []byte, xor uint32, minutes time.Duration) (err error) { + if i.nfsDKey != nil { + err = errors.New("already initialized") + return + } i.nfsDKey, err = mlkem.NewDecapsulationKey768(nfsDKeySeed) + if err != nil { + return + } if xor > 0 { - i.nfsEKeySha256 = sha256.Sum256(i.nfsDKey.EncapsulationKey().Bytes()) - i.xor = xor + xorKey := sha256.Sum256(i.nfsDKey.EncapsulationKey().Bytes()) + i.xorKey = xorKey[:] } if minutes > 0 { i.minutes = minutes @@ -86,18 +92,15 @@ func (i *ServerInstance) Handshake(conn net.Conn) (net.Conn, error) { if i.nfsDKey == nil { return nil, errors.New("uninitialized") } - if i.xor > 0 { - conn = NewXorConn(conn, i.nfsEKeySha256[:]) + if i.xorKey != nil { + conn = NewXorConn(conn, i.xorKey) } c := &ServerConn{Conn: conn} - _, t, l, err := ReadAndDecodeHeader(c.Conn) + _, t, l, err := ReadAndDiscardPaddings(c.Conn) if err != nil { return nil, err } - if t == 23 { - return nil, errors.New("unexpected data") - } if t == 0 { if i.minutes == 0 { @@ -187,19 +190,9 @@ func (c *ServerConn) Read(b []byte) (int, error) { } if c.peerAead == nil { if c.peerRandom == nil { // from 1-RTT - var t byte - var l int - var err error - for { - if _, t, l, err = ReadAndDecodeHeader(c.Conn); err != nil { - return 0, err - } - if t != 23 { - break - } - if _, err := io.ReadFull(c.Conn, make([]byte, l)); err != nil { - return 0, err - } + _, t, l, err := ReadAndDiscardPaddings(c.Conn) + if err != nil { + return 0, err } if t != 0 { return 0, fmt.Errorf("unexpected type %v, expect ticket hello", t) diff --git a/transport/vless/encryption/xor.go b/transport/vless/encryption/xor.go index 696702bc..bbe489ef 100644 --- a/transport/vless/encryption/xor.go +++ b/transport/vless/encryption/xor.go @@ -18,7 +18,7 @@ type XorConn struct { } func NewXorConn(conn net.Conn, key []byte) *XorConn { - return &XorConn{Conn: conn, key: key[:16]} + return &XorConn{Conn: conn, key: key} //chacha20.NewUnauthenticatedCipher() }