diff --git a/component/tls/reality.go b/component/tls/reality.go index c315e527..88445a1e 100644 --- a/component/tls/reality.go +++ b/component/tls/reality.go @@ -36,13 +36,13 @@ type RealityConfig struct { SupportX25519MLKEM768 bool } -func GetRealityConn(ctx context.Context, conn net.Conn, fingerprint UClientHelloID, tlsConfig *Config, realityConfig *RealityConfig) (net.Conn, error) { +func GetRealityConn(ctx context.Context, conn net.Conn, fingerprint UClientHelloID, serverName string, realityConfig *RealityConfig) (net.Conn, error) { for retry := 0; ; retry++ { verifier := &realityVerifier{ - serverName: tlsConfig.ServerName, + serverName: serverName, } uConfig := &utls.Config{ - ServerName: tlsConfig.ServerName, + ServerName: serverName, InsecureSkipVerify: true, SessionTicketsDisabled: true, VerifyPeerCertificate: verifier.VerifyPeerCertificate, diff --git a/transport/gun/gun.go b/transport/gun/gun.go index 4067ce70..35639c12 100644 --- a/transport/gun/gun.go +++ b/transport/gun/gun.go @@ -259,14 +259,13 @@ func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config, clientFingerprint stri } if clientFingerprint, ok := tlsC.GetFingerprint(clientFingerprint); ok { - tlsConfig := tlsC.UConfig(cfg) - err := echConfig.ClientHandleUTLS(ctx, tlsConfig) - if err != nil { - pconn.Close() - return nil, err - } - if realityConfig == nil { + tlsConfig := tlsC.UConfig(cfg) + err := echConfig.ClientHandleUTLS(ctx, tlsConfig) + if err != nil { + pconn.Close() + return nil, err + } tlsConn := tlsC.UClient(pconn, tlsConfig, clientFingerprint) if err := tlsConn.HandshakeContext(ctx); err != nil { pconn.Close() @@ -279,7 +278,7 @@ func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config, clientFingerprint stri } return tlsConn, nil } else { - realityConn, err := tlsC.GetRealityConn(ctx, pconn, clientFingerprint, tlsConfig, realityConfig) + realityConn, err := tlsC.GetRealityConn(ctx, pconn, clientFingerprint, cfg.ServerName, realityConfig) if err != nil { pconn.Close() return nil, err @@ -296,25 +295,10 @@ func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config, clientFingerprint stri return nil, errors.New("REALITY is based on uTLS, please set a client-fingerprint") } - if echConfig != nil { - tlsConfig := tlsC.UConfig(cfg) - err := echConfig.ClientHandleUTLS(ctx, tlsConfig) - if err != nil { - pconn.Close() - return nil, err - } - - conn := tlsC.Client(pconn, tlsConfig) - if err := conn.HandshakeContext(ctx); err != nil { - pconn.Close() - return nil, err - } - state := conn.ConnectionState() - if p := state.NegotiatedProtocol; p != http.Http2NextProtoTLS { - conn.Close() - return nil, fmt.Errorf("http2: unexpected ALPN protocol %s, want %s", p, http.Http2NextProtoTLS) - } - return conn, nil + err = echConfig.ClientHandle(ctx, cfg) + if err != nil { + pconn.Close() + return nil, err } conn := tls.Client(pconn, cfg) diff --git a/transport/vmess/tls.go b/transport/vmess/tls.go index c304d95f..4c4c50ec 100644 --- a/transport/vmess/tls.go +++ b/transport/vmess/tls.go @@ -44,13 +44,12 @@ func StreamTLSConn(ctx context.Context, conn net.Conn, cfg *TLSConfig) (net.Conn } if clientFingerprint, ok := tlsC.GetFingerprint(cfg.ClientFingerprint); ok { - tlsConfig := tlsC.UConfig(tlsConfig) - err = cfg.ECH.ClientHandleUTLS(ctx, tlsConfig) - if err != nil { - return nil, err - } - if cfg.Reality == nil { + tlsConfig := tlsC.UConfig(tlsConfig) + err = cfg.ECH.ClientHandleUTLS(ctx, tlsConfig) + if err != nil { + return nil, err + } tlsConn := tlsC.UClient(conn, tlsConfig, clientFingerprint) err = tlsConn.HandshakeContext(ctx) if err != nil { @@ -58,24 +57,16 @@ func StreamTLSConn(ctx context.Context, conn net.Conn, cfg *TLSConfig) (net.Conn } return tlsConn, nil } else { - return tlsC.GetRealityConn(ctx, conn, clientFingerprint, tlsConfig, cfg.Reality) + return tlsC.GetRealityConn(ctx, conn, clientFingerprint, tlsConfig.ServerName, cfg.Reality) } } if cfg.Reality != nil { return nil, errors.New("REALITY is based on uTLS, please set a client-fingerprint") } - if cfg.ECH != nil { - tlsConfig := tlsC.UConfig(tlsConfig) - err = cfg.ECH.ClientHandleUTLS(ctx, tlsConfig) - if err != nil { - return nil, err - } - - tlsConn := tlsC.Client(conn, tlsConfig) - - err = tlsConn.HandshakeContext(ctx) - return tlsConn, err + err = cfg.ECH.ClientHandle(ctx, tlsConfig) + if err != nil { + return nil, err } tlsConn := tls.Client(conn, tlsConfig) diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index e73540d2..0ee1b3ab 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -370,17 +370,11 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, return nil, err } conn = tlsConn - } else if c.ECHConfig != nil { - tlsConfig := tlsC.UConfig(config) - err = c.ECHConfig.ClientHandleUTLS(ctx, tlsConfig) + } else { + err = c.ECHConfig.ClientHandle(ctx, config) if err != nil { return nil, err } - tlsConn := tlsC.Client(conn, tlsConfig) - - err = tlsConn.HandshakeContext(ctx) - conn = tlsConn - } else { tlsConn := tls.Client(conn, config) err = tlsConn.HandshakeContext(ctx) if err != nil {