diff --git a/transport/vless/encryption/client.go b/transport/vless/encryption/client.go index 8debb06f..ee73b63d 100644 --- a/transport/vless/encryption/client.go +++ b/transport/vless/encryption/client.go @@ -7,11 +7,23 @@ import ( "errors" "io" "net" + "runtime" "sync" "time" "github.com/metacubex/blake3" "github.com/metacubex/utls/mlkem" + "golang.org/x/sys/cpu" +) + +var ( + // Keep in sync with crypto/tls/cipher_suites.go. + hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ && cpu.X86.HasSSE41 && cpu.X86.HasSSSE3 + hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCTR && cpu.S390X.HasGHASH + hasGCMAsmPPC64 = runtime.GOARCH == "ppc64" || runtime.GOARCH == "ppc64le" + + HasAESGCMHardwareSupport = hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X || hasGCMAsmPPC64 ) type ClientInstance struct { @@ -65,7 +77,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { if i.NfsPKeys == nil { return nil, errors.New("uninitialized") } - c := NewCommonConn(conn) + c := NewCommonConn(conn, HasAESGCMHardwareSupport) ivAndRealysLength := 16 + i.RelaysLength pfsKeyExchangeLength := 18 + 1184 + 32 + 16 @@ -107,18 +119,18 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { lastCTR.XORKeyStream(relays[index:], i.Hash32s[j+1][:]) relays = relays[index+32:] } - nfsGCM := NewGCM(iv, nfsKey) + nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES) if i.Seconds > 0 { i.RWLock.RLock() if time.Now().Before(i.Expire) { c.Client = i c.UnitedKey = append(i.PfsKey, nfsKey...) // different unitedKey for each connection - nfsGCM.Seal(clientHello[:ivAndRealysLength], nil, EncodeLength(32), nil) - nfsGCM.Seal(clientHello[:ivAndRealysLength+18], nil, i.Ticket, nil) + nfsAEAD.Seal(clientHello[:ivAndRealysLength], nil, EncodeLength(32), nil) + nfsAEAD.Seal(clientHello[:ivAndRealysLength+18], nil, i.Ticket, nil) i.RWLock.RUnlock() c.PreWrite = clientHello[:ivAndRealysLength+18+32] - c.GCM = NewGCM(clientHello[ivAndRealysLength+18:ivAndRealysLength+18+32], c.UnitedKey) + c.AEAD = NewAEAD(clientHello[ivAndRealysLength+18:ivAndRealysLength+18+32], c.UnitedKey, c.UseAES) if i.XorMode == 2 { c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, iv), nil, len(c.PreWrite), 16) } @@ -128,15 +140,15 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { } pfsKeyExchange := clientHello[ivAndRealysLength : ivAndRealysLength+pfsKeyExchangeLength] - nfsGCM.Seal(pfsKeyExchange[:0], nil, EncodeLength(pfsKeyExchangeLength-18), nil) + nfsAEAD.Seal(pfsKeyExchange[:0], nil, EncodeLength(pfsKeyExchangeLength-18), nil) mlkem768DKey, _ := mlkem.GenerateKey768() x25519SKey, _ := ecdh.X25519().GenerateKey(rand.Reader) pfsPublicKey := append(mlkem768DKey.EncapsulationKey().Bytes(), x25519SKey.PublicKey().Bytes()...) - nfsGCM.Seal(pfsKeyExchange[:18], nil, pfsPublicKey, nil) + nfsAEAD.Seal(pfsKeyExchange[:18], nil, pfsPublicKey, nil) padding := clientHello[ivAndRealysLength+pfsKeyExchangeLength:] - nfsGCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil) - nfsGCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil) + nfsAEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil) + nfsAEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil) if _, err := conn.Write(clientHello); err != nil { return nil, err @@ -147,7 +159,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil { return nil, err } - nfsGCM.Open(encryptedPfsPublicKey[:0], MaxNonce, encryptedPfsPublicKey, nil) + nfsAEAD.Open(encryptedPfsPublicKey[:0], MaxNonce, encryptedPfsPublicKey, nil) mlkem768Key, err := mlkem768DKey.Decapsulate(encryptedPfsPublicKey[:1088]) if err != nil { return nil, err @@ -164,14 +176,14 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { copy(pfsKey, mlkem768Key) copy(pfsKey[32:], x25519Key) c.UnitedKey = append(pfsKey, nfsKey...) - c.GCM = NewGCM(pfsPublicKey, c.UnitedKey) - c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1088+32], c.UnitedKey) + c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES) + c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1088+32], c.UnitedKey, c.UseAES) encryptedTicket := make([]byte, 32) if _, err := io.ReadFull(conn, encryptedTicket); err != nil { return nil, err } - if _, err := c.PeerGCM.Open(encryptedTicket[:0], nil, encryptedTicket, nil); err != nil { + if _, err := c.PeerAEAD.Open(encryptedTicket[:0], nil, encryptedTicket, nil); err != nil { return nil, err } seconds := DecodeLength(encryptedTicket) @@ -188,7 +200,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) { if _, err := io.ReadFull(conn, encryptedLength); err != nil { return nil, err } - if _, err := c.PeerGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { + if _, err := c.PeerAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { return nil, err } length := DecodeLength(encryptedLength[:2]) diff --git a/transport/vless/encryption/common.go b/transport/vless/encryption/common.go index 0275a21f..07ea57c3 100644 --- a/transport/vless/encryption/common.go +++ b/transport/vless/encryption/common.go @@ -14,23 +14,26 @@ import ( "github.com/metacubex/blake3" "github.com/metacubex/randv2" + "golang.org/x/crypto/chacha20poly1305" ) type CommonConn struct { net.Conn + UseAES bool Client *ClientInstance UnitedKey []byte PreWrite []byte - GCM *GCM + AEAD *AEAD + PeerAEAD *AEAD PeerPadding []byte rawInput bytes.Buffer // PeerInBytes - PeerGCM *GCM input bytes.Reader // PeerCache } -func NewCommonConn(conn net.Conn) *CommonConn { +func NewCommonConn(conn net.Conn, useAES bool) *CommonConn { return &CommonConn{ - Conn: conn, + Conn: conn, + UseAES: useAES, } } @@ -49,12 +52,12 @@ func (c *CommonConn) Write(b []byte) (int, error) { headerAndData := outBytes[:5+len(b)+16] EncodeHeader(headerAndData, len(b)+16) max := false - if bytes.Equal(c.GCM.Nonce[:], MaxNonce) { + if bytes.Equal(c.AEAD.Nonce[:], MaxNonce) { max = true } - c.GCM.Seal(headerAndData[:5], nil, b, headerAndData[:5]) + c.AEAD.Seal(headerAndData[:5], nil, b, headerAndData[:5]) if max { - c.GCM = NewGCM(headerAndData, c.UnitedKey) + c.AEAD = NewAEAD(headerAndData, c.UnitedKey, c.UseAES) } if c.PreWrite != nil { headerAndData = append(c.PreWrite, headerAndData...) @@ -71,12 +74,12 @@ func (c *CommonConn) Read(b []byte) (int, error) { if len(b) == 0 { return 0, nil } - if c.PeerGCM == nil { // client's 0-RTT + if c.PeerAEAD == nil { // client's 0-RTT serverRandom := make([]byte, 16) if _, err := io.ReadFull(c.Conn, serverRandom); err != nil { return 0, err } - c.PeerGCM = NewGCM(serverRandom, c.UnitedKey) + c.PeerAEAD = NewAEAD(serverRandom, c.UnitedKey, c.UseAES) if xorConn, ok := c.Conn.(*XorConn); ok { xorConn.PeerCTR = NewCTR(c.UnitedKey, serverRandom) } @@ -85,7 +88,7 @@ func (c *CommonConn) Read(b []byte) (int, error) { if _, err := io.ReadFull(c.Conn, c.PeerPadding); err != nil { return 0, err } - if _, err := c.PeerGCM.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil { + if _, err := c.PeerAEAD.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil { return 0, err } c.PeerPadding = nil @@ -119,13 +122,13 @@ func (c *CommonConn) Read(b []byte) (int, error) { if len(dst) <= len(b) { dst = b[:len(dst)] // avoids another copy() } - var newGCM *GCM - if bytes.Equal(c.PeerGCM.Nonce[:], MaxNonce) { - newGCM = NewGCM(append(peerHeader, peerData...), c.UnitedKey) + var newAEAD *AEAD + if bytes.Equal(c.PeerAEAD.Nonce[:], MaxNonce) { + newAEAD = NewAEAD(append(peerHeader, peerData...), c.UnitedKey, c.UseAES) } - _, err = c.PeerGCM.Open(dst[:0], nil, peerData, peerHeader) - if newGCM != nil { - c.PeerGCM = newGCM + _, err = c.PeerAEAD.Open(dst[:0], nil, peerData, peerHeader) + if newAEAD != nil { + c.PeerAEAD = newAEAD } if err != nil { return 0, err @@ -137,28 +140,32 @@ func (c *CommonConn) Read(b []byte) (int, error) { return len(dst), nil } -type GCM struct { +type AEAD struct { cipher.AEAD Nonce [12]byte } -func NewGCM(ctx, key []byte) *GCM { +func NewAEAD(ctx, key []byte, useAES bool) *AEAD { k := make([]byte, 32) blake3.DeriveKey(k, string(ctx), key) - block, _ := aes.NewCipher(k) - aead, _ := cipher.NewGCM(block) - return &GCM{AEAD: aead} - //chacha20poly1305.New() + var aead cipher.AEAD + if useAES { + block, _ := aes.NewCipher(k) + aead, _ = cipher.NewGCM(block) + } else { + aead, _ = chacha20poly1305.New(k) + } + return &AEAD{AEAD: aead} } -func (a *GCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte { +func (a *AEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte { if nonce == nil { nonce = IncreaseNonce(a.Nonce[:]) } return a.AEAD.Seal(dst, nonce, plaintext, additionalData) } -func (a *GCM) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { +func (a *AEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { if nonce == nil { nonce = IncreaseNonce(a.Nonce[:]) } diff --git a/transport/vless/encryption/doc.go b/transport/vless/encryption/doc.go index d3d72e5a..a9462a83 100644 --- a/transport/vless/encryption/doc.go +++ b/transport/vless/encryption/doc.go @@ -21,4 +21,5 @@ // https://github.com/XTLS/Xray-core/commit/0199dea39988a1a1b846d0bf8598631bade40902 // https://github.com/XTLS/Xray-core/commit/fce1195b60f48ca18a953dbd5c7d991869de9a5e // https://github.com/XTLS/Xray-core/commit/b0b220985c9c1bc832665458d5fd6e0c287b67ae +// https://github.com/XTLS/Xray-core/commit/82ea7a3cc5ff23280b87e3052f0f83b04f0267fa package encryption diff --git a/transport/vless/encryption/server.go b/transport/vless/encryption/server.go index 14ffd1e5..0dc6da15 100644 --- a/transport/vless/encryption/server.go +++ b/transport/vless/encryption/server.go @@ -101,7 +101,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { if i.NfsSKeys == nil { return nil, errors.New("uninitialized") } - c := NewCommonConn(conn) + c := NewCommonConn(conn, true) ivAndRelays := make([]byte, 16+i.RelaysLength) if _, err := io.ReadFull(conn, ivAndRelays); err != nil { @@ -150,16 +150,21 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { } relays = relays[32:] } - nfsGCM := NewGCM(iv, nfsKey) + nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES) encryptedLength := make([]byte, 18) if _, err := io.ReadFull(conn, encryptedLength); err != nil { return nil, err } - if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { - return nil, err + decryptedLength := make([]byte, 2) + if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil { + c.UseAES = !c.UseAES + nfsAEAD = NewAEAD(iv, nfsKey, c.UseAES) + if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil { + return nil, err + } } - length := DecodeLength(encryptedLength[:2]) + length := DecodeLength(decryptedLength) if length == 32 { if i.Seconds == 0 { @@ -169,7 +174,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { if _, err := io.ReadFull(conn, encryptedTicket); err != nil { return nil, err } - ticket, err := nfsGCM.Open(nil, nil, encryptedTicket, nil) + ticket, err := nfsAEAD.Open(nil, nil, encryptedTicket, nil) if err != nil { return nil, err } @@ -192,8 +197,8 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { c.UnitedKey = append(s.PfsKey, nfsKey...) // the same nfsKey links the upload & download (prevents server -> client's another request) c.PreWrite = make([]byte, 16) rand.Read(c.PreWrite) // always trust yourself, not the client (also prevents being parsed as TLS thus causing false interruption for "native" and "xorpub") - c.GCM = NewGCM(c.PreWrite, c.UnitedKey) - c.PeerGCM = NewGCM(encryptedTicket, c.UnitedKey) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client) + c.AEAD = NewAEAD(c.PreWrite, c.UnitedKey, c.UseAES) + c.PeerAEAD = NewAEAD(encryptedTicket, c.UnitedKey, c.UseAES) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client) if i.XorMode == 2 { c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, c.PreWrite), NewCTR(c.UnitedKey, iv), 16, 0) // it doesn't matter if the attacker sends client's iv back to the client } @@ -207,7 +212,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil { return nil, err } - if _, err := nfsGCM.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil { + if _, err := nfsAEAD.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil { return nil, err } mlkem768EKey, err := mlkem.NewEncapsulationKey768(encryptedPfsPublicKey[:1184]) @@ -229,8 +234,8 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { copy(pfsKey[32:], x25519Key) pfsPublicKey := append(encapsulatedPfsKey, x25519SKey.PublicKey().Bytes()...) c.UnitedKey = append(pfsKey, nfsKey...) - c.GCM = NewGCM(pfsPublicKey, c.UnitedKey) - c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1184+32], c.UnitedKey) + c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES) + c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1184+32], c.UnitedKey, c.UseAES) ticket := make([]byte, 16) rand.Read(ticket) copy(ticket, EncodeLength(int(i.Seconds*4/5))) @@ -239,11 +244,11 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { encryptedTicketLength := 32 paddingLength := int(randBetween(100, 1000)) serverHello := make([]byte, pfsKeyExchangeLength+encryptedTicketLength+paddingLength) - nfsGCM.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil) - c.GCM.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket, nil) + nfsAEAD.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil) + c.AEAD.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket, nil) padding := serverHello[pfsKeyExchangeLength+encryptedTicketLength:] - c.GCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil) - c.GCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil) + c.AEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil) + c.AEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil) if _, err := conn.Write(serverHello); err != nil { return nil, err @@ -263,14 +268,14 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) { if _, err := io.ReadFull(conn, encryptedLength); err != nil { return nil, err } - if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { + if _, err := nfsAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil { return nil, err } encryptedPadding := make([]byte, DecodeLength(encryptedLength[:2])) if _, err := io.ReadFull(conn, encryptedPadding); err != nil { return nil, err } - if _, err := nfsGCM.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil { + if _, err := nfsAEAD.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil { return nil, err }