diff --git a/component/generater/cmd.go b/component/generater/cmd.go index 2bd5b9cd..ad219880 100644 --- a/component/generater/cmd.go +++ b/component/generater/cmd.go @@ -12,7 +12,7 @@ import ( func Main(args []string) { if len(args) < 1 { - panic("Using: generate uuid/reality-keypair/wg-keypair/ech-keypair/vless-mlkem768") + panic("Using: generate uuid/reality-keypair/wg-keypair/ech-keypair/vless-mlkem768/vless-x25519") } switch args[0] { case "uuid": @@ -57,5 +57,16 @@ func Main(args []string) { } fmt.Println("Seed: " + seedBase64) fmt.Println("Client: " + clientBase64) + case "vless-x25519": + var privateKey string + if len(args) > 1 { + privateKey = args[1] + } + privateKeyBase64, passwordBase64, err := encryption.GenX25519(privateKey) + if err != nil { + panic(err) + } + fmt.Println("PrivateKey:" + privateKeyBase64) + fmt.Println("Password:" + passwordBase64) } } diff --git a/docs/config.yaml b/docs/config.yaml index 80af843b..0313bee7 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -638,8 +638,12 @@ proxies: # socks5 port: 443 uuid: uuid network: tcp - encryption: "8min-vless-mlkem768client-bas64RawURLEncoding" # 复用八分钟后协商新的 sharedKey,需小于服务端的值 - # encryption: "8min-xored-mlkem768client-bas64RawURLEncoding" + # ------------------------- + # vless encryption客户端配置: + # (只使用 1-RTT 模式 / 复用八分钟后协商新的 baseKey,周期需小于服务端的值) + # / 是只能选一个,后面是 base64RawURLEncoding,使用 mihomo generate vless-x25519 和 mihomo generate vless-mlkem768 生成,替换值时需去掉括号 + # ------------------------- + encryption: "1rtt/8min.native/divide/random.mlkem768Client.(X25519 Password).(ML-KEM-768 Client)" tls: false #可以不开启tls udp: true @@ -1359,8 +1363,12 @@ listeners: flow: xtls-rprx-vision # ws-path: "/" # 如果不为空则开启 websocket 传输层 # grpc-service-name: "GunService" # 如果不为空则开启 grpc 传输层 - # decryption: "10min-vless-mlkem768seed-bas64RawURLEncoding" # 同时允许 1-RTT 模式与十分钟复用的 0-RTT 模式, 后面base64字符串可由可由 mihomo generate vless-mlkem768 命令生成 - # decryption: "10min-xored-mlkem768seed-bas64RawURLEncoding" + # ------------------------- + # vless encryption服务端配置: + # (只允许 1-RTT 模式 / 同时允许 1-RTT 模式与十分钟复用的 0-RTT 模式;原生外观 / ECH 式 XOR / 全随机数) + # / 是只能选一个,后面是 base64RawURLEncoding,使用 mihomo generate vless-x25519 和 mihomo generate vless-mlkem768 生成,替换值时需去掉括号 + # ------------------------- + # decryption: "1rtt/10min.native/divide/random.mlkem768Seed.(X25519 PrivateKey).(ML-KEM-768 Seed)" # 下面两项如果填写则开启 tls(需要同时填写) # certificate: ./server.crt # private-key: ./server.key diff --git a/listener/inbound/vless_test.go b/listener/inbound/vless_test.go index f3fcd39c..b58c9752 100644 --- a/listener/inbound/vless_test.go +++ b/listener/inbound/vless_test.go @@ -94,34 +94,33 @@ func TestInboundVless_Encryption(t *testing.T) { t.Fatal(err) return } - t.Run("-vless-", func(t *testing.T) { - inboundOptions := inbound.VlessOption{ - Decryption: "10min-vless-mlkem768seed-" + seedBase64, - } - outboundOptions := outbound.VlessOption{ - Encryption: "8min-vless-mlkem768client-" + clientBase64, - } - testInboundVless(t, inboundOptions, outboundOptions) - t.Run("xtls-rprx-vision", func(t *testing.T) { - outboundOptions := outboundOptions - outboundOptions.Flow = "xtls-rprx-vision" + privateKeyBase64, passwordBase64, err := encryption.GenX25519("") + if err != nil { + t.Fatal(err) + return + } + var modes = []string{ + "native", + "divide", + "random", + } + for i := range modes { + mode := modes[i] + t.Run(mode, func(t *testing.T) { + inboundOptions := inbound.VlessOption{ + Decryption: "10min." + mode + ".mlkem768Seed." + privateKeyBase64 + "." + seedBase64, + } + outboundOptions := outbound.VlessOption{ + Encryption: "8min." + mode + ".mlkem768Client." + passwordBase64 + "." + clientBase64, + } testInboundVless(t, inboundOptions, outboundOptions) + t.Run("xtls-rprx-vision", func(t *testing.T) { + outboundOptions := outboundOptions + outboundOptions.Flow = "xtls-rprx-vision" + testInboundVless(t, inboundOptions, outboundOptions) + }) }) - }) - t.Run("-xored-", func(t *testing.T) { - inboundOptions := inbound.VlessOption{ - Decryption: "10min-xored-mlkem768seed-" + seedBase64, - } - outboundOptions := outbound.VlessOption{ - Encryption: "8min-xored-mlkem768client-" + clientBase64, - } - testInboundVless(t, inboundOptions, outboundOptions) - t.Run("xtls-rprx-vision", func(t *testing.T) { - outboundOptions := outboundOptions - outboundOptions.Flow = "xtls-rprx-vision" - testInboundVless(t, inboundOptions, outboundOptions) - }) - }) + } } func TestInboundVless_Wss1(t *testing.T) { diff --git a/transport/vless/encryption/client.go b/transport/vless/encryption/client.go index ecf20692..dfa4fa63 100644 --- a/transport/vless/encryption/client.go +++ b/transport/vless/encryption/client.go @@ -3,6 +3,7 @@ package encryption import ( "bytes" "crypto/cipher" + "crypto/ecdh" "crypto/rand" "errors" "fmt" @@ -40,7 +41,8 @@ type ClientInstance struct { sync.RWMutex nfsEKey *mlkem.EncapsulationKey768 hash11 [11]byte // no more capacity - xorKey []byte + xorMode uint32 + xorPKey *ecdh.PublicKey minutes time.Duration expire time.Time baseKey []byte @@ -60,22 +62,23 @@ type ClientConn struct { input bytes.Reader // peerCache } -func (i *ClientInstance) Init(nfsEKeyBytes []byte, xor uint32, minutes time.Duration) (err error) { +func (i *ClientInstance) Init(nfsEKeyBytes, xorPKeyBytes []byte, xorMode, minutes uint32) (err error) { if i.nfsEKey != nil { err = errors.New("already initialized") return } - i.nfsEKey, err = mlkem.NewEncapsulationKey768(nfsEKeyBytes) - if err != nil { + if i.nfsEKey, err = mlkem.NewEncapsulationKey768(nfsEKeyBytes); err != nil { return } hash32 := sha3.Sum256(nfsEKeyBytes) copy(i.hash11[:], hash32[:]) - if xor > 0 { - xorKey := sha3.Sum256(nfsEKeyBytes) - i.xorKey = xorKey[:] + if xorMode > 0 { + i.xorMode = xorMode + if i.xorPKey, err = ecdh.X25519().NewPublicKey(xorPKeyBytes); err != nil { + return + } } - i.minutes = minutes + i.minutes = time.Duration(minutes) * time.Minute return } @@ -83,8 +86,8 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*ClientConn, error) { if i.nfsEKey == nil { return nil, errors.New("uninitialized") } - if i.xorKey != nil { - conn = NewXorConn(conn, i.xorKey) + if i.xorMode > 0 { + conn, _ = NewXorConn(conn, i.xorMode, i.xorPKey, nil) } c := &ClientConn{Conn: conn} @@ -145,7 +148,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*ClientConn, error) { } c.baseKey = append(pfsKey, nfsKey...) - VLESS, _ := NewAead(ClientCipher, c.baseKey, encapsulatedPfsKey, encapsulatedNfsKey).Open(nil, append(i.hash11[:], ClientCipher), c.ticket[11:], pfsEKeyBytes) + VLESS, _ := NewAEAD(ClientCipher, c.baseKey, encapsulatedPfsKey, encapsulatedNfsKey).Open(nil, append(i.hash11[:], ClientCipher), c.ticket[11:], pfsEKeyBytes) if !bytes.Equal(VLESS, []byte("VLESS")) { return nil, errors.New("invalid server") } @@ -180,7 +183,7 @@ func (c *ClientConn) Write(b []byte) (int, error) { rand.Read(c.random) copy(data[5+32:], c.random) EncodeHeader(data[5+32+32:], 23, len(b)+16) - c.aead = NewAead(ClientCipher, c.baseKey, c.random, c.ticket) + c.aead = NewAEAD(ClientCipher, c.baseKey, c.random, c.ticket) c.nonce = make([]byte, 12) c.aead.Seal(data[:5+32+32+5], c.nonce, b, data[5+32+32:5+32+32+5]) } else { @@ -188,7 +191,7 @@ func (c *ClientConn) Write(b []byte) (int, error) { EncodeHeader(data, 23, len(b)+16) c.aead.Seal(data[:5], c.nonce, b, data[:5]) if bytes.Equal(c.nonce, MaxNonce) { - c.aead = NewAead(ClientCipher, c.baseKey, data[5:], data[:5]) + c.aead = NewAEAD(ClientCipher, c.baseKey, data[5:], data[:5]) } } IncreaseNonce(c.nonce) @@ -229,7 +232,7 @@ func (c *ClientConn) Read(b []byte) (int, error) { if c.random == nil { return 0, errors.New("empty c.random") } - c.peerAead = NewAead(ClientCipher, c.baseKey, peerRandomHello, c.random) + c.peerAead = NewAEAD(ClientCipher, c.baseKey, peerRandomHello, c.random) c.peerNonce = make([]byte, 12) } if c.input.Len() > 0 { @@ -252,7 +255,7 @@ func (c *ClientConn) Read(b []byte) (int, error) { } var peerAead cipher.AEAD if bytes.Equal(c.peerNonce, MaxNonce) { - peerAead = NewAead(ClientCipher, c.baseKey, peerData, h) + peerAead = NewAEAD(ClientCipher, c.baseKey, peerData, h) } _, err = c.peerAead.Open(dst[:0], c.peerNonce, peerData, h) if peerAead != nil { diff --git a/transport/vless/encryption/common.go b/transport/vless/encryption/common.go index a67a6169..d383d22e 100644 --- a/transport/vless/encryption/common.go +++ b/transport/vless/encryption/common.go @@ -73,7 +73,7 @@ func ReadAndDiscardPaddings(conn net.Conn) (h []byte, t byte, l int, err error) } } -func NewAead(c byte, secret, salt, info []byte) (aead cipher.AEAD) { +func NewAEAD(c byte, secret, salt, info []byte) (aead cipher.AEAD) { key := make([]byte, 32) hkdf.New(sha3.New256, secret, salt, info).Read(key) if c&1 == 1 { diff --git a/transport/vless/encryption/doc.go b/transport/vless/encryption/doc.go index 6f0c8f70..10cd320b 100644 --- a/transport/vless/encryption/doc.go +++ b/transport/vless/encryption/doc.go @@ -14,4 +14,5 @@ // https://github.com/XTLS/Xray-core/commit/d1fb48521271251a8c74bd64fcc2fc8700717a3b // https://github.com/XTLS/Xray-core/commit/49580705f6029648399304b816a2737f991582a8 // https://github.com/XTLS/Xray-core/commit/84835bec7d0d8555d0dd30953ed26a272de814c4 +// https://github.com/XTLS/Xray-core/commit/373558ed7abdbac3de41745cf30ec04c9adde604 package encryption diff --git a/transport/vless/encryption/factory.go b/transport/vless/encryption/factory.go index 4988dbbf..d7e4c7e0 100644 --- a/transport/vless/encryption/factory.go +++ b/transport/vless/encryption/factory.go @@ -5,7 +5,6 @@ import ( "fmt" "strconv" "strings" - "time" ) // NewClient new client from encryption string @@ -15,7 +14,7 @@ func NewClient(encryption string) (*ClientInstance, error) { case "", "none": // We will not reject empty string like xray-core does, because we need to ensure compatibility return nil, nil } - if s := strings.SplitN(encryption, "-", 4); len(s) == 4 && s[2] == "mlkem768client" { + if s := strings.Split(encryption, "."); len(s) == 5 && s[2] == "mlkem768Client" { var minutes uint32 if s[0] != "1rtt" { t := strings.TrimSuffix(s[0], "min") @@ -28,27 +27,35 @@ func NewClient(encryption string) (*ClientInstance, error) { } minutes = uint32(i) } - var xor uint32 + var xorMode uint32 switch s[1] { - case "vless": - case "xored": - xor = 1 + case "native": + case "divide": + xorMode = 1 + case "random": + xorMode = 2 default: return nil, fmt.Errorf("invaild vless encryption value: %s", encryption) } - b, err := base64.RawURLEncoding.DecodeString(s[3]) + xorPKeyBytes, err := base64.RawURLEncoding.DecodeString(s[3]) if err != nil { return nil, fmt.Errorf("invaild vless encryption value: %s", encryption) } - if len(b) == MLKEM768ClientLength { - client := &ClientInstance{} - if err = client.Init(b, xor, time.Duration(minutes)*time.Minute); err != nil { - return nil, fmt.Errorf("failed to use mlkem768seed: %w", err) - } - return client, nil - } else { + if len(xorPKeyBytes) != X25519PasswordSize { return nil, fmt.Errorf("invaild vless encryption value: %s", encryption) } + nfsEKeyBytes, err := base64.RawURLEncoding.DecodeString(s[4]) + if err != nil { + return nil, fmt.Errorf("invaild vless encryption value: %s", encryption) + } + if len(nfsEKeyBytes) != MLKEM768ClientLength { + return nil, fmt.Errorf("invaild vless encryption value: %s", encryption) + } + client := &ClientInstance{} + if err = client.Init(nfsEKeyBytes, xorPKeyBytes, xorMode, minutes); err != nil { + return nil, fmt.Errorf("failed to use mlkem768seed: %w", err) + } + return client, nil } return nil, fmt.Errorf("invaild vless encryption value: %s", encryption) } @@ -60,7 +67,7 @@ func NewServer(decryption string) (*ServerInstance, error) { case "", "none": // We will not reject empty string like xray-core does, because we need to ensure compatibility return nil, nil } - if s := strings.SplitN(decryption, "-", 4); len(s) == 4 && s[2] == "mlkem768seed" { + if s := strings.Split(decryption, "."); len(s) == 5 && s[2] == "mlkem768Seed" { var minutes uint32 if s[0] != "1rtt" { t := strings.TrimSuffix(s[0], "min") @@ -73,27 +80,35 @@ func NewServer(decryption string) (*ServerInstance, error) { } minutes = uint32(i) } - var xor uint32 + var xorMode uint32 switch s[1] { - case "vless": - case "xored": - xor = 1 + case "native": + case "divide": + xorMode = 1 + case "random": + xorMode = 2 default: return nil, fmt.Errorf("invaild vless decryption value: %s", decryption) } - b, err := base64.RawURLEncoding.DecodeString(s[3]) + xorSKeyBytes, err := base64.RawURLEncoding.DecodeString(s[3]) if err != nil { return nil, fmt.Errorf("invaild vless decryption value: %s", decryption) } - if len(b) == MLKEM768SeedLength { - server := &ServerInstance{} - if err = server.Init(b, xor, time.Duration(minutes)*time.Minute); err != nil { - return nil, fmt.Errorf("failed to use mlkem768seed: %w", err) - } - return server, nil - } else { + if len(xorSKeyBytes) != X25519PrivateKeySize { return nil, fmt.Errorf("invaild vless decryption value: %s", decryption) } + nfsDKeySeed, err := base64.RawURLEncoding.DecodeString(s[4]) + if err != nil { + return nil, fmt.Errorf("invaild vless decryption value: %s", decryption) + } + if len(nfsDKeySeed) != MLKEM768SeedLength { + return nil, fmt.Errorf("invaild vless decryption value: %s", decryption) + } + server := &ServerInstance{} + if err = server.Init(nfsDKeySeed, xorSKeyBytes, xorMode, minutes); err != nil { + return nil, fmt.Errorf("failed to use mlkem768seed: %w", err) + } + return server, nil } return nil, fmt.Errorf("invaild vless decryption value: %s", decryption) } diff --git a/transport/vless/encryption/key.go b/transport/vless/encryption/key.go index 69b52895..94b650e4 100644 --- a/transport/vless/encryption/key.go +++ b/transport/vless/encryption/key.go @@ -1,6 +1,7 @@ package encryption import ( + "crypto/ecdh" "crypto/rand" "encoding/base64" "fmt" @@ -10,16 +11,18 @@ import ( const MLKEM768SeedLength = mlkem.SeedSize const MLKEM768ClientLength = mlkem.EncapsulationKeySize768 +const X25519PasswordSize = 32 +const X25519PrivateKeySize = 32 func GenMLKEM768(seedStr string) (seedBase64, clientBase64 string, err error) { - var seed [64]byte + var seed [MLKEM768SeedLength]byte if len(seedStr) > 0 { s, _ := base64.RawURLEncoding.DecodeString(seedStr) - if len(s) != 64 { + if len(s) != MLKEM768SeedLength { err = fmt.Errorf("invalid length of ML-KEM-768 seed: %s", seedStr) return } - seed = [64]byte(s) + seed = [MLKEM768SeedLength]byte(s) } else { _, err = rand.Read(seed[:]) if err != nil { @@ -33,3 +36,35 @@ func GenMLKEM768(seedStr string) (seedBase64, clientBase64 string, err error) { clientBase64 = base64.RawURLEncoding.EncodeToString(pub.Bytes()) return } + +func GenX25519(privateKeyStr string) (privateKeyBase64, passwordBase64 string, err error) { + var privateKey [X25519PrivateKeySize]byte + if len(privateKeyStr) > 0 { + s, _ := base64.RawURLEncoding.DecodeString(privateKeyStr) + if len(s) != X25519PrivateKeySize { + err = fmt.Errorf("invalid length of X25519 private key: %s", privateKeyStr) + return + } + privateKey = [X25519PrivateKeySize]byte(s) + } else { + _, err = rand.Read(privateKey[:]) + if err != nil { + return + } + } + + // Modify random bytes using algorithm described at: + // https://cr.yp.to/ecdh.html. + privateKey[0] &= 248 + privateKey[31] &= 127 + privateKey[31] |= 64 + + key, err := ecdh.X25519().NewPrivateKey(privateKey[:]) + if err != nil { + fmt.Println(err.Error()) + return + } + privateKeyBase64 = base64.RawURLEncoding.EncodeToString(privateKey[:]) + passwordBase64 = base64.RawURLEncoding.EncodeToString(key.PublicKey().Bytes()) + return +} diff --git a/transport/vless/encryption/server.go b/transport/vless/encryption/server.go index ea72f2f5..baa8432c 100644 --- a/transport/vless/encryption/server.go +++ b/transport/vless/encryption/server.go @@ -3,6 +3,7 @@ package encryption import ( "bytes" "crypto/cipher" + "crypto/ecdh" "crypto/rand" "errors" "fmt" @@ -26,7 +27,8 @@ type ServerInstance struct { sync.RWMutex nfsDKey *mlkem.DecapsulationKey768 hash11 [11]byte // no more capacity - xorKey []byte + xorMode uint32 + xorSKey *ecdh.PrivateKey minutes time.Duration sessions map[[32]byte]*ServerSession closed bool @@ -45,23 +47,24 @@ type ServerConn struct { nonce []byte } -func (i *ServerInstance) Init(nfsDKeySeed []byte, xor uint32, minutes time.Duration) (err error) { +func (i *ServerInstance) Init(nfsDKeySeed, xorSKeyBytes []byte, xorMode, minutes uint32) (err error) { if i.nfsDKey != nil { err = errors.New("already initialized") return } - i.nfsDKey, err = mlkem.NewDecapsulationKey768(nfsDKeySeed) - if err != nil { + if i.nfsDKey, err = mlkem.NewDecapsulationKey768(nfsDKeySeed); err != nil { return } hash32 := sha3.Sum256(i.nfsDKey.EncapsulationKey().Bytes()) copy(i.hash11[:], hash32[:]) - if xor > 0 { - xorKey := sha3.Sum256(i.nfsDKey.EncapsulationKey().Bytes()) - i.xorKey = xorKey[:] + if xorMode > 0 { + i.xorMode = xorMode + if i.xorSKey, err = ecdh.X25519().NewPrivateKey(xorSKeyBytes); err != nil { + return + } } if minutes > 0 { - i.minutes = minutes + i.minutes = time.Duration(minutes) * time.Minute i.sessions = make(map[[32]byte]*ServerSession) go func() { for { @@ -95,8 +98,11 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*ServerConn, error) { if i.nfsDKey == nil { return nil, errors.New("uninitialized") } - if i.xorKey != nil { - conn = NewXorConn(conn, i.xorKey) + if i.xorMode > 0 { + var err error + if conn, err = NewXorConn(conn, i.xorMode, nil, i.xorSKey); err != nil { + return nil, err + } } c := &ServerConn{Conn: conn} @@ -167,7 +173,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*ServerConn, error) { pfsKey, encapsulatedPfsKey := pfsEKey.Encapsulate() c.baseKey = append(pfsKey, nfsKey...) - c.ticket = append(i.hash11[:], NewAead(c.cipher, c.baseKey, encapsulatedPfsKey, encapsulatedNfsKey).Seal(nil, peerClientHello[:12], []byte("VLESS"), pfsEKeyBytes)...) + c.ticket = append(i.hash11[:], NewAEAD(c.cipher, c.baseKey, encapsulatedPfsKey, encapsulatedNfsKey).Seal(nil, peerClientHello[:12], []byte("VLESS"), pfsEKeyBytes)...) paddingLen := randBetween(100, 1000) @@ -221,7 +227,7 @@ func (c *ServerConn) Read(b []byte) (int, error) { } c.peerRandom = peerTicketHello[32:] } - c.peerAead = NewAead(c.cipher, c.baseKey, c.peerRandom, c.ticket) + c.peerAead = NewAEAD(c.cipher, c.baseKey, c.peerRandom, c.ticket) c.peerNonce = make([]byte, 12) } if c.input.Len() > 0 { @@ -244,7 +250,7 @@ func (c *ServerConn) Read(b []byte) (int, error) { } var peerAead cipher.AEAD if bytes.Equal(c.peerNonce, MaxNonce) { - peerAead = NewAead(c.cipher, c.baseKey, peerData, h) + peerAead = NewAEAD(c.cipher, c.baseKey, peerData, h) } _, err = c.peerAead.Open(dst[:0], c.peerNonce, peerData, h) if peerAead != nil { @@ -280,7 +286,7 @@ func (c *ServerConn) Write(b []byte) (int, error) { EncodeHeader(data, 0, 32) rand.Read(data[5 : 5+32]) EncodeHeader(data[5+32:], 23, len(b)+16) - c.aead = NewAead(c.cipher, c.baseKey, data[5:5+32], c.peerRandom) + c.aead = NewAEAD(c.cipher, c.baseKey, data[5:5+32], c.peerRandom) c.nonce = make([]byte, 12) c.aead.Seal(data[:5+32+5], c.nonce, b, data[5+32:5+32+5]) } else { @@ -288,7 +294,7 @@ func (c *ServerConn) Write(b []byte) (int, error) { EncodeHeader(data, 23, len(b)+16) c.aead.Seal(data[:5], c.nonce, b, data[:5]) if bytes.Equal(c.nonce, MaxNonce) { - c.aead = NewAead(c.cipher, c.baseKey, data[5:], data[:5]) + c.aead = NewAEAD(c.cipher, c.baseKey, data[5:], data[:5]) } } IncreaseNonce(c.nonce) diff --git a/transport/vless/encryption/xor.go b/transport/vless/encryption/xor.go index caad12bf..8d97ff93 100644 --- a/transport/vless/encryption/xor.go +++ b/transport/vless/encryption/xor.go @@ -3,13 +3,21 @@ package encryption import ( "crypto/aes" "crypto/cipher" + "crypto/ecdh" "crypto/rand" + "errors" "io" "net" + + "golang.org/x/crypto/hkdf" + "golang.org/x/crypto/sha3" ) type XorConn struct { net.Conn + Divide bool + + head []byte key []byte ctr cipher.Stream peerCtr cipher.Stream @@ -25,8 +33,55 @@ type XorConn struct { in_skip int } -func NewXorConn(conn net.Conn, key []byte) *XorConn { - return &XorConn{Conn: conn, key: key} +func NewCTR(key, iv []byte, isServer bool) cipher.Stream { + info := "CLIENT" + if isServer { + info = "SERVER" // avoids attackers sending traffic back to the client, though the encryption layer has its own protection + } + hkdf.New(sha3.New256, key, iv, []byte(info)).Read(key) // avoids using pKey directly if attackers sent the basepoint, or whaterver they like + block, _ := aes.NewCipher(key) + return cipher.NewCTR(block, iv) +} + +func NewXorConn(conn net.Conn, mode uint32, pKey *ecdh.PublicKey, sKey *ecdh.PrivateKey) (*XorConn, error) { + if mode == 0 || (pKey == nil && sKey == nil) || (pKey != nil && sKey != nil) { + return nil, errors.New("invalid parameters") + } + c := &XorConn{ + Conn: conn, + Divide: mode == 1, + isHeader: true, + out_header: make([]byte, 0, 5), // important + in_header: make([]byte, 0, 5), // important + } + if pKey != nil { + c.head = make([]byte, 16+32) + rand.Read(c.head) + eSKey, _ := ecdh.X25519().GenerateKey(rand.Reader) + NewCTR(pKey.Bytes(), c.head[:16], false).XORKeyStream(c.head[16:], eSKey.PublicKey().Bytes()) // make X25519 public key distinguishable from random bytes + c.key, _ = eSKey.ECDH(pKey) + c.ctr = NewCTR(c.key, c.head[:16], false) + } + if sKey != nil { + peerHead := make([]byte, 16+32) + if _, err := io.ReadFull(c.Conn, peerHead); err != nil { + return nil, err + } + NewCTR(sKey.PublicKey().Bytes(), peerHead[:16], false).XORKeyStream(peerHead[16:], peerHead[16:]) // we don't use buggy elligator, because we have PSK :) + ePKey, err := ecdh.X25519().NewPublicKey(peerHead[16:]) + if err != nil { + return nil, err + } + key, err := sKey.ECDH(ePKey) + if err != nil { + return nil, err + } + c.peerCtr = NewCTR(key, peerHead[:16], false) + c.head = make([]byte, 16) + rand.Read(c.head) // make sure the server always replies random bytes even when received replays, though it is not important + c.ctr = NewCTR(key, c.head, true) // the same key links the upload & download, though the encryption layer has its own link + } + return c, nil //chacha20.NewUnauthenticatedCipher() } @@ -35,13 +90,6 @@ func (c *XorConn) Write(b []byte) (int, error) { // whole one/two records return 0, nil } if !c.out_after0 { - var iv []byte - if c.ctr == nil { - block, _ := aes.NewCipher(c.key) - iv = make([]byte, 16) - rand.Read(iv) - c.ctr = cipher.NewCTR(block, iv) - } t, l, _ := DecodeHeader(b) if t == 23 { // single 23 l = 5 @@ -49,20 +97,24 @@ func (c *XorConn) Write(b []byte) (int, error) { // whole one/two records l += 10 if t == 0 { c.out_after0 = true - c.out_header = make([]byte, 0, 5) // important + if c.Divide { + l -= 5 + } } } c.ctr.XORKeyStream(b[:l], b[:l]) // caller MUST discard b - if iv != nil { - b = append(iv, b...) + l = len(b) + if c.head != nil { + b = append(c.head, b...) + c.head = nil } if _, err := c.Conn.Write(b); err != nil { return 0, err } - if iv != nil { - b = b[16:] // for len(b) - } - return len(b), nil + return l, nil + } + if c.Divide { + return c.Conn.Write(b) } for p := b; ; { // for XTLS if len(p) <= c.out_skip { @@ -93,14 +145,12 @@ func (c *XorConn) Read(b []byte) (int, error) { // 5-bytes, data, 5-bytes... return 0, nil } if !c.in_after0 || !c.isHeader { - if c.peerCtr == nil { + if c.peerCtr == nil { // for client peerIv := make([]byte, 16) if _, err := io.ReadFull(c.Conn, peerIv); err != nil { return 0, err } - block, _ := aes.NewCipher(c.key) - c.peerCtr = cipher.NewCTR(block, peerIv) - c.isHeader = true + c.peerCtr = NewCTR(c.key, peerIv, true) } if _, err := io.ReadFull(c.Conn, b); err != nil { return 0, err @@ -117,7 +167,6 @@ func (c *XorConn) Read(b []byte) (int, error) { // 5-bytes, data, 5-bytes... c.isHeader = false if t == 0 { c.in_after0 = true - c.in_header = make([]byte, 0, 5) // important } } } else { @@ -125,6 +174,9 @@ func (c *XorConn) Read(b []byte) (int, error) { // 5-bytes, data, 5-bytes... } return len(b), nil } + if c.Divide { + return c.Conn.Read(b) + } n, err := c.Conn.Read(b) for p := b[:n]; ; { // for XTLS if len(p) <= c.in_skip {