chore: allow automatic reloading when the TLS server's certificate, private-key or ech-key is a local file

This commit is contained in:
wwqgtxx 2025-12-19 18:51:14 +08:00
parent 93cf46e430
commit 4a723e8d3f
2 changed files with 56 additions and 7 deletions

View File

@ -12,10 +12,13 @@ import (
"fmt"
"math/big"
"os"
"runtime"
"sync"
"time"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/fswatch"
"github.com/metacubex/tls"
)
@ -49,7 +52,25 @@ func NewTLSKeyPairLoader(certificate, privateKey string) (func() (*tls.Certifica
if loadErr != nil {
return nil, fmt.Errorf("parse certificate failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error())
}
gcFlag := new(os.File)
updateMutex := sync.RWMutex{}
if watcher, err := fswatch.NewWatcher(fswatch.Options{Path: []string{certificate, privateKey}, Callback: func(path string) {
updateMutex.Lock()
defer updateMutex.Unlock()
if newCert, err := tls.LoadX509KeyPair(certificate, privateKey); err == nil {
cert = newCert
}
}}); err == nil {
if err = watcher.Start(); err == nil {
runtime.SetFinalizer(gcFlag, func(f *os.File) {
_ = watcher.Close()
})
}
}
return func() (*tls.Certificate, error) {
defer runtime.KeepAlive(gcFlag)
updateMutex.RLock()
defer updateMutex.RUnlock()
return &cert, nil
}, nil
}

View File

@ -8,9 +8,12 @@ import (
"errors"
"fmt"
"os"
"runtime"
"sync"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/fswatch"
"github.com/metacubex/tls"
"golang.org/x/crypto/cryptobyte"
)
@ -108,8 +111,11 @@ func LoadECHKey(key string, tlsConfig *tls.Config) error {
if key == "" {
return nil
}
painTextErr := loadECHKey([]byte(key), tlsConfig)
echKeys, painTextErr := loadECHKey([]byte(key))
if painTextErr == nil {
tlsConfig.GetEncryptedClientHelloKeys = func(info *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
return echKeys, nil
}
return nil
}
key = C.Path.Resolve(key)
@ -120,24 +126,46 @@ func LoadECHKey(key string, tlsConfig *tls.Config) error {
var echKey []byte
echKey, loadErr = os.ReadFile(key)
if loadErr == nil {
loadErr = loadECHKey(echKey, tlsConfig)
echKeys, loadErr = loadECHKey(echKey)
}
}
if loadErr != nil {
return fmt.Errorf("parse ECH keys failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error())
}
gcFlag := new(os.File)
updateMutex := sync.RWMutex{}
if watcher, err := fswatch.NewWatcher(fswatch.Options{Path: []string{key}, Callback: func(path string) {
updateMutex.Lock()
defer updateMutex.Unlock()
if echKey, err := os.ReadFile(key); err == nil {
if newEchKeys, err := loadECHKey(echKey); err == nil {
echKeys = newEchKeys
}
}
}}); err == nil {
if err = watcher.Start(); err == nil {
runtime.SetFinalizer(gcFlag, func(f *os.File) {
_ = watcher.Close()
})
}
}
tlsConfig.GetEncryptedClientHelloKeys = func(info *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
defer runtime.KeepAlive(gcFlag)
updateMutex.RLock()
defer updateMutex.RUnlock()
return echKeys, nil
}
return nil
}
func loadECHKey(echKey []byte, tlsConfig *tls.Config) error {
func loadECHKey(echKey []byte) ([]tls.EncryptedClientHelloKey, error) {
block, rest := pem.Decode(echKey)
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
return errors.New("invalid ECH keys pem")
return nil, errors.New("invalid ECH keys pem")
}
echKeys, err := UnmarshalECHKeys(block.Bytes)
if err != nil {
return fmt.Errorf("parse ECH keys: %w", err)
return nil, fmt.Errorf("parse ECH keys: %w", err)
}
tlsConfig.EncryptedClientHelloKeys = echKeys
return nil
return echKeys, err
}