From cbcacdbb8cd5706b60bc2308c3bcfbe2daa51ab6 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 19 Dec 2025 12:08:44 +0800 Subject: [PATCH] chore: using tls.Config.GetCertificate/GetClientCertificate to load TLS certificates --- component/ca/config.go | 7 ++-- component/ca/keypair.go | 19 ++++++----- component/tls/utls.go | 54 ++++++++++++++++++++++++++++++- hub/route/server.go | 6 ++-- listener/anytls/server.go | 8 +++-- listener/http/server.go | 10 +++--- listener/mixed/mixed.go | 10 +++--- listener/sing_hysteria2/server.go | 12 ++++--- listener/sing_vless/server.go | 10 +++--- listener/sing_vmess/server.go | 10 +++--- listener/socks/tcp.go | 10 +++--- listener/trojan/server.go | 10 +++--- listener/tuic/server.go | 12 ++++--- 13 files changed, 127 insertions(+), 51 deletions(-) diff --git a/component/ca/config.go b/component/ca/config.go index 9cc8839f..c097ca25 100644 --- a/component/ca/config.go +++ b/component/ca/config.go @@ -107,12 +107,13 @@ func GetTLSConfig(opt Option) (tlsConfig *tls.Config, err error) { } if len(opt.Certificate) > 0 || len(opt.PrivateKey) > 0 { - var cert tls.Certificate - cert, err = LoadTLSKeyPair(opt.Certificate, opt.PrivateKey, C.Path) + certLoader, err := NewTLSKeyPairLoader(opt.Certificate, opt.PrivateKey, C.Path) if err != nil { return nil, err } - tlsConfig.Certificates = []tls.Certificate{cert} + tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return certLoader() + } } return tlsConfig, nil } diff --git a/component/ca/keypair.go b/component/ca/keypair.go index 13b38dc1..7f869afb 100644 --- a/component/ca/keypair.go +++ b/component/ca/keypair.go @@ -23,24 +23,25 @@ type Path interface { ErrNotSafePath(path string) error } -// LoadTLSKeyPair loads a TLS key pair from the provided certificate and private key data or file paths, supporting fallback resolution. -// Returns a tls.Certificate and an error, where the error indicates issues during parsing or file loading. +// NewTLSKeyPairLoader creates a loader function for TLS key pairs from the provided certificate and private key data or file paths. // If both certificate and privateKey are empty, generates a random TLS RSA key pair. // Accepts a Path interface for resolving file paths when necessary. -func LoadTLSKeyPair(certificate, privateKey string, path Path) (tls.Certificate, error) { +func NewTLSKeyPairLoader(certificate, privateKey string, path Path) (func() (*tls.Certificate, error), error) { if certificate == "" && privateKey == "" { var err error certificate, privateKey, _, err = NewRandomTLSKeyPair(KeyPairTypeRSA) if err != nil { - return tls.Certificate{}, err + return nil, err } } cert, painTextErr := tls.X509KeyPair([]byte(certificate), []byte(privateKey)) if painTextErr == nil { - return cert, nil + return func() (*tls.Certificate, error) { + return &cert, nil + }, nil } if path == nil { - return tls.Certificate{}, painTextErr + return nil, painTextErr } certificate = path.Resolve(certificate) @@ -54,9 +55,11 @@ func LoadTLSKeyPair(certificate, privateKey string, path Path) (tls.Certificate, cert, loadErr = tls.LoadX509KeyPair(certificate, privateKey) } if loadErr != nil { - return tls.Certificate{}, fmt.Errorf("parse certificate failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error()) + return nil, fmt.Errorf("parse certificate failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error()) } - return cert, nil + return func() (*tls.Certificate, error) { + return &cert, nil + }, nil } func LoadCertificates(certificate string, path Path) (*x509.CertPool, error) { diff --git a/component/tls/utls.go b/component/tls/utls.go index 2b33d323..d8ca9716 100644 --- a/component/tls/utls.go +++ b/component/tls/utls.go @@ -1,7 +1,10 @@ package tls import ( + "context" "net" + "reflect" + "unsafe" "github.com/metacubex/mihomo/common/once" "github.com/metacubex/mihomo/common/utils" @@ -126,8 +129,11 @@ type EncryptedClientHelloKey = utls.EncryptedClientHelloKey type Config = utls.Config +var tlsCertificateRequestInfoCtxOffset = utils.MustOK(reflect.TypeOf((*tls.CertificateRequestInfo)(nil)).Elem().FieldByName("ctx")).Offset +var tlsClientHelloInfoCtxOffset = utils.MustOK(reflect.TypeOf((*tls.ClientHelloInfo)(nil)).Elem().FieldByName("ctx")).Offset + func UConfig(config *tls.Config) *utls.Config { - return &utls.Config{ + cfg := &utls.Config{ Rand: config.Rand, Time: config.Time, Certificates: utils.Map(config.Certificates, UCertificate), @@ -147,6 +153,52 @@ func UConfig(config *tls.Config) *utls.Config { SessionTicketsDisabled: config.SessionTicketsDisabled, Renegotiation: utls.RenegotiationSupport(config.Renegotiation), } + if config.GetClientCertificate != nil { + cfg.GetClientCertificate = func(info *utls.CertificateRequestInfo) (*utls.Certificate, error) { + tlsInfo := &tls.CertificateRequestInfo{ + AcceptableCAs: info.AcceptableCAs, + SignatureSchemes: utils.Map(info.SignatureSchemes, func(it utls.SignatureScheme) tls.SignatureScheme { + return tls.SignatureScheme(it) + }), + Version: info.Version, + } + *(*context.Context)(unsafe.Add(unsafe.Pointer(tlsInfo), tlsCertificateRequestInfoCtxOffset)) = info.Context() // for tlsInfo.ctx + cert, err := config.GetClientCertificate(tlsInfo) + if err != nil { + return nil, err + } + uCert := UCertificate(*cert) + return &uCert, err + } + } + if config.GetCertificate != nil { + cfg.GetCertificate = func(info *utls.ClientHelloInfo) (*utls.Certificate, error) { + tlsInfo := &tls.ClientHelloInfo{ + CipherSuites: info.CipherSuites, + ServerName: info.ServerName, + SupportedCurves: utils.Map(info.SupportedCurves, func(it utls.CurveID) tls.CurveID { + return tls.CurveID(it) + }), + SupportedPoints: info.SupportedPoints, + SignatureSchemes: utils.Map(info.SignatureSchemes, func(it utls.SignatureScheme) tls.SignatureScheme { + return tls.SignatureScheme(it) + }), + SupportedProtos: info.SupportedProtos, + SupportedVersions: info.SupportedVersions, + Extensions: info.Extensions, + Conn: info.Conn, + //HelloRetryRequest: info.HelloRetryRequest, + } + *(*context.Context)(unsafe.Add(unsafe.Pointer(tlsInfo), tlsClientHelloInfoCtxOffset)) = info.Context() // for tlsInfo.ctx + cert, err := config.GetCertificate(tlsInfo) + if err != nil { + return nil, err + } + uCert := UCertificate(*cert) + return &uCert, err + } + } + return cfg } // BuildWebsocketHandshakeState it will only send http/1.1 in its ALPN. diff --git a/hub/route/server.go b/hub/route/server.go index 6c47d672..a27ee6a9 100644 --- a/hub/route/server.go +++ b/hub/route/server.go @@ -191,7 +191,7 @@ func startTLS(cfg *Config) { // handle tlsAddr if len(cfg.TLSAddr) > 0 { - cert, err := ca.LoadTLSKeyPair(cfg.Certificate, cfg.PrivateKey, C.Path) + certLoader, err := ca.NewTLSKeyPairLoader(cfg.Certificate, cfg.PrivateKey, C.Path) if err != nil { log.Errorln("External controller tls listen error: %s", err) return @@ -206,7 +206,9 @@ func startTLS(cfg *Config) { log.Infoln("RESTful API tls listening at: %s", l.Addr().String()) tlsConfig := &tls.Config{Time: ntp.Now} tlsConfig.NextProtos = []string{"h2", "http/1.1"} - tlsConfig.Certificates = []tls.Certificate{cert} + tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return certLoader() + } tlsConfig.ClientAuth = ca.ClientAuthTypeFromString(cfg.ClientAuthType) if len(cfg.ClientAuthCert) > 0 { if tlsConfig.ClientAuth == tls.NoClientCert { diff --git a/listener/anytls/server.go b/listener/anytls/server.go index 731f1394..197aabd9 100644 --- a/listener/anytls/server.go +++ b/listener/anytls/server.go @@ -45,11 +45,13 @@ func New(config LC.AnyTLSServer, tunnel C.Tunnel, additions ...inbound.Addition) tlsConfig := &tls.Config{Time: ntp.Now} if config.Certificate != "" && config.PrivateKey != "" { - cert, err := ca.LoadTLSKeyPair(config.Certificate, config.PrivateKey, C.Path) + certLoader, err := ca.NewTLSKeyPairLoader(config.Certificate, config.PrivateKey, C.Path) if err != nil { return nil, err } - tlsConfig.Certificates = []tls.Certificate{cert} + tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return certLoader() + } if config.EchKey != "" { err = ech.LoadECHKey(config.EchKey, tlsConfig, C.Path) @@ -108,7 +110,7 @@ func New(config LC.AnyTLSServer, tunnel C.Tunnel, additions ...inbound.Addition) if err != nil { return nil, err } - if len(tlsConfig.Certificates) > 0 { + if tlsConfig.GetCertificate != nil { l = tls.NewListener(l, tlsConfig) } else { return nil, errors.New("disallow using AnyTLS without certificates config") diff --git a/listener/http/server.go b/listener/http/server.go index 2aba6fda..66bf86c3 100644 --- a/listener/http/server.go +++ b/listener/http/server.go @@ -71,11 +71,13 @@ func NewWithConfig(config LC.AuthServer, tunnel C.Tunnel, additions ...inbound.A var realityBuilder *reality.Builder if config.Certificate != "" && config.PrivateKey != "" { - cert, err := ca.LoadTLSKeyPair(config.Certificate, config.PrivateKey, C.Path) + certLoader, err := ca.NewTLSKeyPairLoader(config.Certificate, config.PrivateKey, C.Path) if err != nil { return nil, err } - tlsConfig.Certificates = []tls.Certificate{cert} + tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return certLoader() + } if config.EchKey != "" { err = ech.LoadECHKey(config.EchKey, tlsConfig, C.Path) @@ -98,7 +100,7 @@ func NewWithConfig(config LC.AuthServer, tunnel C.Tunnel, additions ...inbound.A tlsConfig.ClientCAs = pool } if config.RealityConfig.PrivateKey != "" { - if tlsConfig.Certificates != nil { + if tlsConfig.GetCertificate != nil { return nil, errors.New("certificate is unavailable in reality") } if tlsConfig.ClientAuth != tls.NoClientCert { @@ -112,7 +114,7 @@ func NewWithConfig(config LC.AuthServer, tunnel C.Tunnel, additions ...inbound.A if realityBuilder != nil { l = realityBuilder.NewListener(l) - } else if len(tlsConfig.Certificates) > 0 { + } else if tlsConfig.GetCertificate != nil { l = tls.NewListener(l, tlsConfig) } diff --git a/listener/mixed/mixed.go b/listener/mixed/mixed.go index 995822b0..1efbb40f 100644 --- a/listener/mixed/mixed.go +++ b/listener/mixed/mixed.go @@ -67,11 +67,13 @@ func NewWithConfig(config LC.AuthServer, tunnel C.Tunnel, additions ...inbound.A var realityBuilder *reality.Builder if config.Certificate != "" && config.PrivateKey != "" { - cert, err := ca.LoadTLSKeyPair(config.Certificate, config.PrivateKey, C.Path) + certLoader, err := ca.NewTLSKeyPairLoader(config.Certificate, config.PrivateKey, C.Path) if err != nil { return nil, err } - tlsConfig.Certificates = []tls.Certificate{cert} + tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return certLoader() + } if config.EchKey != "" { err = ech.LoadECHKey(config.EchKey, tlsConfig, C.Path) @@ -94,7 +96,7 @@ func NewWithConfig(config LC.AuthServer, tunnel C.Tunnel, additions ...inbound.A tlsConfig.ClientCAs = pool } if config.RealityConfig.PrivateKey != "" { - if tlsConfig.Certificates != nil { + if tlsConfig.GetCertificate != nil { return nil, errors.New("certificate is unavailable in reality") } if tlsConfig.ClientAuth != tls.NoClientCert { @@ -108,7 +110,7 @@ func NewWithConfig(config LC.AuthServer, tunnel C.Tunnel, additions ...inbound.A if realityBuilder != nil { l = realityBuilder.NewListener(l) - } else if len(tlsConfig.Certificates) > 0 { + } else if tlsConfig.GetCertificate != nil { l = tls.NewListener(l, tlsConfig) } diff --git a/listener/sing_hysteria2/server.go b/listener/sing_hysteria2/server.go index becb06b1..94dd0db6 100644 --- a/listener/sing_hysteria2/server.go +++ b/listener/sing_hysteria2/server.go @@ -56,15 +56,17 @@ func New(config LC.Hysteria2Server, tunnel C.Tunnel, additions ...inbound.Additi sl = &Listener{false, config, nil, nil} - cert, err := ca.LoadTLSKeyPair(config.Certificate, config.PrivateKey, C.Path) - if err != nil { - return nil, err - } tlsConfig := &tls.Config{ Time: ntp.Now, MinVersion: tls.VersionTLS13, } - tlsConfig.Certificates = []tls.Certificate{cert} + certLoader, err := ca.NewTLSKeyPairLoader(config.Certificate, config.PrivateKey, C.Path) + if err != nil { + return nil, err + } + tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return certLoader() + } tlsConfig.ClientAuth = ca.ClientAuthTypeFromString(config.ClientAuthType) if len(config.ClientAuthCert) > 0 { if tlsConfig.ClientAuth == tls.NoClientCert { diff --git a/listener/sing_vless/server.go b/listener/sing_vless/server.go index 049f5eb1..10cb6e2c 100644 --- a/listener/sing_vless/server.go +++ b/listener/sing_vless/server.go @@ -81,11 +81,13 @@ func New(config LC.VlessServer, tunnel C.Tunnel, additions ...inbound.Addition) var httpServer http.Server if config.Certificate != "" && config.PrivateKey != "" { - cert, err := ca.LoadTLSKeyPair(config.Certificate, config.PrivateKey, C.Path) + certLoader, err := ca.NewTLSKeyPairLoader(config.Certificate, config.PrivateKey, C.Path) if err != nil { return nil, err } - tlsConfig.Certificates = []tls.Certificate{cert} + tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return certLoader() + } if config.EchKey != "" { err = ech.LoadECHKey(config.EchKey, tlsConfig, C.Path) @@ -108,7 +110,7 @@ func New(config LC.VlessServer, tunnel C.Tunnel, additions ...inbound.Addition) tlsConfig.ClientCAs = pool } if config.RealityConfig.PrivateKey != "" { - if tlsConfig.Certificates != nil { + if tlsConfig.GetCertificate != nil { return nil, errors.New("certificate is unavailable in reality") } if tlsConfig.ClientAuth != tls.NoClientCert { @@ -153,7 +155,7 @@ func New(config LC.VlessServer, tunnel C.Tunnel, additions ...inbound.Addition) } if realityBuilder != nil { l = realityBuilder.NewListener(l) - } else if len(tlsConfig.Certificates) > 0 { + } else if tlsConfig.GetCertificate != nil { l = tls.NewListener(l, tlsConfig) } else if sl.decryption == nil { return nil, errors.New("disallow using Vless without any certificates/reality/decryption config") diff --git a/listener/sing_vmess/server.go b/listener/sing_vmess/server.go index 956aa708..5ca6a159 100644 --- a/listener/sing_vmess/server.go +++ b/listener/sing_vmess/server.go @@ -81,11 +81,13 @@ func New(config LC.VmessServer, tunnel C.Tunnel, additions ...inbound.Addition) var httpServer http.Server if config.Certificate != "" && config.PrivateKey != "" { - cert, err := ca.LoadTLSKeyPair(config.Certificate, config.PrivateKey, C.Path) + certLoader, err := ca.NewTLSKeyPairLoader(config.Certificate, config.PrivateKey, C.Path) if err != nil { return nil, err } - tlsConfig.Certificates = []tls.Certificate{cert} + tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return certLoader() + } if config.EchKey != "" { err = ech.LoadECHKey(config.EchKey, tlsConfig, C.Path) @@ -108,7 +110,7 @@ func New(config LC.VmessServer, tunnel C.Tunnel, additions ...inbound.Addition) tlsConfig.ClientCAs = pool } if config.RealityConfig.PrivateKey != "" { - if tlsConfig.Certificates != nil { + if tlsConfig.GetCertificate != nil { return nil, errors.New("certificate is unavailable in reality") } if tlsConfig.ClientAuth != tls.NoClientCert { @@ -153,7 +155,7 @@ func New(config LC.VmessServer, tunnel C.Tunnel, additions ...inbound.Addition) } if realityBuilder != nil { l = realityBuilder.NewListener(l) - } else if len(tlsConfig.Certificates) > 0 { + } else if tlsConfig.GetCertificate != nil { l = tls.NewListener(l, tlsConfig) } sl.listeners = append(sl.listeners, l) diff --git a/listener/socks/tcp.go b/listener/socks/tcp.go index 55e9e594..60a34e1c 100644 --- a/listener/socks/tcp.go +++ b/listener/socks/tcp.go @@ -66,11 +66,13 @@ func NewWithConfig(config LC.AuthServer, tunnel C.Tunnel, additions ...inbound.A var realityBuilder *reality.Builder if config.Certificate != "" && config.PrivateKey != "" { - cert, err := ca.LoadTLSKeyPair(config.Certificate, config.PrivateKey, C.Path) + certLoader, err := ca.NewTLSKeyPairLoader(config.Certificate, config.PrivateKey, C.Path) if err != nil { return nil, err } - tlsConfig.Certificates = []tls.Certificate{cert} + tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return certLoader() + } if config.EchKey != "" { err = ech.LoadECHKey(config.EchKey, tlsConfig, C.Path) @@ -93,7 +95,7 @@ func NewWithConfig(config LC.AuthServer, tunnel C.Tunnel, additions ...inbound.A tlsConfig.ClientCAs = pool } if config.RealityConfig.PrivateKey != "" { - if tlsConfig.Certificates != nil { + if tlsConfig.GetCertificate != nil { return nil, errors.New("certificate is unavailable in reality") } if tlsConfig.ClientAuth != tls.NoClientCert { @@ -107,7 +109,7 @@ func NewWithConfig(config LC.AuthServer, tunnel C.Tunnel, additions ...inbound.A if realityBuilder != nil { l = realityBuilder.NewListener(l) - } else if len(tlsConfig.Certificates) > 0 { + } else if tlsConfig.GetCertificate != nil { l = tls.NewListener(l, tlsConfig) } diff --git a/listener/trojan/server.go b/listener/trojan/server.go index 6155d209..a5e123d0 100644 --- a/listener/trojan/server.go +++ b/listener/trojan/server.go @@ -76,11 +76,13 @@ func New(config LC.TrojanServer, tunnel C.Tunnel, additions ...inbound.Addition) var httpServer http.Server if config.Certificate != "" && config.PrivateKey != "" { - cert, err := ca.LoadTLSKeyPair(config.Certificate, config.PrivateKey, C.Path) + certLoader, err := ca.NewTLSKeyPairLoader(config.Certificate, config.PrivateKey, C.Path) if err != nil { return nil, err } - tlsConfig.Certificates = []tls.Certificate{cert} + tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return certLoader() + } if config.EchKey != "" { err = ech.LoadECHKey(config.EchKey, tlsConfig, C.Path) @@ -103,7 +105,7 @@ func New(config LC.TrojanServer, tunnel C.Tunnel, additions ...inbound.Addition) tlsConfig.ClientCAs = pool } if config.RealityConfig.PrivateKey != "" { - if tlsConfig.Certificates != nil { + if tlsConfig.GetCertificate != nil { return nil, errors.New("certificate is unavailable in reality") } if tlsConfig.ClientAuth != tls.NoClientCert { @@ -148,7 +150,7 @@ func New(config LC.TrojanServer, tunnel C.Tunnel, additions ...inbound.Addition) } if realityBuilder != nil { l = realityBuilder.NewListener(l) - } else if len(tlsConfig.Certificates) > 0 { + } else if tlsConfig.GetCertificate != nil { l = tls.NewListener(l, tlsConfig) } else if !config.TrojanSSOption.Enabled { return nil, errors.New("disallow using Trojan without both certificates/reality/ss config") diff --git a/listener/tuic/server.go b/listener/tuic/server.go index 30845515..da492129 100644 --- a/listener/tuic/server.go +++ b/listener/tuic/server.go @@ -49,15 +49,17 @@ func New(config LC.TuicServer, tunnel C.Tunnel, additions ...inbound.Addition) ( return nil, err } - cert, err := ca.LoadTLSKeyPair(config.Certificate, config.PrivateKey, C.Path) - if err != nil { - return nil, err - } tlsConfig := &tls.Config{ Time: ntp.Now, MinVersion: tls.VersionTLS13, } - tlsConfig.Certificates = []tls.Certificate{cert} + certLoader, err := ca.NewTLSKeyPairLoader(config.Certificate, config.PrivateKey, C.Path) + if err != nil { + return nil, err + } + tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return certLoader() + } tlsConfig.ClientAuth = ca.ClientAuthTypeFromString(config.ClientAuthType) if len(config.ClientAuthCert) > 0 { if tlsConfig.ClientAuth == tls.NoClientCert {