chore: allow automatic reloading when the TLS server's certificate, private-key or ech-key is a local file

This commit is contained in:
wwqgtxx 2025-12-19 18:51:14 +08:00
parent 93cf46e430
commit 4a723e8d3f
2 changed files with 56 additions and 7 deletions

View File

@ -12,10 +12,13 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"os" "os"
"runtime"
"sync"
"time" "time"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/fswatch"
"github.com/metacubex/tls" "github.com/metacubex/tls"
) )
@ -49,7 +52,25 @@ func NewTLSKeyPairLoader(certificate, privateKey string) (func() (*tls.Certifica
if loadErr != nil { if loadErr != nil {
return nil, fmt.Errorf("parse certificate failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error()) return nil, fmt.Errorf("parse certificate failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error())
} }
gcFlag := new(os.File)
updateMutex := sync.RWMutex{}
if watcher, err := fswatch.NewWatcher(fswatch.Options{Path: []string{certificate, privateKey}, Callback: func(path string) {
updateMutex.Lock()
defer updateMutex.Unlock()
if newCert, err := tls.LoadX509KeyPair(certificate, privateKey); err == nil {
cert = newCert
}
}}); err == nil {
if err = watcher.Start(); err == nil {
runtime.SetFinalizer(gcFlag, func(f *os.File) {
_ = watcher.Close()
})
}
}
return func() (*tls.Certificate, error) { return func() (*tls.Certificate, error) {
defer runtime.KeepAlive(gcFlag)
updateMutex.RLock()
defer updateMutex.RUnlock()
return &cert, nil return &cert, nil
}, nil }, nil
} }

View File

@ -8,9 +8,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"runtime"
"sync"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/fswatch"
"github.com/metacubex/tls" "github.com/metacubex/tls"
"golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte"
) )
@ -108,8 +111,11 @@ func LoadECHKey(key string, tlsConfig *tls.Config) error {
if key == "" { if key == "" {
return nil return nil
} }
painTextErr := loadECHKey([]byte(key), tlsConfig) echKeys, painTextErr := loadECHKey([]byte(key))
if painTextErr == nil { if painTextErr == nil {
tlsConfig.GetEncryptedClientHelloKeys = func(info *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
return echKeys, nil
}
return nil return nil
} }
key = C.Path.Resolve(key) key = C.Path.Resolve(key)
@ -120,24 +126,46 @@ func LoadECHKey(key string, tlsConfig *tls.Config) error {
var echKey []byte var echKey []byte
echKey, loadErr = os.ReadFile(key) echKey, loadErr = os.ReadFile(key)
if loadErr == nil { if loadErr == nil {
loadErr = loadECHKey(echKey, tlsConfig) echKeys, loadErr = loadECHKey(echKey)
} }
} }
if loadErr != nil { if loadErr != nil {
return fmt.Errorf("parse ECH keys failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error()) return fmt.Errorf("parse ECH keys failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error())
} }
gcFlag := new(os.File)
updateMutex := sync.RWMutex{}
if watcher, err := fswatch.NewWatcher(fswatch.Options{Path: []string{key}, Callback: func(path string) {
updateMutex.Lock()
defer updateMutex.Unlock()
if echKey, err := os.ReadFile(key); err == nil {
if newEchKeys, err := loadECHKey(echKey); err == nil {
echKeys = newEchKeys
}
}
}}); err == nil {
if err = watcher.Start(); err == nil {
runtime.SetFinalizer(gcFlag, func(f *os.File) {
_ = watcher.Close()
})
}
}
tlsConfig.GetEncryptedClientHelloKeys = func(info *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
defer runtime.KeepAlive(gcFlag)
updateMutex.RLock()
defer updateMutex.RUnlock()
return echKeys, nil
}
return nil return nil
} }
func loadECHKey(echKey []byte, tlsConfig *tls.Config) error { func loadECHKey(echKey []byte) ([]tls.EncryptedClientHelloKey, error) {
block, rest := pem.Decode(echKey) block, rest := pem.Decode(echKey)
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 { if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
return errors.New("invalid ECH keys pem") return nil, errors.New("invalid ECH keys pem")
} }
echKeys, err := UnmarshalECHKeys(block.Bytes) echKeys, err := UnmarshalECHKeys(block.Bytes)
if err != nil { if err != nil {
return fmt.Errorf("parse ECH keys: %w", err) return nil, fmt.Errorf("parse ECH keys: %w", err)
} }
tlsConfig.EncryptedClientHelloKeys = echKeys return echKeys, err
return nil
} }