chore: sync vless encryption code

This commit is contained in:
wwqgtxx 2025-08-28 08:26:46 +08:00
parent 0ced98da4d
commit cdd02a90c3
4 changed files with 35 additions and 34 deletions

View File

@ -65,7 +65,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if i.NfsPKeys == nil { if i.NfsPKeys == nil {
return nil, errors.New("uninitialized") return nil, errors.New("uninitialized")
} }
c := &CommonConn{Conn: conn} c := NewCommonConn(conn)
ivAndRealysLength := 16 + i.RelaysLength ivAndRealysLength := 16 + i.RelaysLength
pfsKeyExchangeLength := 18 + 1184 + 32 + 16 pfsKeyExchangeLength := 18 + 1184 + 32 + 16

View File

@ -11,11 +11,18 @@ import (
"math/big" "math/big"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
"github.com/metacubex/blake3" "github.com/metacubex/blake3"
) )
var OutBytesPool = sync.Pool{
New: func() any {
return make([]byte, 5+8192+16)
},
}
type CommonConn struct { type CommonConn struct {
net.Conn net.Conn
Client *ClientInstance Client *ClientInstance
@ -23,36 +30,45 @@ type CommonConn struct {
PreWrite []byte PreWrite []byte
GCM *GCM GCM *GCM
PeerPadding []byte PeerPadding []byte
PeerInBytes []byte
PeerGCM *GCM PeerGCM *GCM
input bytes.Reader // PeerCache input bytes.Reader // PeerCache
} }
func NewCommonConn(conn net.Conn) *CommonConn {
return &CommonConn{
Conn: conn,
PeerInBytes: make([]byte, 5+17000), // no need to use sync.Pool, because we are always reading
}
}
func (c *CommonConn) Write(b []byte) (int, error) { func (c *CommonConn) Write(b []byte) (int, error) {
if len(b) == 0 { if len(b) == 0 {
return 0, nil return 0, nil
} }
var data []byte outBytes := OutBytesPool.Get().([]byte)
defer OutBytesPool.Put(outBytes)
for n := 0; n < len(b); { for n := 0; n < len(b); {
b := b[n:] b := b[n:]
if len(b) > 8192 { if len(b) > 8192 {
b = b[:8192] // for avoiding another copy() in peer's Read() b = b[:8192] // for avoiding another copy() in peer's Read()
} }
n += len(b) n += len(b)
data = make([]byte, 5+len(b)+16) headerAndData := outBytes[:5+len(b)+16]
EncodeHeader(data, len(b)+16) EncodeHeader(headerAndData, len(b)+16)
max := false max := false
if bytes.Equal(c.GCM.Nonce[:], MaxNonce) { if bytes.Equal(c.GCM.Nonce[:], MaxNonce) {
max = true max = true
} }
c.GCM.Seal(data[:5], nil, b, data[:5]) c.GCM.Seal(headerAndData[:5], nil, b, headerAndData[:5])
if max { if max {
c.GCM = NewGCM(data[5:], c.UnitedKey) c.GCM = NewGCM(headerAndData, c.UnitedKey)
} }
if c.PreWrite != nil { if c.PreWrite != nil {
data = append(c.PreWrite, data...) headerAndData = append(c.PreWrite, headerAndData...)
c.PreWrite = nil c.PreWrite = nil
} }
if _, err := c.Conn.Write(data); err != nil { if _, err := c.Conn.Write(headerAndData); err != nil {
return 0, err return 0, err
} }
} }
@ -85,9 +101,13 @@ func (c *CommonConn) Read(b []byte) (int, error) {
if c.input.Len() > 0 { if c.input.Len() > 0 {
return c.input.Read(b) return c.input.Read(b)
} }
h, l, err := ReadAndDecodeHeader(c.Conn) // l: 17~17000 peerHeader := c.PeerInBytes[:5]
if _, err := io.ReadFull(c.Conn, peerHeader); err != nil {
return 0, err
}
l, err := DecodeHeader(c.PeerInBytes[:5]) // l: 17~17000
if err != nil { if err != nil {
if c.Client != nil && strings.HasPrefix(err.Error(), "invalid header: ") { // client's 0-RTT if c.Client != nil && strings.Contains(err.Error(), "invalid header: ") { // client's 0-RTT
c.Client.RWLock.Lock() c.Client.RWLock.Lock()
if bytes.HasPrefix(c.UnitedKey, c.Client.PfsKey) { if bytes.HasPrefix(c.UnitedKey, c.Client.PfsKey) {
c.Client.Expire = time.Now() // expired c.Client.Expire = time.Now() // expired
@ -98,7 +118,7 @@ func (c *CommonConn) Read(b []byte) (int, error) {
return 0, err return 0, err
} }
c.Client = nil c.Client = nil
peerData := make([]byte, l) peerData := c.PeerInBytes[5 : 5+l]
if _, err := io.ReadFull(c.Conn, peerData); err != nil { if _, err := io.ReadFull(c.Conn, peerData); err != nil {
return 0, err return 0, err
} }
@ -108,9 +128,9 @@ func (c *CommonConn) Read(b []byte) (int, error) {
} }
var newGCM *GCM var newGCM *GCM
if bytes.Equal(c.PeerGCM.Nonce[:], MaxNonce) { if bytes.Equal(c.PeerGCM.Nonce[:], MaxNonce) {
newGCM = NewGCM(peerData, c.UnitedKey) newGCM = NewGCM(c.PeerInBytes[:5+l], c.UnitedKey)
} }
_, err = c.PeerGCM.Open(dst[:0], nil, peerData, h) _, err = c.PeerGCM.Open(dst[:0], nil, peerData, peerHeader)
if newGCM != nil { if newGCM != nil {
c.PeerGCM = newGCM c.PeerGCM = newGCM
} }
@ -191,26 +211,6 @@ func DecodeHeader(h []byte) (l int, err error) {
return return
} }
func ReadAndDecodeHeader(conn net.Conn) (h []byte, l int, err error) {
h = make([]byte, 5)
if _, err = io.ReadFull(conn, h); err != nil {
return
}
l, err = DecodeHeader(h)
return
}
func ReadAndDiscardPaddings(conn net.Conn) (h []byte, l int, err error) {
for {
if h, l, err = ReadAndDecodeHeader(conn); err != nil {
return
}
if _, err = io.ReadFull(conn, make([]byte, l)); err != nil {
return
}
}
}
func randBetween(from int64, to int64) int64 { func randBetween(from int64, to int64) int64 {
if from == to { if from == to {
return from return from

View File

@ -20,4 +20,5 @@
// https://github.com/XTLS/Xray-core/commit/ad7140641c44239c9dcdc3d7215ea639b1f0841c // https://github.com/XTLS/Xray-core/commit/ad7140641c44239c9dcdc3d7215ea639b1f0841c
// https://github.com/XTLS/Xray-core/commit/0199dea39988a1a1b846d0bf8598631bade40902 // https://github.com/XTLS/Xray-core/commit/0199dea39988a1a1b846d0bf8598631bade40902
// https://github.com/XTLS/Xray-core/commit/fce1195b60f48ca18a953dbd5c7d991869de9a5e // https://github.com/XTLS/Xray-core/commit/fce1195b60f48ca18a953dbd5c7d991869de9a5e
// https://github.com/XTLS/Xray-core/commit/b0b220985c9c1bc832665458d5fd6e0c287b67ae
package encryption package encryption

View File

@ -101,7 +101,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if i.NfsSKeys == nil { if i.NfsSKeys == nil {
return nil, errors.New("uninitialized") return nil, errors.New("uninitialized")
} }
c := &CommonConn{Conn: conn} c := NewCommonConn(conn)
ivAndRelays := make([]byte, 16+i.RelaysLength) ivAndRelays := make([]byte, 16+i.RelaysLength)
if _, err := io.ReadFull(conn, ivAndRelays); err != nil { if _, err := io.ReadFull(conn, ivAndRelays); err != nil {