diff --git a/transport/vless/encryption/client.go b/transport/vless/encryption/client.go index 465f1626..8debb06f 100644 --- a/transport/vless/encryption/client.go +++ b/transport/vless/encryption/client.go @@ -65,7 +65,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { if i.NfsPKeys == nil { return nil, errors.New("uninitialized") } - c := &CommonConn{Conn: conn} + c := NewCommonConn(conn) ivAndRealysLength := 16 + i.RelaysLength pfsKeyExchangeLength := 18 + 1184 + 32 + 16 diff --git a/transport/vless/encryption/common.go b/transport/vless/encryption/common.go index 1b2aa182..53c9c51d 100644 --- a/transport/vless/encryption/common.go +++ b/transport/vless/encryption/common.go @@ -11,11 +11,18 @@ import ( "math/big" "net" "strings" + "sync" "time" "github.com/metacubex/blake3" ) +var OutBytesPool = sync.Pool{ + New: func() any { + return make([]byte, 5+8192+16) + }, +} + type CommonConn struct { net.Conn Client *ClientInstance @@ -23,36 +30,45 @@ type CommonConn struct { PreWrite []byte GCM *GCM PeerPadding []byte + PeerInBytes []byte PeerGCM *GCM input bytes.Reader // PeerCache } +func NewCommonConn(conn net.Conn) *CommonConn { + return &CommonConn{ + Conn: conn, + PeerInBytes: make([]byte, 5+17000), // no need to use sync.Pool, because we are always reading + } +} + func (c *CommonConn) Write(b []byte) (int, error) { if len(b) == 0 { return 0, nil } - var data []byte + outBytes := OutBytesPool.Get().([]byte) + defer OutBytesPool.Put(outBytes) for n := 0; n < len(b); { b := b[n:] if len(b) > 8192 { b = b[:8192] // for avoiding another copy() in peer's Read() } n += len(b) - data = make([]byte, 5+len(b)+16) - EncodeHeader(data, len(b)+16) + headerAndData := outBytes[:5+len(b)+16] + EncodeHeader(headerAndData, len(b)+16) max := false if bytes.Equal(c.GCM.Nonce[:], MaxNonce) { max = true } - c.GCM.Seal(data[:5], nil, b, data[:5]) + c.GCM.Seal(headerAndData[:5], nil, b, headerAndData[:5]) if max { - c.GCM = NewGCM(data[5:], c.UnitedKey) + c.GCM = NewGCM(headerAndData, c.UnitedKey) } if c.PreWrite != nil { - data = append(c.PreWrite, data...) + headerAndData = append(c.PreWrite, headerAndData...) c.PreWrite = nil } - if _, err := c.Conn.Write(data); err != nil { + if _, err := c.Conn.Write(headerAndData); err != nil { return 0, err } } @@ -85,9 +101,13 @@ func (c *CommonConn) Read(b []byte) (int, error) { if c.input.Len() > 0 { return c.input.Read(b) } - h, l, err := ReadAndDecodeHeader(c.Conn) // l: 17~17000 + peerHeader := c.PeerInBytes[:5] + if _, err := io.ReadFull(c.Conn, peerHeader); err != nil { + return 0, err + } + l, err := DecodeHeader(c.PeerInBytes[:5]) // l: 17~17000 if err != nil { - if c.Client != nil && strings.HasPrefix(err.Error(), "invalid header: ") { // client's 0-RTT + if c.Client != nil && strings.Contains(err.Error(), "invalid header: ") { // client's 0-RTT c.Client.RWLock.Lock() if bytes.HasPrefix(c.UnitedKey, c.Client.PfsKey) { c.Client.Expire = time.Now() // expired @@ -98,7 +118,7 @@ func (c *CommonConn) Read(b []byte) (int, error) { return 0, err } c.Client = nil - peerData := make([]byte, l) + peerData := c.PeerInBytes[5 : 5+l] if _, err := io.ReadFull(c.Conn, peerData); err != nil { return 0, err } @@ -108,9 +128,9 @@ func (c *CommonConn) Read(b []byte) (int, error) { } var newGCM *GCM if bytes.Equal(c.PeerGCM.Nonce[:], MaxNonce) { - newGCM = NewGCM(peerData, c.UnitedKey) + newGCM = NewGCM(c.PeerInBytes[:5+l], c.UnitedKey) } - _, err = c.PeerGCM.Open(dst[:0], nil, peerData, h) + _, err = c.PeerGCM.Open(dst[:0], nil, peerData, peerHeader) if newGCM != nil { c.PeerGCM = newGCM } @@ -191,26 +211,6 @@ func DecodeHeader(h []byte) (l int, err error) { return } -func ReadAndDecodeHeader(conn net.Conn) (h []byte, l int, err error) { - h = make([]byte, 5) - if _, err = io.ReadFull(conn, h); err != nil { - return - } - l, err = DecodeHeader(h) - return -} - -func ReadAndDiscardPaddings(conn net.Conn) (h []byte, l int, err error) { - for { - if h, l, err = ReadAndDecodeHeader(conn); err != nil { - return - } - if _, err = io.ReadFull(conn, make([]byte, l)); err != nil { - return - } - } -} - func randBetween(from int64, to int64) int64 { if from == to { return from diff --git a/transport/vless/encryption/doc.go b/transport/vless/encryption/doc.go index 177839c6..d3d72e5a 100644 --- a/transport/vless/encryption/doc.go +++ b/transport/vless/encryption/doc.go @@ -20,4 +20,5 @@ // https://github.com/XTLS/Xray-core/commit/ad7140641c44239c9dcdc3d7215ea639b1f0841c // https://github.com/XTLS/Xray-core/commit/0199dea39988a1a1b846d0bf8598631bade40902 // https://github.com/XTLS/Xray-core/commit/fce1195b60f48ca18a953dbd5c7d991869de9a5e +// https://github.com/XTLS/Xray-core/commit/b0b220985c9c1bc832665458d5fd6e0c287b67ae package encryption diff --git a/transport/vless/encryption/server.go b/transport/vless/encryption/server.go index 36ce399a..14ffd1e5 100644 --- a/transport/vless/encryption/server.go +++ b/transport/vless/encryption/server.go @@ -101,7 +101,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { if i.NfsSKeys == nil { return nil, errors.New("uninitialized") } - c := &CommonConn{Conn: conn} + c := NewCommonConn(conn) ivAndRelays := make([]byte, 16+i.RelaysLength) if _, err := io.ReadFull(conn, ivAndRelays); err != nil {